aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--WORKSPACE20
-rw-r--r--configure.py12
-rw-r--r--tensorflow/BUILD10
-rw-r--r--tensorflow/api_template.__init__.py15
-rw-r--r--tensorflow/c/eager/BUILD5
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc19
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc6
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc13
-rw-r--r--tensorflow/compiler/xla/rpc/BUILD13
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_service_main.cc21
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc98
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.h32
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc10
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc25
-rw-r--r--tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc48
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc144
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h123
-rw-r--r--tensorflow/compiler/xla/service/platform_util.cc10
-rw-r--r--tensorflow/compiler/xla/tests/BUILD18
-rw-r--r--tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc120
-rw-r--r--tensorflow/compiler/xla/xla.proto9
-rw-r--r--tensorflow/contrib/BUILD15
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc23
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py86
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt1
-rw-r--r--tensorflow/contrib/cmake/python_protos.txt1
-rw-r--r--tensorflow/contrib/compiler/BUILD6
-rw-r--r--tensorflow/contrib/compiler/xla.py293
-rw-r--r--tensorflow/contrib/data/kernels/prefetching_kernels.cc650
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc76
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD16
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py57
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py160
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py10
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py78
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/input_ops_test.py20
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py79
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops.py8
-rw-r--r--tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h3
-rw-r--r--tensorflow/contrib/lite/g3doc/models.md17
-rw-r--r--tensorflow/contrib/lite/g3doc/overview.md2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h41
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc52
-rw-r--r--tensorflow/contrib/lite/kernels/reduce_test.cc12
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py5
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc6
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.h2
-rw-r--r--tensorflow/contrib/recurrent/python/ops/functional_rnn.py10
-rw-r--r--tensorflow/contrib/tensorboard/BUILD31
-rw-r--r--tensorflow/contrib/tensorboard/plugins/__init__.py2
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/__init__.py24
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/trace.py167
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto60
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/trace_test.py95
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model.py65
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py157
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py35
-rw-r--r--tensorflow/contrib/verbs/verbs_server_lib.cc2
-rw-r--r--tensorflow/core/BUILD19
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt49
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt41
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt30
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt17
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc7
-rw-r--r--tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc3
-rw-r--r--tensorflow/core/common_runtime/executor.cc15
-rw-r--r--tensorflow/core/grappler/clusters/cluster.cc1
-rw-r--r--tensorflow/core/grappler/graph_view.cc30
-rw-r--r--tensorflow/core/grappler/graph_view.h10
-rw-r--r--tensorflow/core/grappler/graph_view_test.cc54
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD39
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h3
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc43
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc226
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h62
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc162
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.cc9
-rw-r--r--tensorflow/core/kernels/BUILD14
-rw-r--r--tensorflow/core/kernels/data/BUILD15
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc37
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h10
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc37
-rw-r--r--tensorflow/core/kernels/data/multi_device_iterator_ops.cc633
-rw-r--r--tensorflow/core/kernels/data/optional_ops.cc8
-rw-r--r--tensorflow/core/kernels/eigen_cuboid_convolution.h367
-rw-r--r--tensorflow/core/kernels/eigen_spatial_convolutions.h268
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op.cc197
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op.h58
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc38
-rw-r--r--tensorflow/core/lib/io/record_reader.cc53
-rw-r--r--tensorflow/core/lib/io/record_reader.h25
-rw-r--r--tensorflow/core/lib/io/record_reader_writer_test.cc7
-rw-r--r--tensorflow/core/ops/array_ops.cc110
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt181
-rw-r--r--tensorflow/core/ops/dataset_ops.cc37
-rw-r--r--tensorflow/core/ops/ops.pbtxt181
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto2
-rw-r--r--tensorflow/core/util/mkl_util.h5
-rw-r--r--tensorflow/go/op/wrappers.go406
-rw-r--r--tensorflow/python/autograph/pyct/compiler.py2
-rw-r--r--tensorflow/python/autograph/pyct/origin_info.py2
-rw-r--r--tensorflow/python/autograph/pyct/origin_info_test.py59
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/BUILD1
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD21
-rw-r--r--tensorflow/python/data/kernel_tests/multi_device_iterator_test.py190
-rw-r--r--tensorflow/python/data/kernel_tests/optional_ops_test.py176
-rw-r--r--tensorflow/python/data/ops/BUILD19
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py21
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py13
-rw-r--r--tensorflow/python/data/ops/multi_device_iterator_ops.py213
-rw-r--r--tensorflow/python/data/ops/optional_ops.py150
-rw-r--r--tensorflow/python/data/util/structure.py131
-rw-r--r--tensorflow/python/data/util/structure_test.py36
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli_test.py3
-rw-r--r--tensorflow/python/debug/lib/debug_graph_reconstruction_test.py3
-rw-r--r--tensorflow/python/eager/BUILD33
-rw-r--r--tensorflow/python/eager/def_function.py235
-rw-r--r--tensorflow/python/eager/def_function_test.py87
-rw-r--r--tensorflow/python/eager/function.py86
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py8
-rw-r--r--tensorflow/python/framework/function.py17
-rw-r--r--tensorflow/python/framework/function_test.py27
-rw-r--r--tensorflow/python/framework/test_util.py2
-rwxr-xr-xtensorflow/python/keras/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/BUILD12
-rw-r--r--tensorflow/python/kernel_tests/extract_volume_patches_op_test.py131
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_test.py34
-rw-r--r--tensorflow/python/ops/distributions/bijector_impl.py39
-rw-r--r--tensorflow/python/ops/distributions/util.py4
-rw-r--r--tensorflow/python/ops/embedding_ops.py8
-rw-r--r--tensorflow/python/ops/image_ops_impl.py8
-rw-r--r--tensorflow/python/ops/image_ops_test.py6
-rw-r--r--tensorflow/python/profiler/model_analyzer_test.py42
-rw-r--r--tensorflow/python/tools/api/generator/create_python_api.py1
-rw-r--r--tensorflow/python/training/monitored_session.py24
-rw-r--r--tensorflow/python/training/quantize_training.i7
-rw-r--r--tensorflow/python/util/nest.py22
-rw-r--r--tensorflow/python/util/nest_test.py6
-rw-r--r--tensorflow/requirements.txt2
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc57
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt4
-rw-r--r--tensorflow/tools/pip_package/BUILD1
-rw-r--r--tensorflow/tools/test/check_futures_test.py3
-rwxr-xr-xtensorflow/workspace.bzl16
-rw-r--r--third_party/repo.bzl15
-rw-r--r--third_party/systemlibs/absl_py.BUILD1
-rw-r--r--third_party/systemlibs/absl_py.absl.flags.BUILD11
-rw-r--r--third_party/systemlibs/absl_py.absl.testing.BUILD7
-rw-r--r--third_party/systemlibs/boringssl.BUILD21
-rw-r--r--third_party/systemlibs/double_conversion.BUILD12
-rw-r--r--third_party/systemlibs/gast.BUILD12
-rw-r--r--third_party/systemlibs/google_cloud_cpp.BUILD6
-rw-r--r--third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD7
-rw-r--r--third_party/systemlibs/googleapis.BUILD12
-rw-r--r--third_party/systemlibs/jsoncpp.BUILD2
-rw-r--r--third_party/systemlibs/syslibs_configure.bzl6
-rw-r--r--tools/bazel.rc5
170 files changed, 6322 insertions, 2562 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 17961829a6..11605871f3 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -9,11 +9,27 @@ http_archive(
"https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13
],
)
-
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
-
closure_repositories()
+http_archive(
+ name = "io_bazel_rules_python",
+ strip_prefix = "rules_python-8b5d0683a7d878b28fffe464779c8a53659fc645",
+ urls = [
+ "https://github.com/bazelbuild/rules_python/archive/8b5d0683a7d878b28fffe464779c8a53659fc645.tar.gz",
+ ],
+)
+load("@io_bazel_rules_python//python:pip.bzl", "pip_repositories")
+pip_repositories()
+
+load("@io_bazel_rules_python//python:pip.bzl", "pip_import")
+pip_import(
+ name = "pip_deps",
+ requirements = "//tensorflow:requirements.txt",
+)
+load("@pip_deps//:requirements.bzl", "pip_install")
+pip_install()
+
# We must check the bazel version before trying to parse any other BUILD
# files, in case the parsing of those build files depends on the bazel
# version we require here.
diff --git a/configure.py b/configure.py
index e9d162fbd2..f0b9fada5e 100644
--- a/configure.py
+++ b/configure.py
@@ -1401,10 +1401,20 @@ def set_grpc_build_flags():
def set_system_libs_flag(environ_cp):
syslibs = environ_cp.get('TF_SYSTEM_LIBS', '')
- syslibs = ','.join(sorted(syslibs.split(',')))
if syslibs and syslibs != '':
+ if ',' in syslibs:
+ syslibs = ','.join(sorted(syslibs.split(',')))
+ else:
+ syslibs = ','.join(sorted(syslibs.split()))
write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs)
+ if 'PREFIX' in environ_cp:
+ write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX'])
+ if 'LIBDIR' in environ_cp:
+ write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR'])
+ if 'INCLUDEDIR' in environ_cp:
+ write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR'])
+
def set_windows_build_flags(environ_cp):
"""Set Windows specific build options."""
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index c8e24e3aff..3610eea42a 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -564,6 +564,7 @@ tf_cc_shared_object(
"$(location //tensorflow/c:version_script.lds)",
],
}),
+ visibility = ["//visibility:public"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
@@ -588,6 +589,7 @@ tf_cc_shared_object(
"$(location //tensorflow:tf_version_script.lds)",
],
}),
+ visibility = ["//visibility:public"],
deps = [
"//tensorflow:tf_exported_symbols.lds",
"//tensorflow:tf_version_script.lds",
@@ -628,6 +630,14 @@ genrule(
continue
fi
+ if [[ $${d} == external* ]]; then
+ extname="$${d#*external/}"
+ extname="$${extname%%/*}"
+ if [[ $${TF_SYSTEM_LIBS:-} == *$${extname}* ]]; then
+ continue
+ fi
+ fi
+
mkdir -p "$@/$${d}"
cp "$${f}" "$@/$${d}/"
done
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 53a72b8443..2de740e145 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -14,9 +14,9 @@
# ==============================================================================
"""Bring in all of the public TensorFlow interface into this module."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+from __future__ import absolute_import as _absolute_import
+from __future__ import division as _division
+from __future__ import print_function as _print_function
import os as _os
@@ -41,6 +41,11 @@ except (ImportError, AttributeError):
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
+# The templated code that replaces the placeholder above sometimes
+# sets the __all__ variable. If it does, we have to be sure to add
+# "contrib".
+if '__all__' in vars():
+ vars()['__all__'].append('contrib')
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable
@@ -51,10 +56,6 @@ _tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disabl
if _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
-del absolute_import
-del division
-del print_function
-
# These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They
# must come from this module. So python adds these symbols for the
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 37be52f57d..3ee31a6a7a 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -68,7 +68,10 @@ tf_cuda_library(
tf_cuda_library(
name = "c_api_internal",
hdrs = ["c_api_internal.h"],
- visibility = ["//tensorflow:internal"],
+ visibility = [
+ "//learning/deepmind/courier:__pkg__",
+ "//tensorflow:internal",
+ ],
deps = [
":c_api",
"//tensorflow/c:c_api",
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 900e84ab58..e219cf3d88 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -560,6 +560,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
dtype(2),
expected=np.array([[5], [2]], dtype=dtype))
+ if dtype in [np.float32, np.float64]:
+ nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1)
+ divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24)
+ np_result = np.true_divide(nums, divs)
+ np_result[:, divs[0] == 0] = 0
+ self._testBinary(gen_math_ops.div_no_nan, nums, divs, expected=np_result)
+
if dtype not in self.complex_types: # floordiv unsupported for complex.
self._testBinary(
gen_math_ops.floor_div,
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 0d9a768a6f..66676452d0 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -55,6 +56,24 @@ XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
+// Implementation of DivNoNan. Pseudo-code:
+// if (y == 0) {
+// return 0
+// } else {
+// return x / y;
+// }
+static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto y_equals_0 = xla::Eq(y, zero);
+ auto zeros = xla::ZerosLike(x);
+ auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y));
+ return result;
+}
+XLA_MAKE_BINARY(DivNoNan,
+ DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
// Implementation of FloorDiv. Pseudo-code:
// if ((x < 0) != (y < 0)) {
// T abs_x = std::abs(x);
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 739e47778a..d5094e8ec5 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -333,10 +333,8 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
}
// Builds the XLA computation.
-//
-// `retvals` is the list of retvals produced by _Retval operators, in index
-// order. `variable_map` is a map from variable ID numbers to XlaOpContext
-// variable states, generated by the symbolic evaluation.
+// `args` is the list of input arguments, `retvals` is the list of retvals
+// produced by _Retval operators, in index order.
// If `return_updated_values_for_all_resources` is true, all resources will be
// included in `resource_updates`, regardless of whether their value changed.
// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index 0d3136b0cc..3ed3afcfce 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -57,6 +57,8 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
// regression.
flags->set_xla_cpu_enable_fast_math(true);
flags->set_xla_gpu_enable_fast_math(true);
+
+ flags->set_xla_force_host_platform_device_count(1);
}
// Allocates flag_values and flag_objects; this function must not be called more
@@ -323,6 +325,17 @@ void AllocateFlags() {
flag_values->xla_gpu_crash_on_verification_failures(),
"Crashes the program on extra verification failures, e.g. cuDNN "
"cross checking failures"),
+ tensorflow::Flag(
+ "xla_force_host_platform_device_count",
+ int32_setter_for(
+ &DebugOptions::set_xla_force_host_platform_device_count),
+ flag_values->xla_force_host_platform_device_count(),
+ "Force the host platform to pretend that there are these many "
+ "host \"devices\". All of these host devices are backed by the same"
+ "threadpool. Setting this to anything other than 1 can increase "
+ "overhead from context switching but we let the user override this "
+ "behavior to help run tests on the host that run models in parallel "
+ "across multiple devices."),
});
ParseFlagsFromEnv(*flag_objects);
}
diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD
index 97fcd37f6b..3abb3855a4 100644
--- a/tensorflow/compiler/xla/rpc/BUILD
+++ b/tensorflow/compiler/xla/rpc/BUILD
@@ -34,19 +34,28 @@ cc_library(
],
)
-tf_cc_binary(
- name = "grpc_service_main_cpu",
+cc_library(
+ name = "grpc_service_main_library",
srcs = ["grpc_service_main.cc"],
deps = [
":grpc_service",
"//tensorflow:grpc++",
"//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings:str_format",
],
)
+tf_cc_binary(
+ name = "grpc_service_main_cpu",
+ deps = [
+ ":grpc_service_main_library",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ ],
+)
+
tf_cc_test(
name = "grpc_client_test",
srcs = ["grpc_client_test.cc"],
diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
index d6b5149a24..522ab99fb1 100644
--- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "grpcpp/server_builder.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/rpc/grpc_service.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -29,8 +30,15 @@ namespace {
int RealMain(int argc, char** argv) {
int32 port = 1685;
+ bool any_address = false;
+ string platform_str;
std::vector<tensorflow::Flag> flag_list = {
- tensorflow::Flag("port", &port, "port to listen on"),
+ tensorflow::Flag("platform", &platform_str,
+ "The XLA platform this service should be bound to"),
+ tensorflow::Flag("port", &port, "The TCP port to listen on"),
+ tensorflow::Flag(
+ "any", &any_address,
+ "Whether to listen to any host address or simply localhost"),
};
string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
@@ -40,19 +48,24 @@ int RealMain(int argc, char** argv) {
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
+ se::Platform* platform = nullptr;
+ if (!platform_str.empty()) {
+ platform = PlatformUtil::GetPlatform(platform_str).ValueOrDie();
+ }
std::unique_ptr<xla::GRPCService> service =
- xla::GRPCService::NewService().ConsumeValueOrDie();
+ xla::GRPCService::NewService(platform).ConsumeValueOrDie();
::grpc::ServerBuilder builder;
- string server_address(absl::StrFormat("localhost:%d", port));
+ string server_address(
+ absl::StrFormat("%s:%d", any_address ? "[::]" : "localhost", port));
+ builder.SetMaxReceiveMessageSize(INT_MAX);
builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials());
builder.RegisterService(service.get());
std::unique_ptr<::grpc::Server> server(builder.BuildAndStart());
LOG(INFO) << "Server listening on " << server_address;
server->Wait();
-
return 0;
}
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 2bc50c70cf..e800cf470c 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -593,6 +593,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/strings",
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index bf627986a5..b7103118ac 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -50,6 +50,7 @@ cc_library(
"//tensorflow/compiler/xla/service/cpu:cpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/stream_executor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
@@ -462,12 +463,15 @@ cc_library(
],
copts = runtime_copts(),
deps = [
+ "//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "//tensorflow/stream_executor",
+ "@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 7e1590955a..20cf855735 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -17,19 +17,29 @@ limitations under the License.
#include <functional>
+#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor.h"
namespace xla {
namespace cpu {
namespace runtime {
-XfeedManager* GetXfeedManager() {
- static XfeedManager* manager = new XfeedManager;
- return manager;
+XfeedManager* GetXfeedManager(int device_ordinal) {
+ static tensorflow::gtl::FlatMap<int, XfeedManager*>* managers =
+ new tensorflow::gtl::FlatMap<int, XfeedManager*>();
+ static absl::Mutex* mutex = new absl::Mutex();
+
+ absl::MutexLock lock(mutex);
+ auto it = managers->find(device_ordinal);
+ if (it == managers->end()) {
+ it = managers->emplace(device_ordinal, new XfeedManager()).first;
+ }
+ return it->second;
}
extern const char* const kEigenMatMulF16SymbolName =
@@ -118,14 +128,18 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
} // namespace
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
-__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
- const void* shape,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "AcquireInfeedBufferForDequeue: "
- << ShapeString(shape, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_AcquireInfeedBufferForDequeue(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "AcquireInfeedBufferForDequeue: "
+ << ShapeString(shape, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
// Wait until there's a buffer to dequeue.
xla::cpu::runtime::XfeedBuffer* buffer =
xfeed->infeed()->BlockingDequeueBuffer();
@@ -138,15 +152,18 @@ __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
-__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length,
- void* buffer_ptr,
- const void* shape_ptr,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "ReleaseInfeedBufferAfterDeque: "
- << ShapeString(shape_ptr, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "ReleaseInfeedBufferAfterDeque: "
+ << ShapeString(shape_ptr, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
xla::StatusOr<xla::Shape> shape =
xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
@@ -154,14 +171,18 @@ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length,
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
-__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length,
- const void* shape_ptr,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "AcquireOutfeedBufferForPopulation: "
- << ShapeString(shape_ptr, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape_ptr, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "AcquireOutfeedBufferForPopulation: "
+ << ShapeString(shape_ptr, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
// Wait until there's a buffer to dequeue.
xla::cpu::runtime::XfeedBuffer* buffer =
xfeed->outfeed()->BlockingDequeueBuffer();
@@ -174,15 +195,18 @@ __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length,
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
-__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length,
- void* buffer_ptr,
- const void* shape_ptr,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: "
- << ShapeString(shape_ptr, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: "
+ << ShapeString(shape_ptr, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
xla::StatusOr<xla::Shape> shape =
xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index e6345e0344..b2e760a224 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -26,6 +26,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_
+#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
#include "tensorflow/compiler/xla/types.h"
@@ -80,8 +81,9 @@ extern const char* const kKeyValueSortF64SymbolName;
// prefix.
extern const char* const kXlaCpuRuntimeSymbolNamePrefix;
-// Returns the infeed manager used by the CPU runtime.
-XfeedManager* GetXfeedManager();
+// Returns the infeed manager used by the CPU runtime for the CPU device
+// `device_ordinal`. Note the device ordinal does not name a CPU
+XfeedManager* GetXfeedManager(int device_ordinal);
} // namespace runtime
} // namespace cpu
@@ -89,6 +91,18 @@ XfeedManager* GetXfeedManager();
extern "C" {
+// Some things common to all of the runtime entry points below:
+//
+// * The shape pointer and shape_length reflect values that can be deserialized
+// via llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass
+// reified type information from the generated program to the runtime, which
+// helps check the type safety and contract for the emitted-code/runtime
+// communication.
+//
+// * run_options is used to look up the device ordinal for the stream executor
+// we're executing under. If it is null the device ordinal is assumed to be
+// 0 (this behavior helps in writing tests).
+
// Note: in the runtime entry points below, the shape pointer and shape_length
// reflect values that can be deserialized via
// llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass reified
@@ -101,7 +115,8 @@ extern "C" {
// the length would be more exact, but the length check is chosen as a
// tradeoff between error checking and speed/simplicity.
extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
- xla::int32 buffer_length, const void* shape, xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape, xla::int32 shape_length);
// Relinquishes the next infeed buffer that was returned by
// __xla_cpu_runtime_AcquireInfeedBufferForDequeue. Once this call
@@ -116,13 +131,14 @@ extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
// implemented we will add support for multiple outstanding buffers
// that can be returned out of order.
extern void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
- xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
- xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length);
// Blocks until the next outfeed buffer is available to be populated, then
// returns it.
extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
- xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape_ptr, xla::int32 shape_length);
// Relinquishes the outfeed buffer after it has been populated.
// buffer_ptr must have been previously returned by
@@ -134,8 +150,8 @@ extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
// acquired, i.e., there may only be one outstanding outfeed buffer in
// use by the runtime.
extern void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
- xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
- xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length);
} // extern "C"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index 5519a43b2f..1cc2844470 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/stream_executor/stream_executor.h"
namespace xla {
@@ -128,7 +129,8 @@ Status CpuTransferManager::TransferLiteralToInfeed(
buffers.push_back(buffer);
}
- cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed_manager =
+ cpu::runtime::GetXfeedManager(executor->device_ordinal());
xfeed_manager->infeed()->EnqueueBuffersAtomically(buffers);
cleanup.release();
@@ -141,7 +143,8 @@ Status CpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor,
TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer,
TransferBufferToInfeedInternal(executor, size, source));
- cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed_manager =
+ cpu::runtime::GetXfeedManager(executor->device_ordinal());
xfeed_manager->infeed()->EnqueueBuffersAtomically({buffer});
return Status::OK();
@@ -265,7 +268,8 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
buffer_pointers.push_back(b.get());
}
- cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed_manager =
+ cpu::runtime::GetXfeedManager(executor->device_ordinal());
xfeed_manager->outfeed()->EnqueueBuffersAtomically(buffer_pointers);
VLOG(2) << "Waiting for buffer to be notified as populated.";
std::vector<Shape> outfed_shapes;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index c32f2533ee..c3e8020783 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -404,13 +404,12 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
llvm::Value * shape_ptr,
llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
- // The signature of the acquire infeed buffer function is:
- //
- // (void*)(int32 length);
llvm::Type* int32_type = b_.getInt32Ty();
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
llvm::FunctionType* acquire_type = llvm::FunctionType::get(
- i8_ptr_type, {int32_type, i8_ptr_type, int32_type},
+ i8_ptr_type,
+ {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type,
+ /*shape_ptr*/ i8_ptr_type, /*shape_length*/ int32_type},
/*isVarArg=*/false);
llvm::Function* acquire_func;
@@ -423,11 +422,11 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
}
acquire_func->setCallingConv(llvm::CallingConv::C);
- // The signature of the release infeed buffer function is:
- //
- // (void)(int32 length, void* buffer);
llvm::FunctionType* release_type = llvm::FunctionType::get(
- b_.getVoidTy(), {int32_type, i8_ptr_type, i8_ptr_type, int32_type},
+ b_.getVoidTy(),
+ {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type,
+ /*buffer_ptr*/ i8_ptr_type, /*shape_ptr*/ i8_ptr_type,
+ /*shape_length*/ int32_type},
/*isVarArg=*/false);
llvm::Function* release_func;
@@ -444,9 +443,9 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
// of size exactly 'length_32', and the runtime is responsible for
// check-failing the process if there is a mismatch, versus passing us back a
// buffer that we might overrun.
- llvm::Value* acquired_pointer =
- Call(acquire_func,
- {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)});
+ llvm::Value* acquired_pointer = Call(
+ acquire_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
+ shape_ptr, b_.getInt32(shape_length)});
if (kind == XfeedKind::kInfeed) {
// Copy to the program buffer address from the acquired buffer.
@@ -458,8 +457,8 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
/*SrcAlign=*/1, length_32);
}
- Call(release_func, {b_.getInt32(length_32), acquired_pointer, shape_ptr,
- b_.getInt32(shape_length)});
+ Call(release_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
+ acquired_pointer, shape_ptr, b_.getInt32(shape_length)});
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
index 8fe65f488a..cc38b81455 100644
--- a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
@@ -66,9 +66,9 @@ void ProcessNextBuffer(int32 length) {
auto shape = ShapeUtil::MakeShape(U8, {length});
string bytes = shape.SerializeAsString();
void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
- length, bytes.data(), bytes.size());
- __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer,
- bytes.data(), bytes.size());
+ /*run_options=*/nullptr, length, bytes.data(), bytes.size());
+ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
+ /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size());
}
// Performs the acquire/release sequence on the outfeed, as the generated CPU
@@ -76,16 +76,16 @@ void ProcessNextBuffer(int32 length) {
void ProcessNextOutfeedBuffer(int32 length, const Shape& shape) {
string bytes = shape.SerializeAsString();
void* buffer = __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
- length, bytes.data(), bytes.size());
+ /*run_options=*/nullptr, length, bytes.data(), bytes.size());
__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
- length, buffer, bytes.data(), bytes.size());
+ /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size());
}
TEST_F(InfeedManagerTest, SingleThreadedSequential) {
TestInfeedBuffer* a = new TestInfeedBuffer(64);
TestInfeedBuffer* b = new TestInfeedBuffer(32);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
xfeed->infeed()->EnqueueBuffersAtomically({a});
xfeed->infeed()->EnqueueBuffersAtomically({b});
@@ -97,7 +97,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) {
TestInfeedBuffer* a = new TestInfeedBuffer(64);
TestInfeedBuffer* b = new TestInfeedBuffer(32);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
xfeed->infeed()->EnqueueBuffersAtomically({a});
ProcessNextBuffer(a->length());
@@ -108,7 +108,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) {
TEST_F(InfeedManagerTest, MultiThreaded) {
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
const int32 length = 64;
@@ -130,7 +130,7 @@ TEST_F(InfeedManagerTest, MultiThreaded) {
TEST_F(InfeedManagerTest, OutfeedWrongShape) {
TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
xfeed->outfeed()->EnqueueBuffersAtomically({b});
ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {33}));
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index f528e62b17..9eee9ebbd7 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -76,54 +76,23 @@ StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
return se::DeviceMemory<uint8>(buffer_addr);
}
-// Determines whether we can safely perform a winograd non-fused convolution for
-// the given input and output shapes. This works around b/68264959, an integer
-// overflow in cuDNNv5 and cuDNNv6.
-bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape,
- const Shape& output_shape,
- const ConvolutionDimensionNumbers& dnums,
- se::StreamExecutor* stream_exec) {
- // Skip this check for cudnn7 and newer.
- auto version = stream_exec->AsDnn()->GetVersion();
- if (version.ok() && version.ValueOrDie().major_version() >= 7) {
- return true;
- }
-
- int64 batch = input_shape.dimensions(dnums.input_batch_dimension());
- int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension());
- int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0));
- int64 in_cols =
- dnums.input_spatial_dimensions_size() == 1
- ? 1
- : input_shape.dimensions(dnums.input_spatial_dimensions(1));
- int64 out_depths = output_shape.dimensions(dnums.output_feature_dimension());
-
- int64 total_size = CeilOfRatio(batch, int64{16}) *
- std::max(in_depths, out_depths) * in_cols * in_rows *
- sizeof(float);
-
- const int64 threshold = 1L << 31;
- return total_size < threshold;
-}
-
std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
- bool with_winograd_nonfused,
se::StreamExecutor* stream_exec) {
std::vector<AlgorithmDesc> algorithms;
+ bool succ = false;
switch (kind) {
case CudnnConvKind::kBackwardFilter:
- CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(
- with_winograd_nonfused, &algorithms));
+ succ =
+ stream_exec->GetConvolveBackwardFilterAlgorithms(true, &algorithms);
break;
case CudnnConvKind::kBackwardInput:
- CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(
- with_winograd_nonfused, &algorithms));
+ succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms);
break;
case CudnnConvKind::kForward:
- CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused,
- &algorithms));
+ succ = stream_exec->GetConvolveAlgorithms(true, &algorithms);
break;
}
+ DCHECK(succ);
return algorithms;
}
@@ -282,8 +251,6 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
}
}();
- const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
- input_shape, output_shape, *params.dnums, stream_exec_);
se::dnn::ProfileResult best_result;
int64 best_result_bytes_used = 0;
@@ -292,8 +259,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// particular reason to use it, as any algorithm sufficies. It doesn't make
// this algorithm considered correct, though.
optional<AlgorithmDesc> first_algorithm;
- for (const AlgorithmDesc& alg :
- GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) {
+ for (const AlgorithmDesc& alg : GetAlgorithms(params.kind, stream_exec_)) {
ScratchAllocator scratch_allocator(device_ordinal, allocator);
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 06b6d5b559..b91b2406e2 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1173,80 +1173,85 @@ StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
TF_RET_CHECK(
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
<< "Sort keys and values must have the same dimensions";
- TF_RET_CHECK(rank > 0 && rank <= 2)
- << "Sort is only supported for rank-1 and rank-2 shapes, rank is: "
- << rank;
TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort";
- // We need to sort and array of keys and an array of values, where the
+ // We need to sort an array of keys and an array of values, where the
// sorted order of the values is determined by the keys. The simplest(?)
// way to do this is to go to an array-of-pairs representation, sort the
// array using the keys, and then go back to pair-of-arrays.
VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
VLOG(3) << "HandleSort values_literal: " << values_literal.ToString();
- auto sort_r1 = [](const Literal& keys_literal,
- const Literal& values_literal) {
- const auto& keys_data = keys_literal.data<KeyType>();
- const auto& values_data = values_literal.data<ValueType>();
-
- using kv_pair = std::pair<KeyType, ValueType>;
- std::vector<kv_pair> key_value_vector;
- CHECK_EQ(keys_data.size(), values_data.size());
- key_value_vector.reserve(keys_data.size());
- for (int i = 0; i < keys_data.size(); ++i) {
- key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i]));
- }
- std::sort(key_value_vector.begin(), key_value_vector.end(),
- [](const kv_pair& a, const kv_pair& b) {
- return SafeLess<KeyType>(a.first, b.first);
- });
- std::vector<KeyType> result_keys;
- std::vector<ValueType> result_values;
- for (const auto& key_value : key_value_vector) {
- result_keys.push_back(key_value.first);
- result_values.push_back(key_value.second);
- }
- Literal result_keys_literal(keys_literal.shape());
- result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
- Literal result_values_literal(values_literal.shape());
- result_values_literal.PopulateR1(
- absl::Span<const ValueType>(result_values));
- return std::make_pair(std::move(result_keys_literal),
- std::move(result_values_literal));
- };
-
- Literal result_tuple;
- if (rank == 1) {
- auto result_pair = sort_r1(keys_literal, values_literal);
- result_tuple =
- LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
- } else {
- // For R2 sort, the desired semantics are to sort each matrix row
- // independently.
- Literal keys_result_literal(keys_literal.shape());
- Literal values_result_literal(values_literal.shape());
- int64 r1_length = keys_literal.shape().dimensions(1);
- for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
- TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
- keys_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- TF_ASSIGN_OR_RETURN(auto values_r1_slice,
- values_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
- TF_ASSIGN_OR_RETURN(auto sorted_keys,
- r1_result_pair.first.Reshape({1, r1_length}));
- TF_ASSIGN_OR_RETURN(auto sorted_values,
- r1_result_pair.second.Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
- sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
- TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
- sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
- }
- result_tuple =
- LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
+ if (rank == 0) {
+ // Nothing to sort.
+ return LiteralUtil::MakeTuple({&keys_literal, &values_literal});
}
+ Literal keys_result_literal(keys_literal.shape());
+ Literal values_result_literal(values_literal.shape());
+ std::vector<int64> zero_base(rank, 0);
+ std::vector<int64> increment(rank, 1);
+ int64 sort_dim = sort->dimensions(0);
+ int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim);
+ increment[sort_dim] = sort_dim_elements;
+ // Iterate through each dimension except 'sort_dim'.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ keys_literal.shape(), zero_base,
+ AsInt64Slice(keys_literal.shape().dimensions()), increment,
+ [&](absl::Span<const int64> indices) -> StatusOr<bool> {
+ // Extract a slice from the keys and values literals that correspond to
+ // exactly the row in dimension 'sort_dim'.
+ std::vector<int64> limit_indices(indices.begin(), indices.end());
+ std::for_each(limit_indices.begin(), limit_indices.end(),
+ [](int64& index) { ++index; });
+ limit_indices[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto keys_to_sort,
+ keys_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& keys_data = keys_to_sort.data<KeyType>();
+ TF_ASSIGN_OR_RETURN(auto values_to_sort,
+ values_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& values_data = values_to_sort.data<ValueType>();
+ using kv_pair = std::pair<KeyType, ValueType>;
+ std::vector<kv_pair> key_value_vector;
+ key_value_vector.reserve(keys_data.size());
+ for (int i = 0; i < keys_data.size(); ++i) {
+ key_value_vector.push_back(
+ std::make_pair(keys_data[i], values_data[i]));
+ }
+ std::sort(key_value_vector.begin(), key_value_vector.end(),
+ [](const kv_pair& a, const kv_pair& b) {
+ return SafeLess<KeyType>(a.first, b.first);
+ });
+ std::vector<KeyType> result_keys;
+ std::vector<ValueType> result_values;
+ for (const auto& key_value : key_value_vector) {
+ result_keys.push_back(key_value.first);
+ result_values.push_back(key_value.second);
+ }
+ Literal sorted_keys(ShapeUtil::MakeShape(
+ keys_literal.shape().element_type(), {sort_dim_elements}));
+ sorted_keys.PopulateR1(absl::Span<const KeyType>(result_keys));
+ Literal sorted_values(ShapeUtil::MakeShape(
+ values_literal.shape().element_type(), {sort_dim_elements}));
+ sorted_values.PopulateR1(absl::Span<const ValueType>(result_values));
+ std::vector<int64> slice_dimensions(rank, 1);
+ slice_dimensions[sort_dim] = sort_dim_elements;
+ std::vector<int64> start_indices(rank, 0);
+ TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped,
+ sorted_keys.Reshape(slice_dimensions));
+ TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
+ sorted_keys_reshaped, start_indices, indices, slice_dimensions));
+ TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped,
+ sorted_values.Reshape(slice_dimensions));
+ TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
+ sorted_values_reshaped, start_indices, indices, slice_dimensions));
+ return true;
+ }));
+
+ Literal result_tuple;
+ result_tuple =
+ LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
return std::move(result_tuple);
}
@@ -1292,15 +1297,6 @@ StatusOr<Literal> EvaluateSort(HloInstruction* sort,
} // namespace
Status HloEvaluator::HandleSort(HloInstruction* sort) {
- const int64 sort_dim = sort->dimensions(0);
- const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape());
- if (sort_dim != rank - 1) {
- return Unimplemented(
- "Trying to sort along dimension %d, which is not the last "
- "dimension",
- sort_dim);
- }
-
if (!ShapeUtil::IsTuple(sort->shape())) {
return DefaultAction(sort);
} else {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 8fb17a0033..35391ecf8a 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+#include <cmath>
+
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
@@ -41,7 +43,9 @@ template <typename T>
using is_complex64_t = std::is_same<T, complex64>;
// It's UB to use std::sort with std::less<float>, because of NaNs. Define
-// "safe" less functions which are actually strict weak orders.
+// "safe" less functions which are actually strict weak orders. -NaN and NaN
+// should appear at the beginning and end of the ordering, and -0.0 should
+// appear before 0.0.
template <
typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value>::type* = nullptr>
@@ -49,26 +53,33 @@ bool SafeLess(const NativeT& a, const NativeT& b) {
return a < b;
}
-template <typename NativeT,
- typename std::enable_if<
- std::is_floating_point<NativeT>::value ||
- std::is_same<NativeT, bfloat16>::value>::type* = nullptr>
+template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
bool SafeLess(const NativeT& a, const NativeT& b) {
- if (std::isnan(b)) {
- return !std::isnan(a);
- } else {
- return a < b;
+ bool lhs_is_negative = std::signbit(a);
+ bool rhs_is_negative = std::signbit(b);
+ // If the signs are different, we can just compare the signs.
+ if (lhs_is_negative != rhs_is_negative) {
+ return lhs_is_negative && !rhs_is_negative;
+ }
+ bool lhs_nan = std::isnan(a);
+ bool rhs_nan = std::isnan(b);
+ // Exactly one number is nan?
+ if (lhs_nan != rhs_nan) {
+ if (lhs_nan) {
+ return lhs_is_negative;
+ }
+ return !rhs_is_negative;
}
+ return a < b;
}
-template <typename NativeT, typename std::enable_if<std::is_same<
- NativeT, Eigen::half>::value>::type* = nullptr>
+template <typename NativeT,
+ typename std::enable_if<
+ std::is_same<NativeT, bfloat16>::value ||
+ std::is_same<NativeT, Eigen::half>::value>::type* = nullptr>
bool SafeLess(const NativeT& a, const NativeT& b) {
- if (Eigen::half_impl::isnan(b)) {
- return !Eigen::half_impl::isnan(a);
- } else {
- return a < b;
- }
+ return SafeLess(static_cast<float>(a), static_cast<float>(b));
}
// Templated DfsHloVisitor for use by HloEvaluator.
@@ -1527,47 +1538,55 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
!std::is_same<NativeT, bool>::value>::type* = nullptr>
Status HandleSort(HloInstruction* sort) {
auto keys = sort->operand(0);
- auto rank = ShapeUtil::Rank(keys->shape());
- TF_RET_CHECK(rank > 0 && rank <= 2)
- << "Sort is only supported for R1 and R2 shapes";
TF_RET_CHECK(sort->operand_count() == 1)
<< "Typed visitor does not support key-value sort";
const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys);
-
- auto sort_r1 = [this](const Literal& keys_literal) {
- VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
- const auto& keys_data = keys_literal.data<ReturnT>();
-
- std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end());
- std::sort(result_data.begin(), result_data.end(),
- [](const ReturnT& a, const ReturnT& b) {
- return SafeLess<ReturnT>(a, b);
- });
- Literal result_literal(keys_literal.shape());
- result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
- return result_literal;
- };
-
- if (rank == 1) {
- parent_->evaluated_[sort] = std::move(sort_r1(keys_literal));
- } else {
- // For R2 sort, the desired semantics are to sort each matrix row
- // independently.
- Literal result_literal(keys_literal.shape());
- int64 r1_length = keys->shape().dimensions(1);
- for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
- TF_ASSIGN_OR_RETURN(auto r1_slice,
- keys_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- auto r1_result = sort_r1(r1_slice);
- TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
- r1_result, {0, 0}, {row, 0}, {1, r1_length}));
- }
- parent_->evaluated_[sort] = std::move(result_literal);
+ int64 sort_dim = sort->dimensions(0);
+ int64 sort_dim_elements = keys->shape().dimensions(sort_dim);
+ int64 rank = ShapeUtil::Rank(keys->shape());
+ if (rank == 0) {
+ // Nothing to sort.
+ parent_->evaluated_[sort] = keys_literal.Clone();
+ return Status::OK();
}
+ Literal result_literal(keys_literal.shape());
+ std::vector<int64> zero_base(rank, 0);
+ std::vector<int64> increment(rank, 1);
+ increment[sort_dim] = sort_dim_elements;
+ // Iterate through each dimension except 'sort_dim'.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()),
+ increment, [&](absl::Span<const int64> indices) -> StatusOr<bool> {
+ // Extract a slice from the literal that corresponds to exactly the
+ // row in dimension 'sort_dim'.
+ std::vector<int64> limit_indices(indices.begin(), indices.end());
+ std::for_each(limit_indices.begin(), limit_indices.end(),
+ [](int64& index) { ++index; });
+ limit_indices[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto row_to_sort,
+ keys_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& row_data = row_to_sort.data<NativeT>();
+
+ std::vector<NativeT> result_data(row_data.begin(), row_data.end());
+ std::sort(result_data.begin(), result_data.end(),
+ [](const NativeT& a, const NativeT& b) {
+ return SafeLess<NativeT>(a, b);
+ });
+ Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(),
+ {sort_dim_elements}));
+ sorted_row.PopulateR1(absl::Span<const NativeT>(result_data));
+ std::vector<int64> slice_dimensions(rank, 1);
+ slice_dimensions[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped,
+ sorted_row.Reshape(slice_dimensions));
+ std::vector<int64> start_indices(rank, 0);
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
+ sorted_row_reshaped, start_indices, indices, slice_dimensions));
+ return true;
+ }));
+ parent_->evaluated_[sort] = std::move(result_literal);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc
index 178a78ede0..c522e7ae23 100644
--- a/tensorflow/compiler/xla/service/platform_util.cc
+++ b/tensorflow/compiler/xla/service/platform_util.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/ascii.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -217,9 +218,12 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) {
if (platform->id() == se::host::kHostPlatformId) {
// On host "devices", StreamExecutor exports a device for each hardware
// thread. Because we parallelize a single computation across threads, it
- // doesn't make sense to expose these as separate devices, so fix the number
- // of devices to one.
- device_count = 1;
+ // doesn't make sense to expose these as separate devices, so by default we
+ // fix the number of devices to one. However we do let the user override
+ // this behavior to help run tests on the host that run models in parallel
+ // across multiple devices.
+ device_count = legacy_flags::GetDebugOptionsFromFlags()
+ .xla_force_host_platform_device_count();
}
std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
VLOG(1) << "Initializing devices";
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index fd3e3bfa94..f474ecb18c 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -2168,3 +2168,21 @@ xla_test(
"//tensorflow/core:lib",
],
)
+
+tf_cc_test(
+ name = "multiple_devices_on_host_test",
+ srcs = ["multiple_devices_on_host_test.cc"],
+ args = ["--xla_force_host_platform_device_count=4"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
diff --git a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc
new file mode 100644
index 0000000000..c530591c6e
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+StatusOr<XlaComputation> BuildComputation() {
+ XlaBuilder b("computation");
+ Shape scalar_s32 = ShapeUtil::MakeShape(S32, {});
+ XlaOp infeed = InfeedWithToken(CreateToken(&b), scalar_s32);
+ return b.Build(
+ OutfeedWithToken(GetTupleElement(infeed, 0) +
+ ConstantLiteral(&b, LiteralUtil::CreateR0<int32>(1)),
+ GetTupleElement(infeed, 1), scalar_s32, ""));
+}
+
+void CompileAndExecute(
+ LocalExecutable* executable, int device_ordinal, LocalClient* client,
+ absl::Mutex* results_mutex,
+ std::vector<std::pair<int, StatusOr<ScopedShapedBuffer>>>* results) {
+ xla::ExecutableRunOptions execute_options;
+ execute_options.set_intra_op_thread_pool(
+ client->backend().eigen_intra_op_thread_pool_device());
+ execute_options.set_device_ordinal(device_ordinal);
+ execute_options.set_allocator(
+ xla::ClientLibrary::GetXlaService(client->platform())
+ ->backend()
+ .memory_allocator());
+ StatusOr<ScopedShapedBuffer> result = executable->Run({}, execute_options);
+ {
+ absl::MutexLock lock(results_mutex);
+ results->emplace_back(device_ordinal, std::move(result));
+ }
+}
+
+void TestWithDeviceCount(const int device_count) {
+ // Run `device_count` copies of the XLA program built by BuildComputation.
+ TF_ASSERT_OK_AND_ASSIGN(
+ se::Platform* const platform,
+ perftools::gputools::MultiPlatformManager::PlatformWithName("Host"));
+ xla::LocalClientOptions client_options;
+ client_options.set_platform(platform);
+ TF_ASSERT_OK_AND_ASSIGN(
+ LocalClient* const client,
+ xla::ClientLibrary::GetOrCreateLocalClient(client_options));
+
+ TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, BuildComputation());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<LocalExecutable> executable,
+ client->Compile(xla_computation, {}, xla::ExecutableBuildOptions{}));
+ std::vector<tensorflow::Thread*> threads;
+ absl::Mutex results_mutex;
+ std::vector<std::pair<int, StatusOr<ScopedShapedBuffer>>> results;
+ tensorflow::Env* env = tensorflow::Env::Default();
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ tensorflow::Thread* t = env->StartThread(
+ tensorflow::ThreadOptions{}, absl::StrCat("thread-", device_ordinal),
+ [&executable, device_ordinal, client, &results_mutex, &results] {
+ CompileAndExecute(executable.get(), device_ordinal, client,
+ &results_mutex, &results);
+ });
+ threads.push_back(t);
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ TF_ASSERT_OK(client->TransferToInfeedLocal(
+ LiteralUtil::CreateR0<int32>(device_ordinal * 100), device_ordinal));
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ TF_ASSERT_OK_AND_ASSIGN(Literal outfeed,
+ client->TransferFromOutfeedLocal(
+ ShapeUtil::MakeShape(S32, {}), device_ordinal));
+ EXPECT_EQ(outfeed, LiteralUtil::CreateR0<int32>(device_ordinal * 100 + 1));
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ delete threads[device_ordinal];
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ TF_ASSERT_OK(results[device_ordinal].second.status());
+ }
+}
+
+// NB! This test requires --xla_force_host_platform_device_count=4
+
+TEST(MultipleDeviceOnHostTest, OneDevice) { TestWithDeviceCount(1); }
+
+TEST(MultipleDeviceOnHostTest, TwoDevices) { TestWithDeviceCount(2); }
+
+TEST(MultipleDeviceOnHostTest, ThreeDevices) { TestWithDeviceCount(3); }
+
+TEST(MultipleDeviceOnHostTest, FourDevices) { TestWithDeviceCount(4); }
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index b53f89d63b..60d25a6407 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -200,6 +200,15 @@ message DebugOptions {
// among different algorithms.
bool xla_gpu_crash_on_verification_failures = 101;
+ // Force the host platform to pretend that there are these many host
+ // "devices". All these devices are backed by the same threadpool. Defaults
+ // to 1.
+ //
+ // Setting this to anything other than 1 can increase overhead from context
+ // switching but we let the user override this behavior to help run tests on
+ // the host that run models in parallel across multiple devices.
+ int32 xla_force_host_platform_device_count = 102;
+
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index e1af52cd96..ae5ca32bcf 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -126,11 +126,16 @@ py_library(
}) + if_not_windows_cuda([
"//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols
]) + if_not_windows([
- "//tensorflow/contrib/bigtable", # depends on bigtable
- "//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows
- "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows
- "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
- ]),
+ ]) + select({
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "//tensorflow/contrib/bigtable",
+ "//tensorflow/contrib/cloud:cloud_py",
+ "//tensorflow/contrib/tensorrt:init_py",
+ "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
+ ],
+ }),
)
cc_library(
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index af7006bff2..8edb5d6c64 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -739,21 +739,22 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
- std::vector<int32> non_empty_partitions;
- for (int i = 0; i < partition_ids.size() - 1; ++i) {
+ partition_boundaries.push_back(0);
+ for (int i = 1; i < partition_ids.size(); ++i) {
// Make sure the input is sorted by partition_ids;
- CHECK_LE(partition_ids(i), partition_ids(i + 1));
- if (i == 0 || partition_ids(i) != partition_ids(i - 1)) {
+ OP_REQUIRES(context, partition_ids(i - 1) <= partition_ids(i),
+ errors::InvalidArgument("Partition IDs must be sorted."));
+ if (partition_ids(i) != partition_ids(i - 1)) {
partition_boundaries.push_back(i);
- // Some partitions might only have bias feature. We don't want to split
- // those so check that the partition has at least 2 features.
- if (partition_ids(i) == partition_ids(i + 1)) {
- non_empty_partitions.push_back(partition_boundaries.size() - 1);
- }
}
}
- if (partition_ids.size() > 0) {
- partition_boundaries.push_back(partition_ids.size());
+ std::vector<int32> non_empty_partitions;
+ partition_boundaries.push_back(partition_ids.size());
+ for (int i = 0; i < partition_boundaries.size() - 1; ++i) {
+ // We want to ignore partitions with only the bias term.
+ if (partition_boundaries[i + 1] - partition_boundaries[i] >= 2) {
+ non_empty_partitions.push_back(i);
+ }
}
int num_elements = non_empty_partitions.size();
Tensor* output_partition_ids_t = nullptr;
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index 94ea7bc2eb..c050c2ed7f 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -577,6 +577,92 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(gains), 0)
self.assertEqual(len(splits), 0)
+ def testLastOneEmpty(self):
+ with self.cached_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Feature ID |
+ # i0 | (0.2, 0.12) | 0 | 1,2 |
+ # i1 | (-0.5, 0.07) | 0 | |
+ # i2 | (1.2, 0.2) | 0 | 2 |
+ # i3 | (4.0, 0.13) | 1 | |
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = [0, 0, 0, 1]
+ indices = [[0, 0], [0, 1], [2, 0]]
+ values = array_ops.constant([1, 2, 2], dtype=dtypes.int64)
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ class_id = -1
+
+ split_handler = categorical_split_handler.EqualitySplitHandler(
+ l1_regularization=0.1,
+ l2_regularization=1,
+ tree_complexity_regularization=0,
+ min_node_weight=0,
+ sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]),
+ feature_column_group_id=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ init_stamp_token=0)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_1]):
+ are_splits_ready, partitions, gains, splits = (
+ split_handler.make_splits(0, 1, class_id))
+ are_splits_ready, partitions, gains, splits = (
+ sess.run([are_splits_ready, partitions, gains, splits]))
+ self.assertTrue(are_splits_ready)
+ self.assertAllEqual([0], partitions)
+
+ # Check the split on partition 0.
+ # -(0.2 + 1.2 - 0.1) / (0.12 + 0.2 + 1)
+ expected_left_weight = -0.9848484848484846
+
+ # (0.2 + 1.2 - 0.1) ** 2 / (0.12 + 0.2 + 1)
+ expected_left_gain = 1.2803030303030298
+
+ # -(-0.5 + 0.1) / (0.07 + 1)
+ expected_right_weight = 0.37383177570093457
+
+ # (-0.5 + 0.1) ** 2 / (0.07 + 1)
+ expected_right_gain = 0.14953271028037385
+
+ # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
+ expected_bias_gain = 0.46043165467625885
+
+ split_info = split_info_pb2.SplitInfo()
+ split_info.ParseFromString(splits[0])
+ left_child = split_info.left_child.vector
+ right_child = split_info.right_child.vector
+ split_node = split_info.split_node.categorical_id_binary_split
+
+ self.assertEqual(0, split_node.feature_column)
+
+ self.assertEqual(2, split_node.feature_id)
+
+ self.assertAllClose(
+ expected_left_gain + expected_right_gain - expected_bias_gain, gains[0],
+ 0.00001)
+
+ self.assertAllClose([expected_left_weight], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight], right_child.value, 0.00001)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index 1c432b6e0b..c0763f4c0e 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -406,7 +406,6 @@ tensorflow/contrib/summary
tensorflow/contrib/tensorboard
tensorflow/contrib/tensorboard/plugins
tensorflow/contrib/tensorboard/plugins/projector
-tensorflow/contrib/tensorboard/plugins/trace
# TODO(sami): Add cmake implementations.
# tensorflow/contrib/tensorrt/python
# tensorflow/contrib/tensorrt/python/ops
diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt
index cf1ee2ad76..42afbd9105 100644
--- a/tensorflow/contrib/cmake/python_protos.txt
+++ b/tensorflow/contrib/cmake/python_protos.txt
@@ -12,7 +12,6 @@ tensorflow/contrib/mpi_collectives
tensorflow/contrib/session_bundle
tensorflow/contrib/tensor_forest/proto
tensorflow/contrib/tensorboard/plugins/projector
-tensorflow/contrib/tensorboard/plugins/trace
tensorflow/contrib/tpu/proto
tensorflow/contrib/tpu/profiler
tensorflow/contrib/training/python/training
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index 3b0e8f6cda..f51bfc1b22 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -5,7 +5,10 @@ package(default_visibility = [":friends"])
package_group(
name = "friends",
includes = ["//tensorflow/compiler/jit:friends"],
- packages = ["//tensorflow/..."],
+ packages = [
+ "//tensorflow/...",
+ "//third_party/py/tensor2tensor/...",
+ ],
)
load("//tensorflow:tensorflow.bzl", "tf_py_test")
@@ -59,6 +62,7 @@ py_library(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
+ "//tensorflow/python:summary_op_util",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/estimator:model_fn",
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
index 0aae695f92..1e30525159 100644
--- a/tensorflow/contrib/compiler/xla.py
+++ b/tensorflow/contrib/compiler/xla.py
@@ -19,17 +19,22 @@ from __future__ import division
from __future__ import print_function
import collections
+import contextlib
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.jit.ops import xla_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import summary_op_util
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
+from tensorflow.python.util import tf_decorator
_XLA_COMPILE_ATTR = '_xla_compile_id'
_MAX_WARNING_LINES = 5
@@ -353,3 +358,291 @@ def _compile_internal(computation, inputs=None):
array_ops.identity(outputs[i], name='output_%d' % i)
for i in xrange(output_arity)
]
+
+
+@contextlib.contextmanager
+def _disable_summary_context():
+ """Enters a context where all summary ops are skipped.
+
+ Summaries are not yet supported in xla.compile(). So we provide this context
+ manager that can skip creating summary ops. This is a temporary workaround due
+ to XLA not supporting summary ops.
+
+ Yields:
+ None.
+ """
+ origional_skip_summary_func = summary_op_util.skip_summary
+ summary_op_util.skip_summary = lambda: True
+
+ try:
+ yield
+ finally:
+ summary_op_util.skip_summary = origional_skip_summary_func
+
+
+class _CapturedObject(object):
+ """A placeholder to capture an object."""
+
+ def __init__(self):
+ self._object = None
+
+ def capture(self, o):
+ if self._object:
+ raise RuntimeError(
+ 'InternalError: _CapturedObject can capture only once. Please file '
+ 'bug.')
+
+ self._object = o
+
+ def get(self):
+ return self._object
+
+
+def _get_scaffold(captured_scaffold_fn):
+ """Retrieves the Scaffold from `captured_scaffold_fn`."""
+ scaffold_fn = captured_scaffold_fn.get()
+
+ if not scaffold_fn:
+ return None
+
+ scaffold = scaffold_fn()
+ if scaffold is None:
+ raise ValueError(
+ 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
+
+ return scaffold
+
+
+class _ModelFnWrapper(object):
+ """_ModelFnWrapper supports executing model_fn with XLA."""
+
+ def __init__(self, function):
+ self._model_fn = function
+
+ def __call__(self, features, labels, mode, params):
+
+ # TPUEstimator compiles model_fn when use_tpu=True. To avoid double
+ # compilation, we use this params['use_tpu'] as a hint. When it is set to
+ # True, model_fn is called without compilation.
+ # Note that this condition isn't accurate for the case of exporting a model.
+ # In that case we should ideally not compile so that user can see detailed
+ # graph. However, we don't have enough information to tell whether model_fn
+ # is being called for export mode or not.
+ # TODO(ycao): Make this condition more accurate when implementing PREDICT
+ # mode.
+ if params.get('use_tpu'):
+ return self._call_model_fn(features, labels, mode, params)
+
+ if mode == model_fn_lib.ModeKeys.TRAIN:
+ train_step, captured_scaffold_fn = self._make_train_step(
+ features, labels, params)
+ with _disable_summary_context():
+ (loss,) = compile(train_step)
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ train_op=array_ops.identity(loss),
+ scaffold=_get_scaffold(captured_scaffold_fn))
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ eval_step, captured_eval_metric_fn, captured_scaffold_fn = (
+ self._make_eval_step(features, labels, params))
+ with _disable_summary_context():
+ outputs = compile(eval_step)
+ loss = outputs[0]
+
+ # Calculate eval_metric_ops if eval_metric_fn is set and captured.
+ eval_metric_fn = captured_eval_metric_fn.get()
+ if eval_metric_fn:
+ eval_metric_fn_tensors = outputs[1:]
+ eval_metric_ops = eval_metric_fn(*eval_metric_fn_tensors)
+ else:
+ eval_metric_ops = None
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ eval_metric_ops=eval_metric_ops,
+ scaffold=_get_scaffold(captured_scaffold_fn))
+ else:
+ raise NotImplementedError('%s is not implemented, only TRAIN and EVAL are'
+ ' supported' % mode)
+
+ def _make_train_step(self, features, labels, params):
+ """Creates a single step of training for xla.compile()."""
+ captured_scaffold_fn = _CapturedObject()
+
+ def train_step():
+ """A single step of training."""
+ estimator_spec = self._call_model_fn(features, labels,
+ model_fn_lib.ModeKeys.TRAIN, params)
+
+ try:
+ captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
+ except AttributeError:
+ captured_scaffold_fn.capture(None)
+
+ # train_step will be run by xla.compile(). xla.compile() only supports
+ # tensor output while train_op can be either an operation or a tensor.
+ # Even though xla.compile() automatically adds operation-typed train_op as
+ # control dependency of other tensor outputs, it doesn't do so for
+ # tensor-typed train_op. Thus, we need to set it explicitly here.
+ with ops.control_dependencies([estimator_spec.train_op]):
+ return array_ops.identity(estimator_spec.loss)
+
+ return train_step, captured_scaffold_fn
+
+ def _make_eval_step(self, features, labels, params):
+ """Creates a single step of evaluation for xla.compile()."""
+ captured_eval_metric_fn = _CapturedObject()
+ captured_scaffold_fn = _CapturedObject()
+
+ def eval_step():
+ """A single step of evaluation."""
+ estimator_spec = self._call_model_fn(features, labels,
+ model_fn_lib.ModeKeys.EVAL, params)
+
+ try:
+ captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
+ except AttributeError:
+ captured_scaffold_fn.capture(None)
+
+ eval_metric_fn = None
+ eval_metric_fn_tensors = []
+ try:
+ if estimator_spec.eval_metrics:
+ (eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics
+ except AttributeError:
+ pass
+
+ # If a dictionary is provided, we need to convert it into a list sorted
+ # according to order of eval_metric_fn positional arguments.
+ if isinstance(eval_metric_fn_tensors, dict):
+ eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
+ eval_metric_fn_tensors = [
+ eval_metric_fn_tensors[i] for i in eval_metric_fn_args
+ ]
+
+ captured_eval_metric_fn.capture(eval_metric_fn)
+
+ return tuple([estimator_spec.loss] + eval_metric_fn_tensors)
+
+ return eval_step, captured_eval_metric_fn, captured_scaffold_fn
+
+ def _call_model_fn(self, features, labels, mode, params):
+ """Calls the model_fn with required parameters."""
+ model_fn_args = function_utils.fn_args(self._model_fn)
+ kwargs = {}
+
+ if 'labels' in model_fn_args:
+ kwargs['labels'] = labels
+ elif labels is not None:
+ raise ValueError(
+ 'model_fn does not take labels, but input_fn returns labels.')
+ if 'mode' in model_fn_args:
+ kwargs['mode'] = mode
+
+ if 'params' in model_fn_args:
+ kwargs['params'] = params
+
+ return self._verify_estimator_spec(
+ self._model_fn(features=features, **kwargs))
+
+ def _verify_estimator_spec(self, estimator_spec):
+ """Verifies estimator spec contains correct data."""
+ # TODO(ycao): Implement estimator spec verification for other modes.
+
+ try:
+ if estimator_spec.scaffold:
+ logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation'
+ '. Please use TPUEstimatorSpec.scaffold_fn instead.')
+ except AttributeError:
+ pass
+
+ try:
+ if estimator_spec.eval_metric_ops:
+ raise ValueError('EstimatorSpec.eval_metric_ops is not supported with '
+ 'XLA compilation. Please use '
+ 'TPUEstimatorSpec.eval_metrics instead.')
+ except AttributeError:
+ pass
+
+ if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL:
+ # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics,
+ # check that eval_metrics contains eval_metric_fn and
+ # eval_metric_fn_tensors with matching arguments.
+ try:
+ eval_metrics = estimator_spec.eval_metrics
+ except AttributeError:
+ eval_metrics = None
+
+ if eval_metrics:
+ (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics
+ eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
+
+ if isinstance(eval_metric_fn_tensors, dict):
+ missing_tensors = [
+ i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors
+ ]
+ additional_tensors = [
+ i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args
+ ]
+
+ if missing_tensors:
+ raise ValueError('Arguments %s are needed by metric_fn (first '
+ 'element of TPUEstimatorSpec.eval_metrics) but '
+ 'they are not provided by evaluation tensors '
+ '(second element of TPUEstimatorSpec.eval_metrics)'
+ '.' % missing_tensors)
+
+ if additional_tensors:
+ raise ValueError('Arguments %s are provided by evaluation tensors '
+ '(second element of TPUEstimatorSpec.eval_metrics)'
+ ' but they are not needed by metric_fn (first '
+ 'element of TPUEstimatorSpec.eval_metrics).' %
+ additional_tensors)
+
+ return estimator_spec
+
+
+def estimator_model_fn(target_model_fn=None):
+ """estimator_model_fn decorates a model_fn to be compiled for execution.
+
+ Currently only it only works with `TPUEstimator`. If you need to use it with
+ base `Estimator`, please add `tf.enable_resource_variables()` at beginning of
+ your program.
+
+ Example 1, decorating model_fn:
+ ```
+ @xla.estimator_model_fn()
+ def model_fn(features, labels, mode, params):
+ ...
+ return EstimatorSpec(...)
+
+
+ est = Estimator(model_fn=model_fn, ...)
+ est.train(...)
+
+ ```
+
+ Example 2, decorator as function:
+ ```
+ def model_fn(features, labels, mode, params):
+ ...
+ return EstimatorSpec(...)
+
+ est = Estimator(model_fn=xla.estimator_model_fn(model_fn), ...)
+ est.train(...)
+ ```
+
+ Args:
+ target_model_fn: model_fn to be decorated. This is only needed when
+ decorator is used in function call form (example 2).
+
+ Returns:
+ Decorated target_model_fn.
+ """
+
+ def decorated(function):
+ return tf_decorator.make_decorator(function, _ModelFnWrapper(function))
+
+ return decorated(target_model_fn) if target_model_fn else decorated
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index 39f23f7b24..96f1dd0059 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -476,656 +476,6 @@ class IteratorGetDeviceOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU),
IteratorGetDeviceOp);
-Status VerifyTypesMatch(const DataTypeVector& expected,
- const DataTypeVector& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " types but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (expected[i] != received[i]) {
- return errors::InvalidArgument("Data type mismatch at component ", i,
- ": expected ", DataTypeString(expected[i]),
- " but got ", DataTypeString(received[i]),
- ".");
- }
- }
- return Status::OK();
-}
-
-Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
- const std::vector<PartialTensorShape>& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " shapes but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (!expected[i].IsCompatibleWith(received[i])) {
- return errors::InvalidArgument("Incompatible shapes at component ", i,
- ": expected ", expected[i].DebugString(),
- " but got ", received[i].DebugString(),
- ".");
- }
- }
-
- return Status::OK();
-}
-
-string SanitizeThreadSuffix(string suffix) {
- string clean;
- for (int i = 0; i < suffix.size(); ++i) {
- const char ch = suffix[i];
- if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
- (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') {
- clean += ch;
- } else {
- clean += '_';
- }
- }
- return clean;
-}
-
-struct HostBufferElement {
- Status status;
- bool end_of_sequence;
- std::vector<Tensor> value;
-};
-
-using MultiDeviceIteratorCallback =
- std::function<void(const HostBufferElement&)>;
-
-class MultiDeviceIterator : public ResourceBase {
- public:
- MultiDeviceIterator(const DataTypeVector& output_types,
- const std::vector<PartialTensorShape>& output_shapes,
- const std::vector<string>& devices,
- std::unique_ptr<FunctionLibraryDefinition> flib_def,
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
- FunctionLibraryRuntime* lib)
- : output_types_(output_types),
- output_shapes_(output_shapes),
- devices_(devices),
- flib_def_(std::move(flib_def)),
- pflr_(std::move(pflr)),
- lib_(lib) {
- CHECK_NOTNULL(lib_);
- }
-
- string DebugString() override {
- return strings::StrCat("MultiDeviceIterator for ", devices_.size(),
- " devices");
- }
-
- Status Init(std::unique_ptr<IteratorBase> iterator, int64 max_buffer_size,
- int64* incarnation_id) {
- if (iterator) {
- TF_RETURN_IF_ERROR(
- VerifyTypesMatch(output_types_, iterator->output_dtypes()));
- TF_RETURN_IF_ERROR(
- VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
- }
-
- mutex_lock l(mu_);
- if (multi_device_buffer_) {
- multi_device_buffer_->Reset();
- }
-
- ++incarnation_id_;
- *incarnation_id = incarnation_id_;
-
- multi_device_buffer_.reset(
- new MultiDeviceBuffer(devices_.size(), max_buffer_size, incarnation_id_,
- std::move(iterator)));
- return Status::OK();
- }
-
- void GetNextFromShard(IteratorContext* ctx, int shard_num,
- int64 incarnation_id,
- MultiDeviceIteratorCallback callback) {
- if (lib_ != nullptr) {
- ctx->set_lib(lib_);
- }
- tf_shared_lock l(mu_);
- multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id,
- std::move(callback));
- }
-
- const DataTypeVector& output_types() const { return output_types_; }
-
- const std::vector<PartialTensorShape>& output_shapes() const {
- return output_shapes_;
- }
-
- std::shared_ptr<const FunctionLibraryDefinition> function_library() {
- tf_shared_lock l(mu_);
- return lib_def_;
- }
-
- FunctionLibraryRuntime* const lib() {
- tf_shared_lock l(mu_);
- return lib_;
- }
-
- private:
- // A private class that uses a background thread to keep a per device buffer
- // full.
- class MultiDeviceBuffer {
- public:
- MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id,
- std::unique_ptr<IteratorBase> host_iterator)
- : buffer_(size),
- size_(size),
- max_buffer_size_(max_buffer_size),
- incarnation_id_(incarnation_id),
- host_iterator_(std::move(host_iterator)) {}
-
- ~MultiDeviceBuffer() {
- {
- mutex_lock l(mu_);
- if (!background_thread_started_) return;
- }
- Reset();
- }
-
- void Reset() LOCKS_EXCLUDED(mu_) {
- {
- mutex_lock l(mu_);
- if (background_thread_finished_) {
- return;
- }
-
- cancelled_ = true;
- // Wake up the background thread.
- for (int i = 0; i < size_; ++i) {
- buffer_[i].cond_var.notify_all();
- }
-
- // Make sure background thread has finished first.
- while (!background_thread_finished_) {
- shutdown_cond_var_.wait(l);
- }
- }
- RunPendingCallbacks();
- }
-
- void GetNextFromShard(IteratorContext* ctx, int shard_num,
- int64 incarnation_id,
- MultiDeviceIteratorCallback callback) {
- HostBufferElement elem;
- if (incarnation_id_ != incarnation_id) {
- elem.status = errors::InvalidArgument("Invalid incarnation id");
- callback(elem);
- return;
- }
-
- bool produced_output = false;
- {
- mutex_lock l(mu_);
- if (cancelled_) {
- elem.status = errors::Cancelled("Cancelled Multidevice iterator");
- callback(elem);
- return;
- }
-
- EnsureBackgroundThreadStarted(ctx);
-
- if (!buffer_[shard_num].data.empty()) {
- produced_output = true;
- std::swap(elem, buffer_[shard_num].data.front());
- buffer_[shard_num].data.pop_front();
- // Wake up background thread if it is blocked on this element.
- if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) {
- buffer_[shard_num].cond_var.notify_all();
- }
- } else {
- if (background_thread_finished_) {
- produced_output = true;
- elem.end_of_sequence = true;
- } else {
- buffer_[shard_num].callbacks.push_back(std::move(callback));
- callback = nullptr;
- }
- }
- }
-
- if (produced_output) {
- callback(elem);
- }
- }
-
- private:
- void EnsureBackgroundThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (!background_thread_) {
- background_thread_.reset(ctx->env()->StartThread(
- {}, "multi_device_iterator_background_thread",
- std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
- this, new IteratorContext(*ctx))));
- }
- }
-
- void RunPendingCallbacks() LOCKS_EXCLUDED(mu_) {
- // Run all remaining callbacks.
- std::vector<MultiDeviceIteratorCallback> cancellation_callbacks;
- std::vector<HostBufferElement> cancellation_elements;
- {
- mutex_lock l(mu_);
-
- for (int i = 0; i < size_; ++i) {
- while (!buffer_[i].callbacks.empty()) {
- if (buffer_[i].data.empty()) {
- HostBufferElement elem;
- elem.status =
- errors::Cancelled("Cancelled and buffer not filled.");
- cancellation_elements.push_back(std::move(elem));
- } else {
- cancellation_elements.push_back(
- std::move(buffer_[i].data.front()));
- buffer_[i].data.pop_front();
- }
- cancellation_callbacks.push_back(
- std::move(buffer_[i].callbacks.front()));
- buffer_[i].callbacks.pop_front();
- }
- }
- }
- for (int i = 0; i < cancellation_callbacks.size(); ++i) {
- cancellation_callbacks[i](cancellation_elements[i]);
- }
- }
-
- void BackgroundThread(IteratorContext* ctx) {
- {
- mutex_lock l(mu_);
- background_thread_started_ = true;
- }
- std::unique_ptr<IteratorContext> cleanup(ctx);
- int shard_to_fetch = 0;
- while (true) {
- HostBufferElement elem;
- MultiDeviceIteratorCallback callback = nullptr;
- bool end_of_iterator = false;
-
- {
- mutex_lock l(mu_);
- while (!cancelled_ &&
- buffer_[shard_to_fetch].data.size() >= max_buffer_size_) {
- buffer_[shard_to_fetch].cond_var.wait(l);
- }
-
- if (cancelled_) {
- background_thread_finished_ = true;
- shutdown_cond_var_.notify_all();
- return;
- }
- }
-
- elem.status =
- host_iterator_->GetNext(ctx, &elem.value, &elem.end_of_sequence);
-
- if (elem.status.ok() && elem.end_of_sequence) {
- end_of_iterator = true;
- }
-
- {
- mutex_lock l(mu_);
- // Try to find a callback, else just push stuff into buffer.
- if (!buffer_[shard_to_fetch].callbacks.empty()) {
- callback = buffer_[shard_to_fetch].callbacks.front();
- buffer_[shard_to_fetch].callbacks.pop_front();
- } else {
- buffer_[shard_to_fetch].data.push_back(std::move(elem));
- elem = HostBufferElement();
- }
- }
-
- if (callback) {
- (*ctx->runner())(std::bind(std::move(callback), std::move(elem)));
- }
-
- // Finish off the thread if we reach the end of the iterator. Runs
- // pending callbacks.
- if (end_of_iterator) {
- {
- mutex_lock l(mu_);
- background_thread_finished_ = true;
- shutdown_cond_var_.notify_all();
- }
- RunPendingCallbacks();
- return;
- }
- shard_to_fetch = (shard_to_fetch + 1) % size_;
- }
- }
-
- struct HostBuffer {
- condition_variable cond_var;
- std::deque<HostBufferElement> data;
- std::deque<MultiDeviceIteratorCallback> callbacks;
- };
-
- mutex mu_;
- std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
- bool background_thread_finished_ GUARDED_BY(mu_) = false;
- bool background_thread_started_ GUARDED_BY(mu_) = false;
- bool cancelled_ GUARDED_BY(mu_) = false;
- condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
-
- std::vector<HostBuffer> buffer_;
-
- const size_t size_;
- const int64 max_buffer_size_;
- const int64 incarnation_id_;
- const std::unique_ptr<IteratorBase> host_iterator_;
- };
-
- mutex mu_;
- const DataTypeVector output_types_;
- const std::vector<PartialTensorShape> output_shapes_;
- const std::vector<string> devices_;
- const std::unique_ptr<FunctionLibraryDefinition> flib_def_;
- const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
- FunctionLibraryRuntime* const lib_ = nullptr; // not owned.
- std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
-
- int64 incarnation_id_ GUARDED_BY(mu_) = 0;
- std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_);
-};
-
-// Just creates a MultiDeviceIterator and returns it.
-class MultiDeviceIteratorHandleOp : public OpKernel {
- public:
- explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_));
- }
-
- // The resource is deleted from the resource manager only when it is private
- // to kernel.
- ~MultiDeviceIteratorHandleOp() override {
- if (resource_ != nullptr) {
- resource_->Unref();
- if (cinfo_.resource_is_private_to_kernel()) {
- if (!cinfo_.resource_manager()
- ->template Delete<MultiDeviceIterator>(cinfo_.container(),
- cinfo_.name())
- .ok()) {
- // Do nothing; the resource can have been deleted by session resets.
- }
- }
- }
- }
-
- void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
- {
- mutex_lock l(mu_);
- if (resource_ == nullptr) {
- FunctionLibraryRuntime* lib;
- std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
- OP_REQUIRES_OK(context, context->function_library()->Clone(
- &flib_def, &pflr, &lib));
- ResourceMgr* mgr = context->resource_manager();
- OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
-
- MultiDeviceIterator* resource;
- OP_REQUIRES_OK(
- context,
- mgr->LookupOrCreate<MultiDeviceIterator>(
- cinfo_.container(), cinfo_.name(), &resource,
- [this, lib, &flib_def, &pflr](MultiDeviceIterator** ret)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- *ret = new MultiDeviceIterator(
- output_types_, output_shapes_, devices_,
- std::move(flib_def), std::move(pflr), lib);
- return Status::OK();
- }));
-
- Status s = VerifyResource(resource);
- if (TF_PREDICT_FALSE(!s.ok())) {
- resource->Unref();
- context->SetStatus(s);
- return;
- }
-
- resource_ = resource;
- }
- }
- OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
- context, 0, cinfo_.container(), cinfo_.name(),
- MakeTypeIndex<MultiDeviceIterator>()));
- }
-
- private:
- // During the first Compute(), resource is either created or looked up using
- // shared_name. In the latter case, the resource found should be verified if
- // it is compatible with this op's configuration. The verification may fail in
- // cases such as two graphs asking queues of the same shared name to have
- // inconsistent capacities.
- Status VerifyResource(MultiDeviceIterator* resource) {
- TF_RETURN_IF_ERROR(
- VerifyTypesMatch(output_types_, resource->output_types()));
- TF_RETURN_IF_ERROR(
- VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
- return Status::OK();
- }
-
- mutex mu_;
- ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
- MultiDeviceIterator* resource_ GUARDED_BY(mu_) = nullptr;
- DataTypeVector output_types_;
- std::vector<PartialTensorShape> output_shapes_;
- const int graph_def_version_;
- string name_;
- string container_;
- std::vector<string> devices_;
-};
-
-REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU),
- MultiDeviceIteratorHandleOp);
-
-// Calls init on the MultiDeviceIterator.
-class MultiDeviceIteratorInitOp : public OpKernel {
- public:
- explicit MultiDeviceIteratorInitOp(OpKernelConstruction* ctx)
- : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor* tensor_max_buffer_size;
- OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size));
- int64 max_buffer_size = tensor_max_buffer_size->scalar<int64>()();
-
- DatasetBase* dataset;
- OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
- MultiDeviceIterator* resource;
- OP_REQUIRES_OK(ctx,
- LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
- core::ScopedUnref unref(resource);
-
- std::unique_ptr<IteratorBase> iterator;
- IteratorContext iter_ctx(ctx);
- iter_ctx.set_lib(resource->lib());
- OP_REQUIRES_OK(
- ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
- int64 incarnation_id;
- OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
- &incarnation_id));
- Tensor tensor_incarnation_id(DT_INT64, TensorShape({}));
- tensor_incarnation_id.scalar<int64>()() = incarnation_id;
- OP_REQUIRES_OK(ctx,
- ctx->set_output("incarnation_id", tensor_incarnation_id));
- }
-};
-
-REGISTER_KERNEL_BUILDER(Name("MultiDeviceIteratorInit").Device(DEVICE_CPU),
- MultiDeviceIteratorInitOp);
-
-// Calls GetNextFromShard(shard) and returns a vector of Tensors as output.
-// TODO(rohanj): Implement using BackgroundWorker that Derek built?
-class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
- public:
- explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx)
- : AsyncOpKernel(ctx),
- thread_pool_(new thread::ThreadPool(
- ctx->env(), ThreadOptions(),
- strings::StrCat("multi_device_iterator_get_next_thread_",
- SanitizeThreadSuffix(name())),
- 1 /* num_threads */, false /* low_latency_hint */)) {}
-
- void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- const Tensor* tensor_shard_num;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done);
- int32 shard_num = tensor_shard_num->scalar<int32>()();
-
- const Tensor* tensor_incarnation_id;
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
- int64 incarnation_id = tensor_incarnation_id->scalar<int64>()();
-
- MultiDeviceIterator* iterator;
- OP_REQUIRES_OK_ASYNC(
- ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
- thread_pool_->Schedule(std::bind(
- [ctx, iterator, shard_num, incarnation_id](DoneCallback done) {
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = *(ctx->runner());
- params.function_library = iterator->function_library();
- DeviceBase* device = ctx->function_library()->device();
- params.allocator_getter = [device](AllocatorAttributes attrs) {
- return device->GetAllocator(attrs);
- };
- IteratorContext iter_ctx(std::move(params));
-
- MultiDeviceIteratorCallback callback = std::bind(
- [ctx](const HostBufferElement& elem, DoneCallback done) {
- // iterator->Unref();
- Status s = elem.status;
- if (!s.ok()) {
- ctx->SetStatus(s);
- } else if (elem.end_of_sequence) {
- ctx->SetStatus(errors::OutOfRange("End of sequence"));
- } else {
- for (int i = 0; i < elem.value.size(); ++i) {
- ctx->set_output(i, elem.value[i]);
- }
- }
- done();
- },
- std::placeholders::_1, std::move(done));
-
- iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
- callback);
- iterator->Unref();
- },
- std::move(done)));
- }
-
- private:
- std::unique_ptr<thread::ThreadPool> thread_pool_;
-};
-
-REGISTER_KERNEL_BUILDER(
- Name("MultiDeviceIteratorGetNextFromShard").Device(DEVICE_CPU),
- MultiDeviceIteratorGetNextFromShardOp);
-
-class MultiDeviceIteratorToStringHandleOp : public OpKernel {
- public:
- explicit MultiDeviceIteratorToStringHandleOp(OpKernelConstruction* ctx)
- : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor& resource_handle_t = ctx->input(0);
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
- errors::InvalidArgument("resource_handle must be a scalar"));
-
- // Validate that the handle corresponds to a real resource, and
- // that it is an MultiDeviceIterator.
- MultiDeviceIterator* resource;
- OP_REQUIRES_OK(ctx,
- LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
- resource->Unref();
-
- Tensor* string_handle_t;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_output(0, TensorShape({}), &string_handle_t));
- string_handle_t->scalar<string>()() =
- resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
- }
-};
-
-REGISTER_KERNEL_BUILDER(
- Name("MultiDeviceIteratorToStringHandle").Device(DEVICE_CPU),
- MultiDeviceIteratorToStringHandleOp);
-
-class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
- public:
- explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx)
- : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
- OP_REQUIRES(
- ctx,
- output_types_.empty() || output_shapes_.empty() ||
- output_types_.size() == output_shapes_.size(),
- errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
- "are set, they must have the same length."));
- }
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor& string_handle_t = ctx->input(0);
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
- errors::InvalidArgument("string_handle must be a scalar"));
-
- ResourceHandle resource_handle;
- OP_REQUIRES(
- ctx,
- resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
- errors::InvalidArgument(
- "Could not parse string_handle as a valid ResourceHandle"));
-
- OP_REQUIRES(
- ctx, resource_handle.device() == ctx->device()->attributes().name(),
- errors::InvalidArgument("Attempted create an iterator on device \"",
- ctx->device()->attributes().name(),
- "\" from handle defined on device \"",
- resource_handle.device(), "\""));
-
- // Validate that the handle corresponds to a real resource, and
- // that it is an MultiDeviceIterator.
- MultiDeviceIterator* resource;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
- core::ScopedUnref unref_iterator(resource);
- if (!output_types_.empty()) {
- OP_REQUIRES_OK(ctx,
- VerifyTypesMatch(output_types_, resource->output_types()));
- }
- if (!output_shapes_.empty()) {
- OP_REQUIRES_OK(ctx, VerifyShapesCompatible(output_shapes_,
- resource->output_shapes()));
- }
-
- Tensor* resource_handle_t;
- OP_REQUIRES_OK(
- ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
- resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
- }
-
- private:
- DataTypeVector output_types_;
- std::vector<PartialTensorShape> output_shapes_;
-};
-
-REGISTER_KERNEL_BUILDER(
- Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU),
- MultiDeviceIteratorFromStringHandleOp);
-
} // namespace
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index ad410e17fe..d1a771f005 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -151,82 +151,6 @@ Resets the FunctionBufferingResource.
function_buffer_resource: The FunctionBufferingResource handle.
)doc");
-REGISTER_OP("MultiDeviceIterator")
- .Output("handle: resource")
- .Attr("devices: list(string) >= 1")
- .Attr("shared_name: string")
- .Attr("container: string")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .Doc(R"doc(
-Creates a MultiDeviceIterator resource.
-
-handle: Handle to the resource created.
-devices: A list of devices the iterator works across.
-shared_name: If non-empty, this resource will be shared under the given name
- across multiple sessions.
-container: If non-empty, this resource is placed in the given container.
- Otherwise, a default container is used.
-output_types: The type list for the return values.
-output_shapes: The list of shapes being produced.
-)doc");
-
-REGISTER_OP("MultiDeviceIteratorInit")
- .Input("dataset: variant")
- .Input("multi_device_iterator: resource")
- .Input("max_buffer_size: int64")
- .Output("incarnation_id: int64")
- .Doc(R"doc(
-Initializes the multi device iterator with the given dataset.
-max_buffer_size: The maximum size of the host side per device buffer to keep.
-incarnation_id: An int64 indicating which incarnation of the MultiDeviceIterator
- is running.
-dataset: Dataset to be iterated upon.
-multi_device_iterator: A MultiDeviceIteratorResource.
-)doc");
-
-REGISTER_OP("MultiDeviceIteratorGetNextFromShard")
- .Input("multi_device_iterator: resource")
- .Input("shard_num: int32")
- .Input("incarnation_id: int64")
- .Output("components: output_types")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .Doc(R"doc(
-Gets next element for the provided shard number.
-
-multi_device_iterator: A MultiDeviceIterator resource.
-shard_num: Integer representing which shard to fetch data for.
-incarnation_id: Which incarnation of the MultiDeviceIterator is running.
-components: Result of the get_next on the dataset.
-output_types: The type list for the return values.
-output_shapes: The list of shapes being produced.
-)doc");
-
-REGISTER_OP("MultiDeviceIteratorToStringHandle")
- .Input("multi_device_iterator: resource")
- .Output("string_handle: string")
- .Doc(R"doc(
-Produces a string handle for the given MultiDeviceIterator.
-
-multi_device_iterator: A MultiDeviceIterator resource.
-string_handle: A string representing the resource.
-)doc");
-
-REGISTER_OP("MultiDeviceIteratorFromStringHandle")
- .Input("string_handle: string")
- .Output("multi_device_iterator: resource")
- .Attr("output_types: list(type) >= 0 = []")
- .Attr("output_shapes: list(shape) >= 0 = []")
- .Doc(R"doc(
-Generates a MultiDeviceIterator resource from its provided string handle.
-
-string_handle: String representing the resource.
-multi_device_iterator: A MultiDeviceIterator resource.
-output_types: The type list for the return values.
-output_shapes: The list of shapes being produced.
-)doc");
-
REGISTER_OP("ThreadPoolDataset")
.Input("input_dataset: variant")
.Input("thread_pool: resource")
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
index b3187bf61b..a2fc244ced 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -110,6 +110,22 @@ py_test(
)
py_test(
+ name = "noop_elimination_test",
+ size = "small",
+ srcs = ["noop_elimination_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:batching",
+ "//tensorflow/contrib/data/python/ops:interleave_ops",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "optimize_dataset_op_test",
size = "small",
srcs = ["optimize_dataset_op_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
new file mode 100644
index 0000000000..507feda3ad
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapParallelization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class NoopEliminationTest(test.TestCase):
+
+ def testNoopElimination(self):
+ a = constant_op.constant(1, dtype=dtypes.int64)
+ b = constant_op.constant(2, dtype=dtypes.int64)
+ some_tensor = math_ops.mul(a, b)
+
+ dataset = dataset_ops.Dataset.range(5)
+ dataset = dataset.apply(
+ optimization.assert_next(
+ ["FiniteRepeat", "FiniteSkip", "Prefetch", "Prefetch"]))
+ dataset = dataset.repeat(some_tensor).skip(5).prefetch(0).take(-1).skip(
+ 0).repeat(1).prefetch(0)
+ dataset = dataset.apply(optimization.optimize(["noop_elimination"]))
+
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ for x in range(5):
+ result = sess.run(get_next)
+ self.assertAllEqual(result, x)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 5b17511e41..33a64ea767 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -31,7 +31,6 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
@@ -944,164 +943,5 @@ class CopyToDeviceTest(test.TestCase):
sess.run(elem_value_t)
-class MultiDeviceIteratorTest(test.TestCase):
-
- def testNoGetNext(self):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"])
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
-
- def testBasic(self):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"])
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testOneOnSameDevice(self):
- with ops.device("/cpu:0"):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:0", "/cpu:1"])
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testRepeatDevices(self):
- with ops.device("/cpu:0"):
- dataset = dataset_ops.Dataset.range(20)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"])
- elements = multi_device_iterator.get_next()
- elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 20, 4):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- self.assertEqual(i + 2, sess.run(elem_on_3))
- self.assertEqual(i + 3, sess.run(elem_on_4))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
- sess.run(elem_on_3)
- sess.run(elem_on_4)
-
- def testNotFullyDivisible(self):
- dataset = dataset_ops.Dataset.range(9)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"])
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 8, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- self.assertEqual(8, sess.run(elem_on_1))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testUneven(self):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4)
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- for i in range(0, 10, 2):
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testMultipleInitializations(self):
- with ops.device("/cpu:0"):
- epoch = array_ops.placeholder(dtypes.int64, shape=[])
- dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000)
- dataset2 = dataset_ops.Dataset.range(1000)
- dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
- init_op = multi_device_iterator.initializer
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- for i in range(1000):
- sess.run(init_op, feed_dict={epoch: i})
- self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2]))
-
- def testBasicGpu(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- with compat.forward_compatibility_horizon(2018, 8, 4):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/gpu:0"])
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testUnevenGpu(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- with compat.forward_compatibility_horizon(2018, 8, 4):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4)
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- for i in range(0, 10, 2):
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index 77079d0df9..c900b41e14 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -143,8 +143,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
def _real_mirrored_creator(devices, *args, **kwargs):
"""Creates one MirroredVariable on the current worker."""
index = {}
+ unique_var_name = ops.get_default_graph().unique_name(
+ kwargs["name"], mark_as_used=False).rstrip("/")
collective_instance_key = self._collective_keys.get_instance_key(
- key_id=kwargs["name"])
+ key_id=unique_var_name)
if "initial_value" not in kwargs:
raise ValueError("Initial value must be specified.")
initial_value = kwargs["initial_value"]
@@ -188,6 +190,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
v = next_creator(*args, **kwargs)
+ if i == 0:
+ actual_var_name = v.name.split(":")[0]
+ assert unique_var_name == actual_var_name, "%r vs %r" % (
+ unique_var_name, actual_var_name)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
return index
@@ -229,8 +235,6 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
if not session_config or not self._cluster_spec:
return
- session_config.isolate_session_state = True
-
assert self._task_type
assert self._task_id is not None
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index 36e9761073..33ffbf6abe 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -34,9 +35,14 @@ from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
+from tensorflow.python.training import adam
+from tensorflow.python.training import training_util
class CollectiveAllReduceStrategyTestBase(
@@ -146,6 +152,56 @@ class CollectiveAllReduceStrategyTestBase(
self.assertLess(error_after, error_before)
return error_after < error_before
+ def _test_complex_model(self, task_type, task_id, num_gpus):
+ d, master_target = self._get_test_object(task_type, task_id, num_gpus)
+
+ def model_fn():
+ """Mnist model with synthetic input."""
+ data_format = 'channels_last'
+ input_shape = [28, 28, 1]
+ l = keras.layers
+ max_pool = l.MaxPooling2D((2, 2), (2, 2),
+ padding='same',
+ data_format=data_format)
+ model = keras.Sequential([
+ l.Reshape(target_shape=input_shape, input_shape=(28 * 28,)),
+ l.Conv2D(
+ 32,
+ 5,
+ padding='same',
+ data_format=data_format,
+ activation=nn.relu), max_pool,
+ l.Conv2D(
+ 64,
+ 5,
+ padding='same',
+ data_format=data_format,
+ activation=nn.relu), max_pool,
+ l.Flatten(),
+ l.Dense(1024, activation=nn.relu),
+ l.Dropout(0.4),
+ l.Dense(10)
+ ])
+ image = random_ops.random_uniform([2, 28, 28])
+ label = random_ops.random_uniform([2, 1], maxval=10, dtype=dtypes.int32)
+ logits = model(image, training=True)
+ loss = losses.sparse_softmax_cross_entropy(labels=label, logits=logits)
+ optimizer = adam.AdamOptimizer(learning_rate=1e-4)
+ train_op = optimizer.minimize(loss,
+ training_util.get_or_create_global_step())
+ return train_op
+
+ with ops.Graph().as_default(), \
+ self.test_session(config=self._sess_config,
+ target=master_target) as sess:
+ with d.scope():
+ train_op = d.call_for_each_tower(model_fn)
+ train_op = d.group(d.unwrap(train_op))
+
+ sess.run(variables.global_variables_initializer())
+ sess.run(train_op)
+ return True
+
def _test_variable_initialization(self, task_type, task_id, num_gpus):
distribution, master_target = self._get_test_object(task_type, task_id,
num_gpus)
@@ -206,6 +262,14 @@ class DistributedCollectiveAllReduceStrategyTest(
self._cluster_spec,
num_gpus=num_gpus)
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testComplexModel(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_complex_model, self._cluster_spec, num_gpus=num_gpus)
+
class DistributedCollectiveAllReduceStrategyTestWithChief(
CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
@@ -236,6 +300,14 @@ class DistributedCollectiveAllReduceStrategyTestWithChief(
self._cluster_spec,
num_gpus=num_gpus)
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testComplexModel(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_complex_model, self._cluster_spec, num_gpus=num_gpus)
+
class LocalCollectiveAllReduceStrategy(
CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
@@ -246,6 +318,12 @@ class LocalCollectiveAllReduceStrategy(
return
self._test_minimize_loss_graph(None, None, num_gpus)
+ def testComplexModel(self, num_gpus=2):
+ # Collective ops doesn't support strategy with one device.
+ if context.num_gpus() < num_gpus:
+ return
+ self._test_complex_model(None, None, num_gpus)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index a3e1b96a68..490371477a 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -114,7 +114,7 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
self.assertEqual([v.numpy() for v in left._index.values()],
list(right._index.values()))
else:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
self.assertEqual(
sess.run(list(left._index.values())), list(right._index.values()))
diff --git a/tensorflow/contrib/distribute/python/input_ops_test.py b/tensorflow/contrib/distribute/python/input_ops_test.py
index c5acb7ced4..559de97bb1 100644
--- a/tensorflow/contrib/distribute/python/input_ops_test.py
+++ b/tensorflow/contrib/distribute/python/input_ops_test.py
@@ -20,8 +20,6 @@ from __future__ import print_function
import os
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.distribute.python import input_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
@@ -126,20 +124,6 @@ class AutoShardDatasetTest(test.TestCase):
# contain records in order of files.
self._verifySimpleShardingOutput(dataset, self._record)
- def testParallelInterleave(self):
- dataset = dataset_ops.Dataset.from_tensor_slices(
- self._createTFRecordFiles())
- dataset = dataset.apply(interleave_ops.parallel_interleave(
- readers.TFRecordDataset,
- cycle_length=4,
- block_length=self._num_records))
- dataset = input_ops.auto_shard_dataset(
- dataset, self._num_shards, self._shard_index)
-
- # Since block_length == num records in each file, the output will still
- # contain records in order of files.
- self._verifySimpleShardingOutput(dataset, self._record)
-
def testListfiles(self):
filenames = self._createTFRecordFiles()
file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt"
@@ -171,8 +155,8 @@ class AutoShardDatasetTest(test.TestCase):
dataset = dataset.prefetch(buffer_size=batch_size)
dataset = dataset.shuffle(2 * self._num_files * self._num_records)
dataset = dataset.repeat(num_epochs)
- dataset = dataset.apply(batching.map_and_batch(
- lambda x: x, batch_size=batch_size))
+ dataset = dataset.map(lambda x: x)
+ dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=None)
# Auto shard.
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
index 152431d1b2..3fd9f12c61 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
@@ -24,7 +24,6 @@ import pickle
from google.protobuf.any_pb2 import Any
from tensorflow.python.estimator import estimator
-from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import dnn
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.framework import ops
@@ -68,7 +67,7 @@ def _to_any_wrapped_tensor_info(tensor):
return any_buf
-def make_input_layer_with_layer_annotations(original_input_layer, mode):
+def make_input_layer_with_layer_annotations(original_input_layer):
"""Make an input_layer replacement function that adds layer annotations."""
def input_layer_with_layer_annotations(features,
@@ -137,42 +136,38 @@ def make_input_layer_with_layer_annotations(original_input_layer, mode):
if cols_to_output_tensors is not None:
cols_to_output_tensors = local_cols_to_output_tensors
- if mode and mode == model_fn.ModeKeys.PREDICT:
- # Only annotate in PREDICT mode.
-
- # Annotate features.
- # These are the parsed Tensors, before embedding.
-
- # Only annotate features used by FeatureColumns.
- # We figure which ones are used by FeatureColumns by creating a parsing
- # spec and looking at the keys.
- spec = feature_column_lib.make_parse_example_spec(feature_columns)
- for key in spec.keys():
- tensor = features[key]
- ops.add_to_collection(
- LayerAnnotationsCollectionNames.keys(
- LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key)
- ops.add_to_collection(
- LayerAnnotationsCollectionNames.values(
- LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES),
- _to_any_wrapped_tensor_info(tensor))
-
- # Annotate feature columns.
- for column in feature_columns:
- # TODO(cyfoo): Find a better way to serialize and deserialize
- # _FeatureColumn.
- ops.add_to_collection(LayerAnnotationsCollectionNames.FEATURE_COLUMNS,
- serialize_feature_column(column))
-
- for column, tensor in local_cols_to_output_tensors.items():
- ops.add_to_collection(
- LayerAnnotationsCollectionNames.keys(
- LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
- column.name)
- ops.add_to_collection(
- LayerAnnotationsCollectionNames.values(
- LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
- _to_any_wrapped_tensor_info(tensor))
+ # Annotate features.
+ # These are the parsed Tensors, before embedding.
+
+ # Only annotate features used by FeatureColumns.
+ # We figure which ones are used by FeatureColumns by creating a parsing
+ # spec and looking at the keys.
+ spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ for key in spec.keys():
+ tensor = ops.convert_to_tensor(features[key])
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.keys(
+ LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key)
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.values(
+ LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES),
+ _to_any_wrapped_tensor_info(tensor))
+
+ # Annotate feature columns.
+ for column in feature_columns:
+ # TODO(cyfoo): Find a better way to serialize and deserialize
+ # _FeatureColumn.
+ ops.add_to_collection(LayerAnnotationsCollectionNames.FEATURE_COLUMNS,
+ serialize_feature_column(column))
+
+ for column, tensor in local_cols_to_output_tensors.items():
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.keys(
+ LayerAnnotationsCollectionNames.PROCESSED_FEATURES), column.name)
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.values(
+ LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
+ _to_any_wrapped_tensor_info(tensor))
return input_layer
@@ -302,8 +297,8 @@ def DNNClassifierWithLayerAnnotations( # pylint: disable=invalid-name
def _model_fn(features, labels, mode, config):
with _monkey_patch(
feature_column_lib, 'input_layer',
- make_input_layer_with_layer_annotations(feature_column_lib.input_layer,
- mode)):
+ make_input_layer_with_layer_annotations(
+ feature_column_lib.input_layer)):
return original.model_fn(features, labels, mode, config)
return estimator.Estimator(
@@ -423,8 +418,8 @@ def DNNRegressorWithLayerAnnotations( # pylint: disable=invalid-name
def _model_fn(features, labels, mode, config):
with _monkey_patch(
feature_column_lib, 'input_layer',
- make_input_layer_with_layer_annotations(feature_column_lib.input_layer,
- mode)):
+ make_input_layer_with_layer_annotations(
+ feature_column_lib.input_layer)):
return original.model_fn(features, labels, mode, config)
return estimator.Estimator(
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py
index 60e1d85ea9..17ee8c0733 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py
@@ -112,9 +112,11 @@ def safe_embedding_lookup_sparse(embedding_weights,
dtype = sparse_weights.dtype if sparse_weights is not None else None
if isinstance(embedding_weights, variables.PartitionedVariable):
embedding_weights = list(embedding_weights)
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
+ if not isinstance(embedding_weights[0],
+ resource_variable_ops.ResourceVariable):
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
contrib_tensor_util.assert_same_float_dtype(embedding_weights +
[sparse_weights])
diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
index 6fdcf78b69..21ad39a6bf 100644
--- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
+++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
@@ -80,8 +80,7 @@ void resize(T* out, uint8_t* in, int image_height, int image_width,
interpreter->Invoke();
auto output = interpreter->typed_tensor<float>(2);
- auto output_number_of_pixels =
- wanted_height * wanted_height * wanted_channels;
+ auto output_number_of_pixels = wanted_height * wanted_width * wanted_channels;
for (int i = 0; i < output_number_of_pixels; i++) {
if (s->input_floating)
diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md
index a4267eee4c..279764ce96 100644
--- a/tensorflow/contrib/lite/g3doc/models.md
+++ b/tensorflow/contrib/lite/g3doc/models.md
@@ -1,6 +1,23 @@
# List of Hosted Models
+# AutoML mobile image classification models (Float Models)
+
+Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^
+------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ---------------------:
+MnasNet_0.50_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.5_224_09_07_2018.tgz) | 8.5 Mb | 68.03% | 87.79% | 37 ms
+MnasNet_0.75_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.75_224_09_07_2018.tgz) | 12 Mb | 71.72% | 90.17% | 61 ms
+MnasNet_1.0_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms
+MnasNet_1.3_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.3_224_09_07_2018.tgz) | 24 Mb | 75.24% | 92.55% | 152 ms
+MnasNet_1.0_96| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_96_09_07_2018.tgz) | 17 Mb | 62.33% | 83.98% | 23 ms
+MnasNet_1.0_128| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_128_09_07_2018.tgz) | 17 Mb | 67.32% | 87.70% | 34 ms
+MnasNet_1.0_160| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_160_09_07_2018.tgz) | 17 Mb | 70.63% | 89.58% | 51 ms
+MnasNet_1.0_192| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_192_09_07_2018.tgz) | 17 Mb | 72.56% | 90.76% | 70 ms
+MnasNet_1.0_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms
+
+^ Performance numbers are generated on Pixel-1 using single thread large BIG core.
+
+
## Image classification (Float Models)
Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance
diff --git a/tensorflow/contrib/lite/g3doc/overview.md b/tensorflow/contrib/lite/g3doc/overview.md
index 8cf43496df..9d035a6921 100644
--- a/tensorflow/contrib/lite/g3doc/overview.md
+++ b/tensorflow/contrib/lite/g3doc/overview.md
@@ -25,7 +25,7 @@ models.
TensorFlow Lite defines a new model file format, based on
[FlatBuffers](https://google.github.io/flatbuffers/). FlatBuffers is an
-open-sourced, efficient cross platform serialization library. It is similar to
+efficient open-source cross-platform serialization library. It is similar to
[protocol buffers](https://developers.google.com/protocol-buffers/?hl=en), but
the primary difference is that FlatBuffers does not need a parsing/unpacking
step to a secondary representation before you can access data, often coupled
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index bb1d30b216..5bfa3bd084 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -4661,12 +4661,15 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims,
// It does so in two stages, first calculates the sum of elements along the axis
// then divides it by the number of element in axis for quantized values.
template <typename T, typename U>
-inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale,
- const int* input_dims, const int input_num_dims,
- T* output_data, int32 output_zero_point, float output_scale,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int num_axis_dimensions, bool keep_dims,
- int* temp_index, int* resolved_axis, U* temp_sum) {
+inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point,
+ float input_scale, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ int32 output_zero_point, float output_scale,
+ const int* output_dims,
+ const int output_num_dims, const int* axis,
+ const int num_axis_dimensions, bool keep_dims,
+ int* temp_index, int* resolved_axis, U* temp_sum,
+ bool compute_sum) {
// Reset output data.
size_t num_outputs = 1;
for (int idx = 0; idx < output_num_dims; ++idx) {
@@ -4708,14 +4711,24 @@ inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale,
if (num_elements_in_axis > 0) {
const float scale = input_scale / output_scale;
- const float bias = -input_zero_point * scale;
- for (size_t idx = 0; idx < num_outputs; ++idx) {
- float float_mean = static_cast<float>(temp_sum[idx]) /
- static_cast<float>(num_elements_in_axis);
-
- // Convert to float value.
- output_data[idx] =
- static_cast<T>(round(float_mean * scale + bias)) + output_zero_point;
+ if (compute_sum) {
+ // TODO(b/116341117): Eliminate float and do this completely in 8bit.
+ const float bias = -input_zero_point * scale * num_elements_in_axis + 0.5;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ const U value = static_cast<U>(round(temp_sum[idx] * scale + bias)) +
+ output_zero_point;
+ output_data[idx] = static_cast<T>(value);
+ }
+ } else {
+ const float bias = -input_zero_point * scale + 0.5;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ float float_mean = static_cast<float>(temp_sum[idx]) /
+ static_cast<float>(num_elements_in_axis);
+
+ // Convert to float value.
+ output_data[idx] = static_cast<T>(round(float_mean * scale + bias)) +
+ output_zero_point;
+ }
}
}
return true;
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index d94d821e87..4732a37a65 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -215,7 +215,7 @@ TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
return PrepareSimple(context, node);
}
-TfLiteStatus PrepareMean(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
// reduce_mean requires a buffer to store intermediate sum result.
@@ -274,7 +274,7 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
} else {
TF_LITE_ENSURE(
context,
- reference_ops::Mean<>(
+ reference_ops::QuantizedMeanOrSum<>(
GetTensorData<uint8_t>(op_context.input),
op_context.input->params.zero_point,
op_context.input->params.scale, op_context.input->dims->data,
@@ -286,7 +286,7 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<int>(op_context.axis), num_axis,
op_context.params->keep_dims, GetTensorData<int>(temp_index),
GetTensorData<int>(resolved_axis),
- GetTensorData<int>(temp_sum)));
+ GetTensorData<int>(temp_sum), /*compute_sum=*/false));
}
break;
default:
@@ -416,19 +416,57 @@ TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) {
}
}
+TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+ const auto& input = op_context.input;
+ const auto& output = op_context.output;
+ if (input->type != kTfLiteUInt8 ||
+ (input->params.scale == output->params.scale &&
+ input->params.zero_point == output->params.zero_point)) {
+ return EvalGeneric<kReference, kSum>(context, node);
+ } else {
+ // Rescaling 8bit reduce sum.
+ int num_axis = static_cast<int>(NumElements(op_context.axis));
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
+ }
+
+ TF_LITE_ENSURE(
+ context,
+ reference_ops::QuantizedMeanOrSum<>(
+ GetTensorData<uint8_t>(op_context.input),
+ op_context.input->params.zero_point, op_context.input->params.scale,
+ op_context.input->dims->data, op_context.input->dims->size,
+ GetTensorData<uint8_t>(op_context.output),
+ op_context.output->params.zero_point,
+ op_context.output->params.scale, op_context.output->dims->data,
+ op_context.output->dims->size, GetTensorData<int>(op_context.axis),
+ num_axis, op_context.params->keep_dims,
+ GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
+ GetTensorData<int32>(temp_sum), /*compute_sum=*/true));
+ }
+
+ return kTfLiteOk;
+}
} // namespace reduce
TfLiteRegistration* Register_MEAN_REF() {
static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareMean,
+ reduce::PrepareMeanOrSum,
reduce::EvalMean<reduce::kReference>};
return &r;
}
TfLiteRegistration* Register_SUM_REF() {
- static TfLiteRegistration r = {
- reduce::Init, reduce::Free, reduce::PrepareSimple,
- reduce::EvalGeneric<reduce::kReference, reduce::kSum>};
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareMeanOrSum, reduce::EvalSum};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/reduce_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc
index 6d289b14d8..fb2ec58ab2 100644
--- a/tensorflow/contrib/lite/kernels/reduce_test.cc
+++ b/tensorflow/contrib/lite/kernels/reduce_test.cc
@@ -488,6 +488,18 @@ TEST(ConstUint8SumOpTest, NotKeepDims) {
ArrayFloatNear({-0.823529, -0.815686}, kQuantizedTolerance)));
}
+TEST(ConstUint8SumOpTest, NotKeepDimsRescaling) {
+ float kQuantizedTolerance = GetTolerance(0.0, 2.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, 0.0, 1.0},
+ {TensorType_UINT8, {2}, 0.0, 2.0}, {1}, {1}, false);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ {1.2, 1.2}, kQuantizedTolerance)));
+}
+
TEST(ConstUint8SumOpTest, KeepDims) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
index fcce52a07a..a5621b44cd 100644
--- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
@@ -66,10 +66,11 @@ class LossScaleOptimizer(optimizer.Optimizer):
# Choose a loss scale manager which decides how to pick the right loss scale
# throughout the training process.
- loss_scale_manger = tf.contrib.mixed_precision.FixedLossScaleManager(5000)
+ loss_scale_manager = tf.contrib.mixed_precision.FixedLossScaleManager(5000)
# Wraps the original optimizer in a LossScaleOptimizer.
- loss_scale_optimizer = LossScaleOptimizer(opt, loss_scale_manager)
+ loss_scale_optimizer =
+ tf.contrib.mixed_precision.LossScaleOptimizer(opt, loss_scale_manager)
# Call minimize() on the loss scale optimizer.
train_op = loss_scale_optimizer.minimize(loss)
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
index 6a7f5efecd..b9967fe76d 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
@@ -136,8 +136,8 @@ void MPIRemoteRendezvous::RecvFromRemoteAsync(
MPIRendezvousMgr* mgr =
reinterpret_cast<MPIRendezvousMgr*>(this->rendezvous_mgr_);
- mgr->QueueRequest(parsed.FullKey().ToString(), step_id_,
- std::move(request_call), rendezvous_call);
+ mgr->QueueRequest(string(parsed.FullKey()), step_id_, std::move(request_call),
+ rendezvous_call);
}
MPIRemoteRendezvous::~MPIRemoteRendezvous() {}
@@ -258,7 +258,7 @@ void MPIRendezvousMgr::AddRequest(RecvTensorRequest request,
std::function<MPISendTensorCall*()> res = std::bind(
send_cb, status, send_args, recv_args, val, is_dead, mpi_send_call);
- SendQueueEntry req(parsed.FullKey().ToString().c_str(), std::move(res));
+ SendQueueEntry req(string(parsed.FullKey()), std::move(res));
this->QueueSendRequest(req);
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
index 5596601ddb..90140fcab3 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
@@ -71,7 +71,7 @@ class MPISendTensorCall {
void Init(const Rendezvous::ParsedKey& parsed, const int64 step_id,
const bool is_dead) {
- mRes_.set_key(parsed.FullKey().ToString());
+ mRes_.set_key(string(parsed.FullKey()));
mRes_.set_step_id(step_id);
mRes_.mutable_response()->set_is_dead(is_dead);
mRes_.mutable_response()->set_send_start_micros(
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
index c3db71359c..efaf63086f 100644
--- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
+++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import copy
from tensorflow.contrib.recurrent.python.ops import recurrent
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -62,7 +61,7 @@ class _FunctionalRnnCell(object):
assert initial_state is not None
# TODO(drpng): Dtype needs to be configurable.
- input_dtypes = [dtypes.float32] + _GetDTypesFromStructure(initial_state)
+ input_dtypes = [seq_inputs.dtype] + _GetDTypesFromStructure(initial_state)
# See _index.
like_inputs_t = nest.map_structure(
lambda x: array_ops.stop_gradient(array_ops.gather(x, 0)), seq_inputs)
@@ -144,7 +143,10 @@ class _FunctionalRnnCell(object):
@property
def extended_initial_state(self):
if self._prepend_output:
- return [array_ops.zeros(self._output_shape), self._state_template]
+ return [array_ops.zeros(
+ self._output_shape,
+ dtype=_GetDTypesFromStructure(self._state_template)[0]),
+ self._state_template]
else:
# The base case, where the output is just the hidden state.
return self._state_template
@@ -185,7 +187,7 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output):
lengths = array_ops.tile(
array_ops.reshape(sequence_lengths, [-1, 1]), [1, max_time])
is_less = math_ops.cast(
- math_ops.less(output_time, lengths), dtype=dtypes.float32)
+ math_ops.less(output_time, lengths), dtype=tf_output.dtype)
keep_mask = array_ops.tile(
array_ops.expand_dims(is_less, -1),
[1, 1, vector_size])
diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD
index 2b6a2b2f3c..7f0b3255ed 100644
--- a/tensorflow/contrib/tensorboard/BUILD
+++ b/tensorflow/contrib/tensorboard/BUILD
@@ -32,7 +32,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":projector",
- ":trace",
],
)
@@ -60,33 +59,3 @@ py_test(
"//tensorflow/python:summary",
],
)
-
-# API methods and protos in `tf.contrib.tensorboard.plugins.trace` package.
-py_library(
- name = "trace",
- srcs = glob(
- ["plugins/trace/**/*.py"],
- exclude = ["**/*test*"],
- ),
- srcs_version = "PY2AND3",
- deps = [
- ":protos_all_py",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:lib",
- "//tensorflow/python:platform",
- ],
-)
-
-py_test(
- name = "trace_test",
- size = "small",
- srcs = ["plugins/trace/trace_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_windows"],
- deps = [
- ":trace",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:platform",
- ],
-)
diff --git a/tensorflow/contrib/tensorboard/plugins/__init__.py b/tensorflow/contrib/tensorboard/plugins/__init__.py
index 41aa77910c..4ba469eb52 100644
--- a/tensorflow/contrib/tensorboard/plugins/__init__.py
+++ b/tensorflow/contrib/tensorboard/plugins/__init__.py
@@ -20,4 +20,4 @@ from __future__ import print_function
# Add projects here, they will show up under tf.contrib.tensorboard.plugins
from tensorflow.contrib.tensorboard.plugins import projector
-from tensorflow.contrib.tensorboard.plugins import trace
+
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/__init__.py b/tensorflow/contrib/tensorboard/plugins/trace/__init__.py
deleted file mode 100644
index 2c99f4077e..0000000000
--- a/tensorflow/contrib/tensorboard/plugins/trace/__init__.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Public API for the Trace plugin."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=wildcard-import
-from tensorflow.contrib.tensorboard.plugins.trace.trace import *
-from tensorflow.contrib.tensorboard.plugins.trace.trace_info_pb2 import *
-# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace.py b/tensorflow/contrib/tensorboard/plugins/trace/trace.py
deleted file mode 100644
index 07e5316b8b..0000000000
--- a/tensorflow/contrib/tensorboard/plugins/trace/trace.py
+++ /dev/null
@@ -1,167 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Stores debugging information regarding TensorFlow model."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-import parser
-import re
-import token
-
-from google.protobuf import json_format
-
-from tensorflow.contrib.tensorboard.plugins.trace.trace_info_pb2 import TraceInfo
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import gfile
-
-# List of regex patterns that match files in the core tensorflow library.
-TF_LIB_REGEX_FPATHS = [os.sep + os.path.join('tensorflow', 'python')]
-
-LEFT_TOKENS = [token.LPAR, token.LSQB, token.LBRACE]
-RIGHT_TOKENS = [token.RPAR, token.RSQB, token.RBRACE]
-TOKENS = LEFT_TOKENS + RIGHT_TOKENS
-
-
-def store_trace_info(output_file_path,
- graph=None,
- ignore_regex_fpaths=None):
- """Collects and stores trace information for a TensorFlow model.
-
- The output proto is stored in json format.
-
- Args:
- output_file_path: The path where to store the output proto.
- graph: Optional. The data flow graph. Defaults to `tf.get_default_graph()`.
- ignore_regex_fpaths: Optional. Files whose path matches any of the regexes
- in this list will be ignored. Defaults to patterns that match the core
- tensorflow python library.
- """
- graph = graph or ops.get_default_graph()
-
- if not ignore_regex_fpaths:
- ignore_regex_fpaths = TF_LIB_REGEX_FPATHS
-
- trace_info = TraceInfo()
- # Extract trace information for every op in the graph.
- source_fpaths = set()
- for op in graph.get_operations():
- op_info = trace_info.ops.add()
- op_info.name = op.name
- op_info.op_type = op.type
- op_info.device = op.device
- for trace in op.traceback:
- fname, lineno, _, _ = trace
- # Ignore traces in specified file paths.
- if os.path.isabs(fname) and not _ignore_file_path(fname,
- ignore_regex_fpaths):
- line_trace = op_info.traceback.add()
- line_trace.file_path = fname
- line_trace.line_number = lineno
- source_fpaths.add(fname)
- _add_data_from_tensors(op.inputs, op_info.inputs)
- _add_data_from_tensors(op.outputs, op_info.outputs)
-
- # Read the source files involved in the graph construction.
- for fpath in source_fpaths:
- file_info = trace_info.files.add()
-
- with gfile.Open(fpath, 'r') as f:
- source = f.read()
-
- file_info.file_path = fpath
- file_info.source_code = source
-
- line2start = find_multiline_statements(source)
-
- for key, value in line2start.items():
- file_info.multiline_statements[key] = value
-
- # Make sure the directory for the output file exists.
- output_file_path = os.path.expanduser(output_file_path)
- output_dir = os.path.dirname(output_file_path)
- if not gfile.Exists(output_dir):
- gfile.MakeDirs(output_dir)
-
- # Store the debug information.
- with gfile.Open(output_file_path, 'w') as f:
- f.write(json_format.MessageToJson(trace_info))
-
-
-def find_multiline_statements(source):
- """Parses the python source and finds multiline statements.
-
- Based on counting the number of open and closed parenthesis on each line.
-
- Args:
- source: The source code string.
-
- Returns:
- A dict that maps a line index A to a line index B, where A is the end of a
- multiline statement and B is the start. Line indexing is 0-based.
- """
- # Get the AST.
- tree = parser.suite(source)
- line2paren_count = [0] * (source.count('\n') + 1)
- _count_brackets_braces_parenthesis(tree.totuple(True), line2paren_count)
-
- line2start = {}
- for end in range(len(line2paren_count)):
- if line2paren_count[end] >= 0:
- # This is not the end of a multiline statement.
- continue
- cumulative_paren_count = 0
- for start in range(end, -1, -1):
- cumulative_paren_count += line2paren_count[start]
- if cumulative_paren_count == 0:
- line2start[end] = start
- break
- return line2start
-
-
-def _add_data_from_tensors(tensors, info):
- for t in tensors:
- tensor_info = info.add()
-
- shape = t.get_shape()
- if shape.ndims:
- shape = [(-1 if s is None else s) for s in shape.as_list()]
- tensor_info.shape.extend(shape)
- tensor_info.dtype = t.dtype.name
- tensor_info.num_bytes_per_elem = t.dtype.size
-
- for c in t.consumers():
- tensor_info.consumers.append(c.name)
-
-
-def _ignore_file_path(fname, ignore_regex_fpaths):
- for regex_pattern in ignore_regex_fpaths:
- if re.search(regex_pattern, fname):
- return True
- return False
-
-
-def _count_brackets_braces_parenthesis(node, line2par):
- if isinstance(node[1], tuple):
- for child in node[1:]:
- _count_brackets_braces_parenthesis(child, line2par)
- else:
- tok = node[0]
- if tok in TOKENS:
- lineno = node[2]
- line2par[lineno - 1] += (1 if tok in LEFT_TOKENS else -1)
- return line2par
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto b/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto
deleted file mode 100644
index 9f20becb0f..0000000000
--- a/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto
+++ /dev/null
@@ -1,60 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-syntax = "proto3";
-
-package tensorflow.contrib.tensorboard;
-
-message TraceInfo {
- repeated OpInfo ops = 1;
- repeated FileInfo files = 2;
-}
-
-message OpInfo {
- string name = 1;
- string op_type = 2;
- string device = 3;
- repeated LineTrace traceback = 4;
- repeated TensorInfo inputs = 5;
- repeated TensorInfo outputs = 6;
-}
-
-message LineTrace {
- // Absolute file path.
- string file_path = 1;
- // 1-based line number.
- uint32 line_number = 2;
-}
-
-message TensorInfo {
- // Size of the tensor for each dimension. Value of -1 denotes "unknown"
- // size for that dimension.
- repeated int32 shape = 1;
- // The data type of the tensor.
- string dtype = 2;
- // Number of bytes per element in the tensor.
- uint32 num_bytes_per_elem = 3;
- // List of operation names that consume this tensor.
- repeated string consumers = 4;
-}
-
-message FileInfo {
- // Absolute file path to the source code.
- string file_path = 1;
- string source_code = 2;
- // Map from end of statement to start of statement. End and start are 0-based
- // line indexes.
- map<uint32, uint32> multiline_statements = 3;
-}
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py b/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py
deleted file mode 100644
index d580f04c5f..0000000000
--- a/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tensorflow.contrib.tensorboard.plugins.trace package."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tempfile
-
-from google.protobuf import json_format
-
-from tensorflow.contrib.tensorboard.plugins import trace
-from tensorflow.python.framework import constant_op
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import test
-
-
-class TraceTest(test.TestCase):
-
- def setUp(self):
- self._temp_dir = tempfile.mkdtemp()
- self._temp_trace_json = self._temp_dir + 'trace.json'
-
- def tearDown(self):
- gfile.DeleteRecursively(self._temp_dir)
-
- def testEmptyGraph(self):
- trace_info = self._store_and_read_trace_info()
- self.assertEqual(len(trace_info.ops), 0)
-
- def testHasSourceCodeOfThisFile(self):
- constant_op.constant(0)
- trace_info = self._store_and_read_trace_info()
-
- self.assertTrue(trace_info.files)
- for file_info in trace_info.files:
- if file_info.file_path.endswith('trace_test.py'):
- return
- self.fail('trace_test file not found in the trace info json')
-
- def testHasTheConstantOp(self):
- constant_op.constant(0)
- trace_info = self._store_and_read_trace_info()
-
- self.assertTrue(trace_info.ops)
-
- for op in trace_info.ops:
- if op.op_type == 'Const':
- return
- self.fail('Could not find operation of type `Const` in the graph')
-
- def testMultilineStatements(self):
- source = """def test():
- a(4,
- 3,
- 1)
-
- b(3, 4, 5)
-
- c((4, 3),
- (),
- )
- """
- line2start = trace.find_multiline_statements(source)
-
- self.assertEqual(line2start[3], 1)
- self.assertEqual(line2start[9], 7)
- self.assertEqual(len(line2start), 2)
-
- def _store_and_read_trace_info(self):
- trace.store_trace_info(self._temp_trace_json)
- trace_info = trace.TraceInfo()
-
- with gfile.Open(self._temp_trace_json) as f:
- text = f.read()
- json_format.Parse(text, trace_info)
-
- return trace_info
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
index 1d27fffc62..9bbe87e301 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
@@ -191,6 +191,43 @@ class ARModel(model.TimeSeriesModel):
Note that this class can also be used to regress against time only by setting
the input_window_size to zero.
+
+ Each periodicity in the `periodicities` arg is divided by the
+ `num_time_buckets` into time buckets that are represented as features added
+ to the model.
+
+ A good heuristic for picking an appropriate periodicity for a given data set
+ would be the length of cycles in the data. For example, energy usage in a
+ home is typically cyclic each day. If the time feature in a home energy
+ usage dataset is in the unit of hours, then 24 would be an appropriate
+ periodicity. Similarly, a good heuristic for `num_time_buckets` is how often
+ the data is expected to change within the cycle. For the aforementioned home
+ energy usage dataset and periodicity of 24, then 48 would be a reasonable
+ value if usage is expected to change every half hour.
+
+ Each feature's value for a given example with time t is the difference
+ between t and the start of the time bucket it falls under. If it doesn't fall
+ under a feature's associated time bucket, then that feature's value is zero.
+
+ For example: if `periodicities` = (9, 12) and `num_time_buckets` = 3, then 6
+ features would be added to the model, 3 for periodicity 9 and 3 for
+ periodicity 12.
+
+ For an example data point where t = 17:
+ - It's in the 3rd time bucket for periodicity 9 (2nd period is 9-18 and 3rd
+ time bucket is 15-18)
+ - It's in the 2nd time bucket for periodicity 12 (2nd period is 12-24 and
+ 2nd time bucket is between 16-20).
+
+ Therefore the 6 added features for this row with t = 17 would be:
+
+ # Feature name (periodicity#_timebucket#), feature value
+ P9_T1, 0 # not in first time bucket
+ P9_T2, 0 # not in second time bucket
+ P9_T3, 2 # 17 - 15 since 15 is the start of the 3rd time bucket
+ P12_T1, 0 # not in first time bucket
+ P12_T2, 1 # 17 - 16 since 16 is the start of the 2nd time bucket
+ P12_T3, 0 # not in third time bucket
"""
SQUARED_LOSS = "squared_loss"
NORMAL_LIKELIHOOD_LOSS = "normal_likelihood_loss"
@@ -208,7 +245,9 @@ class ARModel(model.TimeSeriesModel):
Args:
periodicities: periodicities of the input data, in the same units as the
- time feature. Note this can be a single value or a list of values for
+ time feature (for example 24 if feeding hourly data with a daily
+ periodicity, or 60 * 24 if feeding minute-level data with daily
+ periodicity). Note this can be a single value or a list of values for
multiple periodicities.
input_window_size: Number of past time steps of data to look at when doing
the regression.
@@ -218,21 +257,18 @@ class ARModel(model.TimeSeriesModel):
prediction_model_factory: A callable taking arguments `num_features`,
`input_window_size`, and `output_window_size` and returning a
`tf.keras.Model`. The `Model`'s `call()` takes two arguments: an input
- window and an output window, and returns a dictionary of
- predictions. See `FlatPredictionModel` for an example. Example usage:
+ window and an output window, and returns a dictionary of predictions.
+ See `FlatPredictionModel` for an example. Example usage:
- ```python
- model = ar_model.ARModel(
- periodicities=2, num_features=3,
- prediction_model_factory=functools.partial(
- FlatPredictionModel,
- hidden_layer_sizes=[10, 10]))
- ```
+ ```python model = ar_model.ARModel( periodicities=2, num_features=3,
+ prediction_model_factory=functools.partial( FlatPredictionModel,
+ hidden_layer_sizes=[10, 10])) ```
The default model computes predictions as a linear function of flattened
input and output windows.
num_time_buckets: Number of buckets into which to divide (time %
- periodicity) for generating time based features.
+ periodicity). This value multiplied by the number of periodicities is
+ the number of time features added to the model.
loss: Loss function to use for training. Currently supported values are
SQUARED_LOSS and NORMAL_LIKELIHOOD_LOSS. Note that for
NORMAL_LIKELIHOOD_LOSS, we train the covariance term as well. For
@@ -240,10 +276,9 @@ class ARModel(model.TimeSeriesModel):
observations and predictions, while the training loss is computed on
normalized data (if input statistics are available).
exogenous_feature_columns: A list of `tf.feature_column`s (for example
- `tf.feature_column.embedding_column`) corresponding to exogenous
- features which provide extra information to the model but are not part
- of the series to be predicted. Passed to
- `tf.feature_column.input_layer`.
+ `tf.feature_column.embedding_column`) corresponding to
+ features which provide extra information to the model but are not part
+ of the series to be predicted.
"""
self._model_factory = prediction_model_factory
self.input_window_size = input_window_size
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 0ddc4b4144..af68aa03cf 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -30,6 +30,7 @@ from tensorflow.contrib.timeseries.python.timeseries.state_space_models import s
from tensorflow.contrib.timeseries.python.timeseries.state_space_models.filtering_postprocessor import StateInterpolatingAnomalyDetector
from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.estimator.canned import optimizers
from tensorflow.python.estimator.export import export_lib
from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
@@ -386,6 +387,162 @@ class ARRegressor(TimeSeriesRegressor):
config=config)
+# TODO(b/113684821): Add detailed documentation on what the input_fn should do.
+# Add an example of making and returning a Dataset object. Determine if
+# endogenous features can be passed in as FeatureColumns. Move ARModel's loss
+# functions into a more general location.
+class LSTMAutoRegressor(TimeSeriesRegressor):
+ """An Estimator for an LSTM autoregressive model.
+
+ LSTMAutoRegressor is a window-based model, inputting fixed windows of length
+ `input_window_size` and outputting fixed windows of length
+ `output_window_size`. These two parameters must add up to the window_size
+ of data returned by the `input_fn`.
+
+ Each periodicity in the `periodicities` arg is divided by the `num_timesteps`
+ into timesteps that are represented as time features added to the model.
+
+ A good heuristic for picking an appropriate periodicity for a given data set
+ would be the length of cycles in the data. For example, energy usage in a
+ home is typically cyclic each day. If the time feature in a home energy
+ usage dataset is in the unit of hours, then 24 would be an appropriate
+ periodicity. Similarly, a good heuristic for `num_timesteps` is how often the
+ data is expected to change within the cycle. For the aforementioned home
+ energy usage dataset and periodicity of 24, then 48 would be a reasonable
+ value if usage is expected to change every half hour.
+
+ Each feature's value for a given example with time t is the difference
+ between t and the start of the timestep it falls under. If it doesn't fall
+ under a feature's associated timestep, then that feature's value is zero.
+
+ For example: if `periodicities` = (9, 12) and `num_timesteps` = 3, then 6
+ features would be added to the model, 3 for periodicity 9 and 3 for
+ periodicity 12.
+
+ For an example data point where t = 17:
+ - It's in the 3rd timestep for periodicity 9 (2nd period is 9-18 and 3rd
+ timestep is 15-18)
+ - It's in the 2nd timestep for periodicity 12 (2nd period is 12-24 and
+ 2nd timestep is between 16-20).
+
+ Therefore the 6 added features for this row with t = 17 would be:
+
+ # Feature name (periodicity#_timestep#), feature value
+ P9_T1, 0 # not in first timestep
+ P9_T2, 0 # not in second timestep
+ P9_T3, 2 # 17 - 15 since 15 is the start of the 3rd timestep
+ P12_T1, 0 # not in first timestep
+ P12_T2, 1 # 17 - 16 since 16 is the start of the 2nd timestep
+ P12_T3, 0 # not in third timestep
+
+ Example Code:
+
+ ```python
+ extra_feature_columns = (
+ feature_column.numeric_column("exogenous_variable"),
+ )
+
+ estimator = LSTMAutoRegressor(
+ periodicities=10,
+ input_window_size=10,
+ output_window_size=5,
+ model_dir="/path/to/model/dir",
+ num_features=1,
+ extra_feature_columns=extra_feature_columns,
+ num_timesteps=50,
+ num_units=10,
+ optimizer=tf.train.ProximalAdagradOptimizer(...))
+
+ # Input builders
+ def input_fn_train():
+ return {
+ "times": tf.range(15)[None, :],
+ "values": tf.random_normal(shape=[1, 15, 1])
+ }
+ estimator.train(input_fn=input_fn_train, steps=100)
+
+ def input_fn_eval():
+ pass
+ metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
+
+ def input_fn_predict():
+ pass
+ predictions = estimator.predict(input_fn=input_fn_predict)
+ ```
+ """
+
+ def __init__(self,
+ periodicities,
+ input_window_size,
+ output_window_size,
+ model_dir=None,
+ num_features=1,
+ extra_feature_columns=None,
+ num_timesteps=10,
+ loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,
+ num_units=128,
+ optimizer="Adam",
+ config=None):
+ """Initialize the Estimator.
+
+ Args:
+ periodicities: periodicities of the input data, in the same units as the
+ time feature (for example 24 if feeding hourly data with a daily
+ periodicity, or 60 * 24 if feeding minute-level data with daily
+ periodicity). Note this can be a single value or a list of values for
+ multiple periodicities.
+ input_window_size: Number of past time steps of data to look at when doing
+ the regression.
+ output_window_size: Number of future time steps to predict. Note that
+ setting this value to > 1 empirically seems to give a better fit.
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator
+ to continue training a previously saved model.
+ num_features: The dimensionality of the time series (default value is
+ one for univariate, more than one for multivariate).
+ extra_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to features which
+ provide extra information to the model but are not part of the series to
+ be predicted.
+ num_timesteps: Number of buckets into which to divide (time %
+ periodicity). This value multiplied by the number of periodicities is
+ the number of time features added to the model.
+ loss: Loss function to use for training. Currently supported values are
+ SQUARED_LOSS and NORMAL_LIKELIHOOD_LOSS. Note that for
+ NORMAL_LIKELIHOOD_LOSS, we train the covariance term as well. For
+ SQUARED_LOSS, the evaluation loss is reported based on un-scaled
+ observations and predictions, while the training loss is computed on
+ normalized data.
+ num_units: The size of the hidden state in the encoder and decoder LSTM
+ cells.
+ optimizer: string, `tf.train.Optimizer` object, or callable that defines
+ the optimizer algorithm to use for training. Defaults to the Adam
+ optimizer with a learning rate of 0.01.
+ config: Optional `estimator.RunConfig` object to configure the runtime
+ settings.
+ """
+ optimizer = optimizers.get_optimizer_instance(
+ optimizer, learning_rate=0.01)
+ model = ar_model.ARModel(
+ periodicities=periodicities,
+ input_window_size=input_window_size,
+ output_window_size=output_window_size,
+ num_features=num_features,
+ exogenous_feature_columns=extra_feature_columns,
+ num_time_buckets=num_timesteps,
+ loss=loss,
+ prediction_model_factory=functools.partial(
+ ar_model.LSTMPredictionModel, num_units=num_units))
+ state_manager = state_management.FilteringOnlyStateManager()
+ super(LSTMAutoRegressor, self).__init__(
+ model=model,
+ state_manager=state_manager,
+ optimizer=optimizer,
+ model_dir=model_dir,
+ config=config,
+ head_type=ts_head_lib.OneShotPredictionHead)
+
+
class StateSpaceRegressor(TimeSeriesRegressor):
"""An Estimator for general state space models."""
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 83260fc59a..6ec7184c68 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -226,5 +226,40 @@ class TimeSeriesRegressorTest(test.TestCase):
input_pipeline.NumpyReader(numpy_data)),
steps=1)
+ def test_ar_lstm_regressor(self):
+ dtype = dtypes.float32
+ model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+ exogenous_feature_columns = (
+ feature_column.numeric_column("exogenous"),
+ )
+ estimator = estimators.LSTMAutoRegressor(
+ periodicities=10,
+ input_window_size=10,
+ output_window_size=6,
+ model_dir=model_dir,
+ num_features=1,
+ extra_feature_columns=exogenous_feature_columns,
+ num_units=10,
+ config=_SeedRunConfig())
+ times = numpy.arange(20, dtype=numpy.int64)
+ values = numpy.arange(20, dtype=dtype.as_numpy_dtype)
+ exogenous = numpy.arange(20, dtype=dtype.as_numpy_dtype)
+ features = {
+ feature_keys.TrainEvalFeatures.TIMES: times,
+ feature_keys.TrainEvalFeatures.VALUES: values,
+ "exogenous": exogenous
+ }
+ train_input_fn = input_pipeline.RandomWindowInputFn(
+ input_pipeline.NumpyReader(features), shuffle_seed=2, num_threads=1,
+ batch_size=16, window_size=16)
+ eval_input_fn = input_pipeline.RandomWindowInputFn(
+ input_pipeline.NumpyReader(features), shuffle_seed=3, num_threads=1,
+ batch_size=16, window_size=16)
+ estimator.train(input_fn=train_input_fn, steps=1)
+ evaluation = estimator.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ self.assertAllEqual(evaluation["loss"], evaluation["average_loss"])
+ self.assertAllEqual([], evaluation["loss"].shape)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc
index 61469686e4..5b72b1604a 100644
--- a/tensorflow/contrib/verbs/verbs_server_lib.cc
+++ b/tensorflow/contrib/verbs/verbs_server_lib.cc
@@ -77,7 +77,7 @@ Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
}
namespace {
-std::once_call reg_mem_visitors_call;
+std::once_flag reg_mem_visitors_call;
} // namespace
Status VerbsServer::Init(ServiceInitFunction service_func,
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 85b6d4ff68..d914fdb96c 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2501,7 +2501,12 @@ tf_cuda_library(
cc_header_only_library(
name = "framework_internal_headers_lib",
- includes = ["../../external/com_google_absl"],
+ # Fully depend on external repositories, because identifying the headers
+ # is fragile.
+ extra_deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
deps = [
":lib",
":lib_internal",
@@ -2587,11 +2592,12 @@ tf_cuda_library(
cc_header_only_library(
name = "framework_headers_lib",
+ # Fully depend on external repositories, because identifying the headers
+ # is fragile.
extra_deps = [
- # ABSL headers get dropped, so we add them back here.
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
- includes = ["../../external/com_google_absl"],
visibility = ["//visibility:public"],
deps = [
":framework",
@@ -2601,7 +2607,12 @@ cc_header_only_library(
cc_header_only_library(
name = "stream_executor_headers_lib",
- includes = ["../../external/com_google_absl"],
+ # Fully depend on external repositories, because identifying the headers
+ # is fragile.
+ extra_deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
visibility = ["//visibility:public"],
deps = [
":stream_executor",
diff --git a/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt
new file mode 100644
index 0000000000..3c8a455983
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt
@@ -0,0 +1,49 @@
+op {
+ graph_op_name: "ExtractVolumePatches"
+ in_arg {
+ name: "input"
+ description: <<END
+5-D Tensor with shape `[batch, in_planes, in_rows, in_cols, depth]`.
+END
+ }
+ out_arg {
+ name: "patches"
+ description: <<END
+5-D Tensor with shape `[batch, out_planes, out_rows, out_cols,
+ksize_planes * ksize_rows * ksize_cols * depth]` containing patches
+with size `ksize_planes x ksize_rows x ksize_cols x depth` vectorized
+in the "depth" dimension. Note `out_planes`, `out_rows` and `out_cols`
+are the dimensions of the output patches.
+END
+ }
+ attr {
+ name: "ksizes"
+ description: <<END
+The size of the sliding window for each dimension of `input`.
+END
+ }
+ attr {
+ name: "strides"
+ description: <<END
+1-D of length 5. How far the centers of two consecutive patches are in
+`input`. Must be: `[1, stride_planes, stride_rows, stride_cols, 1]`.
+END
+ }
+ attr {
+ name: "padding"
+ description: <<END
+The type of padding algorithm to use.
+
+We specify the size-related attributes as:
+
+```python
+ ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1]
+ strides = [1, stride_planes, strides_rows, strides_cols, 1]
+```
+END
+ }
+ summary: <<END
+Extract `patches` from `input` and put them in the "depth" output
+dimension. 3D extension of `extract_image_patches`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt
new file mode 100644
index 0000000000..4b0a5d8f65
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "MultiDeviceIterator"
+ out_arg {
+ name: "handle"
+ description: <<END
+Handle to the resource created.
+END
+ }
+ attr {
+ name: "devices"
+ description: <<END
+A list of devices the iterator works across.
+END
+ }
+ attr {
+ name: "shared_name"
+ description: <<END
+If non-empty, this resource will be shared under the given name
+across multiple sessions.
+END
+ }
+ attr {
+ name: "container"
+ description: <<END
+If non-empty, this resource is placed in the given container.
+Otherwise, a default container is used.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Creates a MultiDeviceIterator resource."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt
new file mode 100644
index 0000000000..adaacd8ab7
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "MultiDeviceIteratorFromStringHandle"
+ in_arg {
+ name: "string_handle"
+ description: <<END
+String representing the resource.
+END
+ }
+ out_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIterator resource.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Generates a MultiDeviceIterator resource from its provided string handle."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt
new file mode 100644
index 0000000000..f9be9188cc
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt
@@ -0,0 +1,41 @@
+op {
+ graph_op_name: "MultiDeviceIteratorGetNextFromShard"
+ in_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIterator resource.
+END
+ }
+ in_arg {
+ name: "shard_num"
+ description: <<END
+Integer representing which shard to fetch data for.
+END
+ }
+ in_arg {
+ name: "incarnation_id"
+ description: <<END
+Which incarnation of the MultiDeviceIterator is running.
+END
+ }
+ out_arg {
+ name: "components"
+ description: <<END
+Result of the get_next on the dataset.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Gets next element for the provided shard number."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt
new file mode 100644
index 0000000000..6b54fa1307
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt
@@ -0,0 +1,30 @@
+op {
+ graph_op_name: "MultiDeviceIteratorInit"
+ in_arg {
+ name: "dataset"
+ description: <<END
+Dataset to be iterated upon.
+END
+ }
+ in_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIteratorResource.
+END
+ }
+ in_arg {
+ name: "max_buffer_size"
+ description: <<END
+The maximum size of the host side per device buffer to keep.
+END
+ }
+ out_arg {
+ name: "incarnation_id"
+ description: <<END
+An int64 indicating which incarnation of the MultiDeviceIterator
+is running.
+END
+ }
+ summary: "Initializes the multi device iterator with the given dataset."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt
new file mode 100644
index 0000000000..1f1fdf99b4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt
@@ -0,0 +1,17 @@
+op {
+ graph_op_name: "MultiDeviceIteratorToStringHandle"
+ in_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIterator resource.
+END
+ }
+ out_arg {
+ name: "string_handle"
+ description: <<END
+A string representing the resource.
+END
+ }
+ summary: "Produces a string handle for the given MultiDeviceIterator."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index cf3d1f0b79..d800a86199 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -347,7 +347,12 @@ namespace {
static Status WrappedTensorDeviceCopy(
const Tensor& from, Tensor* to,
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
- if (DMAHelper::CanUseDMA(&from)) {
+ if (from.dtype() == DT_VARIANT) {
+ // TODO(b/116349787): Implement support for nested variants.
+ return errors::Unimplemented(
+ "Support for copying nested variants to device has not yet been "
+ "implemented.");
+ } else if (DMAHelper::CanUseDMA(&from)) {
TF_RETURN_IF_ERROR(copy(from, to));
} else {
*to = from;
diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
index 0b096a14a3..2ed4f69f90 100644
--- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
@@ -77,6 +77,9 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
options.config.mutable_graph_options()
->mutable_rewrite_options()
->set_min_graph_nodes(-1);
+ options.config.mutable_graph_options()
+ ->mutable_rewrite_options()
+ ->set_pin_to_host_optimization(RewriterConfig::OFF);
std::unique_ptr<Session> session(NewSession(options));
TF_ASSERT_OK(session->Create(def));
std::vector<std::pair<string, Tensor>> inputs;
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 98719542c0..7cef34ac52 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -143,6 +143,8 @@ struct NodeItem {
bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr
bool is_merge : 1; // True iff IsMerge(node)
bool is_enter : 1; // True iff IsEnter(node)
+ bool is_constant_enter : 1; // True iff IsEnter(node) and
+ // node->GetAttr("is_constant") == true.
bool is_exit : 1; // True iff IsExit(node)
bool is_control_trigger : 1; // True iff IsControlTrigger(node)
bool is_sink : 1; // True iff IsSink(node)
@@ -626,6 +628,14 @@ Status ExecutorImpl::Initialize() {
item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
item->is_merge = IsMerge(n);
item->is_enter = IsEnter(n);
+ if (item->is_enter) {
+ bool is_constant_enter;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
+ item->is_constant_enter = is_constant_enter;
+ } else {
+ item->is_constant_enter = false;
+ }
item->is_exit = IsExit(n);
item->is_control_trigger = IsControlTrigger(n);
item->is_sink = IsSink(n);
@@ -1988,15 +1998,12 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
is_frame_done = input_frame->DecrementOutstandingOpsLocked(
&impl_->gview_, input_iter, ready);
} else if (item->is_enter) {
- bool is_constant;
- const Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
- DCHECK(s.ok()) << s;
FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
output_iter = 0;
{
const NodeItem* item = impl_->gview_.node(node->id());
mutex_lock l(output_frame->mu);
- if (is_constant) {
+ if (item->is_constant_enter) {
// Propagate to all active iterations if this is a loop invariant.
output_frame->AddLoopInv(item, (*outputs)[0], ready);
} else {
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc
index 7171ae059b..3b1d7d8347 100644
--- a/tensorflow/core/grappler/clusters/cluster.cc
+++ b/tensorflow/core/grappler/clusters/cluster.cc
@@ -83,6 +83,7 @@ void Cluster::DisableOptimizer(bool disable) {
rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT);
rewriter_config->set_shape_optimization(RewriterConfig::OFF);
rewriter_config->set_remapping(RewriterConfig::OFF);
+ rewriter_config->set_pin_to_host_optimization(RewriterConfig::OFF);
rewriter_config->mutable_auto_parallel()->set_enable(false);
rewriter_config->clear_optimizers();
} else {
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index a6b6b6f8b2..b8d8243174 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -14,11 +14,41 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
+ for (int output_arg_id = 0; output_arg_id < op.output_arg_size();
+ ++output_arg_id) {
+ if (port_id < 0) {
+ return -1;
+ } else if (port_id == 0) {
+ return output_arg_id;
+ }
+
+ const auto& output_arg = op.output_arg(output_arg_id);
+ if (!output_arg.number_attr().empty()) {
+ const int n = node.attr().at(output_arg.number_attr()).i();
+ if (n < 0) {
+ // This should never happen.
+ DCHECK_GE(n, 0);
+ return -1;
+ }
+ if (port_id < n) {
+ return output_arg_id;
+ }
+ port_id -= n;
+ } else {
+ --port_id;
+ }
+ }
+
+ return -1;
+}
+
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
for (int i = 0; i < graph_->node_size(); i++) {
auto node = graph_->mutable_node(i);
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index ac260f85a0..ec946ca3b5 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -20,11 +20,21 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
+// Map a node/op's output port_id to arg_id.
+//
+// The port_id refers to the n-th tensor of the node, while the arg_id refers to
+// the n-th arg of the op. These two can be different if an op's arg is a list
+// of tensors.
+//
+// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
+
// A utility class to simplify the traversal of a GraphDef.
class GraphView {
public:
diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc
index 958eb921fb..30512d9d47 100644
--- a/tensorflow/core/grappler/graph_view_test.cc
+++ b/tensorflow/core/grappler/graph_view_test.cc
@@ -25,6 +25,60 @@ namespace {
class GraphViewTest : public ::testing::Test {};
+TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
+ ops::ShapeN b(s.WithOpName("b"), {a, a, a});
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& a_node_def = *graph_view.GetNode("a");
+ const NodeDef& b_node_def = *graph_view.GetNode("b");
+
+ const OpDef* a_op_def = nullptr;
+ const OpDef* b_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(a_node_def.op(), &a_op_def).ok());
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
+
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1));
+
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 3));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4));
+}
+
+TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
+ for (int num_splits : {1, 2}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10});
+ ops::SparseSplit b(s.WithOpName("b"), a, a, a, a, num_splits);
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& b_node_def = *graph_view.GetNode("b");
+ const OpDef* b_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
+
+ for (int port_id = 0; port_id <= num_splits * 3; ++port_id) {
+ int arg_id = -1;
+ if (port_id < num_splits * 3) {
+ arg_id = port_id / num_splits;
+ }
+ EXPECT_EQ(arg_id, OpOutputPortIdToArgId(b_node_def, *b_op_def, port_id));
+ }
+ }
+}
+
TEST_F(GraphViewTest, BasicGraph) {
TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"});
GrapplerItem item;
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 261dee4382..960d1addb3 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -518,6 +518,7 @@ cc_library(
":loop_optimizer",
":memory_optimizer",
":model_pruner",
+ ":pin_to_host_optimizer",
":remapper",
":scoped_allocator_optimizer",
":shape_optimizer",
@@ -883,3 +884,41 @@ tf_cc_test(
"//tensorflow/core/grappler/utils:grappler_test",
],
)
+
+cc_library(
+ name = "pin_to_host_optimizer",
+ srcs = ["pin_to_host_optimizer.cc"],
+ hdrs = [
+ "pin_to_host_optimizer.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_optimizer",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:frame",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
+ "//tensorflow/core/grappler/utils:topological_sort",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "pin_to_host_optimizer_test",
+ srcs = ["pin_to_host_optimizer_test.cc"],
+ deps = [
+ ":pin_to_host_optimizer",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/utils:grappler_test",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index b3f60e34f9..2dd9ee822e 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -88,6 +88,16 @@ NodeDef* AddScalarConstNodeHelper(
} // namespace
+NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
+ NodeDef node;
+ node.set_op("Placeholder");
+ SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node);
+ (*node.mutable_attr())["dtype"].set_type(dtype);
+ TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
+ shape->set_unknown_rank(false);
+ return graph->AddNode(std::move(node));
+}
+
NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<string>& inputs,
const std::vector<std::pair<string, AttrValue>>& attributes,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 1652afcd9e..b117482db2 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -37,6 +37,9 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<std::pair<string, AttrValue>>& attributes,
MutableGraphView* graph);
+// Adds Placeholder node for given type.
+NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph);
+
// Adds a Const node with the given value to the graph.
template <typename T>
NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
index a26f1000a3..cf5a19bab1 100644
--- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
@@ -33,25 +33,27 @@ namespace {
bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) {
if (take_node.op() != "TakeDataset") return false;
- const NodeDef& count_node = *graph.GetNode(take_node.input(1));
+ const auto& count_node = *graph.GetNode(take_node.input(1));
+ if (count_node.op() != "Const") return false;
// We are looking only for 'take' with negative count.
return count_node.attr().at("value").tensor().int64_val(0) < 0;
}
+bool IsConstNodeWithValue(const NodeDef& node, int value) {
+ if (node.op() != "Const") return false;
+ return node.attr().at("value").tensor().int64_val(0) == value;
+}
+
bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) {
if (skip_node.op() != "SkipDataset") return false;
-
- const NodeDef& count_node = *graph.GetNode(skip_node.input(1));
// We are looking only for skip(0) nodes.
- return count_node.attr().at("value").tensor().int64_val(0) == 0;
+ return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0);
}
bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) {
if (repeat_node.op() != "RepeatDataset") return false;
-
- const NodeDef& count_node = *graph.GetNode(repeat_node.input(1));
// We are looking only for repeat(1) nodes.
- return count_node.attr().at("value").tensor().int64_val(0) == 1;
+ return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1);
}
bool IsNoOp(const NodeDef& node, const GraphView& graph) {
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
index f445e75aa7..be1a66df75 100644
--- a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
@@ -43,6 +43,14 @@ NodeDef *MakeUnaryNode(StringPiece node_type, int count, string input_node,
GetCommonAttributes(), graph);
}
+NodeDef *MakeUnaryNonConstNode(StringPiece node_type, string input_node,
+ MutableGraphView *graph) {
+ NodeDef *node_count = graph_utils::AddScalarPlaceholder(DT_INT32, graph);
+ return graph_utils::AddNode("", node_type,
+ {std::move(input_node), node_count->name()},
+ GetCommonAttributes(), graph);
+}
+
NodeDef *MakeCacheNode(string input_node, MutableGraphView *graph) {
NodeDef *node_filename =
graph_utils::AddScalarConstNode<StringPiece>("", graph);
@@ -205,6 +213,41 @@ INSTANTIATE_TEST_CASE_P(
::testing::Values(*kTakeNode, *kSkipNode,
*kRepeatNode)));
+struct NoOpPlaceholdersTest
+ : ::testing::TestWithParam<std::tuple<string, string>> {};
+
+TEST_P(NoOpPlaceholdersTest, NonConstNoOpNode) {
+ GrapplerItem item;
+ MutableGraphView graph(&item.graph);
+
+ static_assert(std::tuple_size<NodesTypes>::value == 2,
+ "Make sure to include everything in the test");
+ const std::vector<string> noop_nodes = {std::get<0>(GetParam()),
+ std::get<1>(GetParam())};
+ NodeDef *range_node = MakeRangeNode(&graph);
+ std::vector<string> nodes_to_keep;
+ nodes_to_keep.reserve(noop_nodes.size());
+ NodeDef *previous = range_node;
+
+ for (const auto &noop_node : noop_nodes) {
+ NodeDef *node = MakeUnaryNonConstNode(noop_node, previous->name(), &graph);
+ nodes_to_keep.push_back(node->name());
+ previous = node;
+ }
+
+ NoOpElimination optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ for (const auto &noop_node_name : nodes_to_keep)
+ EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName(noop_node_name, output));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DoNotRemovePlaceholders, NoOpPlaceholdersTest,
+ ::testing::Combine(
+ ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset"),
+ ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset")));
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 4b0cbfaa82..3992b45c64 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
#include "tensorflow/core/grappler/optimizers/remapper.h"
#include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
@@ -105,6 +106,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("scoped_allocator",
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
cfg_.scoped_allocator_opts()));
+ MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
return std::unique_ptr<GraphOptimizer>();
}
@@ -133,6 +135,9 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
+ if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
+ optimizers->push_back(MakeUnique<PinToHostOptimizer>());
+ }
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
optimizers->push_back(
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
@@ -215,6 +220,16 @@ Status MetaOptimizer::InitializeCustomGraphOptimizers(
TF_RETURN_IF_ERROR(custom_optimizer->Init(&optimizer_config));
optimizers->push_back(std::move(custom_optimizer));
} else {
+ // If there are no custom optimizers with given name, try to initalize a
+ // default optimizer. This way, custom configurable optimizers can be
+ // mixed with default optimizers in any order.
+ auto optimizer = MakeNewOptimizer(optimizer_config.name());
+ if (optimizer) {
+ VLOG(2) << "Registered default graph optimizer: "
+ << optimizer_config.name();
+ optimizers->push_back(std::move(optimizer));
+ continue;
+ }
VLOG(2) << "Can't register an optimizer by name: "
<< optimizer_config.name();
}
@@ -468,6 +483,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
cfg.debug_stripper() == RewriterConfig::ON ||
cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
+ cfg.pin_to_host_optimization() == RewriterConfig::ON ||
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
}
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
new file mode 100644
index 0000000000..c8f9311b2e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -0,0 +1,226 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace internal {
+
+// TODO(williamchan): Change this constant to be something smarter, maybe
+// dynamically determined.
+constexpr int64 kTensorMaxSize = 64;
+
+// Find KernelDef for `node`.
+Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) {
+ // Try find KernelDef for node.device, else GPU or CPU.
+ for (const DeviceType& device :
+ {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) {
+ Status s = FindKernelDef(device, node, kdef, nullptr);
+ if (s.ok()) {
+ return Status::OK();
+ }
+ }
+
+ return errors::NotFound("Could not find KernelDef for op: ", node.op());
+}
+
+// Check if all node's inputs are pinned to CPU memory.
+bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) {
+ // Loop through all the inputs excluding the controlling nodes.
+ for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) {
+ // Check if (the fanin) op's device is on CPU.
+ if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) {
+ continue;
+ }
+
+ // Check if (the fanin) op's output port is pinned to HostMemory.
+ const OpDef* fanin_odef = nullptr;
+ Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find OpDef for : " << fanin.node->op();
+ return false;
+ }
+
+ const int output_arg_id =
+ OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id);
+ if (output_arg_id < 0) {
+ LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n"
+ << node.DebugString() << "\n"
+ << fanin_odef->DebugString();
+ return false;
+ }
+
+ const KernelDef* fanin_kdef = nullptr;
+ s = TryFindKernelDef(*fanin.node, &fanin_kdef);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op();
+ return false;
+ }
+
+ bool fanin_pinned = false;
+ for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) {
+ if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) {
+ fanin_pinned = true;
+ break;
+ }
+ }
+
+ if (!fanin_pinned) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
+ // Check if Tensor is integer and small size.
+
+ // Check type to be int32 or int64.
+ if (prop.dtype() != DataType::DT_INT32 &&
+ prop.dtype() != DataType::DT_INT64) {
+ return false;
+ }
+
+ // Check size known and small.
+ const int64 size = NumCoefficients(prop.shape());
+ if (size < 0 || size > kTensorMaxSize) {
+ return false;
+ }
+
+ return true;
+}
+
+bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties,
+ const NodeDef& node) {
+ for (const auto& prop : properties.GetInputProperties(node.name())) {
+ if (!IsTensorIntegerAndSmall(prop)) {
+ return false;
+ }
+ }
+
+ for (const auto& prop : properties.GetOutputProperties(node.name())) {
+ if (!IsTensorIntegerAndSmall(prop)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+string TryFindHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, const string& device) {
+ // Force this node onto the CPU.
+ if (device.empty() && has_device_cpu) {
+ return "/device:CPU:0";
+ } else if (str_util::StrContains(device, DEVICE_GPU)) {
+ // Sometimes the cluster can have:
+ // devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
+ // and we need to handle them properly.
+ for (const auto& device_match :
+ {std::pair<string, string>("GPU", "CPU:0"),
+ std::pair<string, string>("/device", "/device:CPU:0")}) {
+ const string device_host =
+ strings::StrCat(device.substr(0, device.rfind(device_match.first)),
+ device_match.second);
+ if (devices.find(device_host) != devices.end()) {
+ return device_host;
+ }
+ }
+ }
+
+ // We couldn't find an appropriate Host device, return original device.
+ return device;
+}
+
+// All the nodes that should be blacklisted and not swapped.
+bool IsBlacklisted(const NodeDef& node) { return IsCollective(node); }
+} // end namespace internal
+
+Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+
+ GraphProperties properties(item);
+ bool has_properties = false;
+ GraphView graph(optimized_graph);
+
+ gtl::FlatSet<string> devices;
+ if (cluster) {
+ const std::vector<string> device_names = cluster->GetDeviceNames();
+ devices.insert(device_names.begin(), device_names.end());
+ } else {
+ devices = {"/device:CPU:0"};
+ }
+
+ const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
+
+ // Topologically sort the graph, so that we traverse the nodes in order. This
+ // will help us discover producer->consumer chains of Host ops.
+ TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
+ for (auto& node : *optimized_graph->mutable_node()) {
+ // Check if node already on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ continue;
+ }
+
+ // Skip these node types.
+ if (internal::IsBlacklisted(node)) {
+ continue;
+ }
+
+ // Check the node can be run on CPU.
+ Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr);
+ if (!s.ok()) {
+ continue;
+ }
+
+ // Check all input's are pinned to CPU.
+ if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) {
+ continue;
+ }
+
+ if (!has_properties) {
+ // This is an expensive call, call it lazily.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ has_properties = true;
+ }
+
+ // Check all inputs and outputs are integers and small.
+ if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) {
+ continue;
+ }
+
+ // Try and swap the device to Host.
+ node.set_device(
+ internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
+ }
+ return Status::OK();
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
new file mode 100644
index 0000000000..d557a03463
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
+
+#include <unordered_set>
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace internal {
+// Try and find an appropriate Host device in `devices` given `device`.
+string TryFindHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, const string& device);
+} // end namespace internal
+
+// Optimize TensorFlow ops that should be swapped into the CPU to avoid
+// excessive cpu<->gpu memcpy/sync.
+//
+// TODO(williamchan): The current heuristic will swap any small integer Const to
+// CPU. This may cause a problem cpu->cpu->gpu wherein the original behaviour of
+// gpu->gpu->gpu may have been better/faster. We should probably fix this.
+class PinToHostOptimizer : public GraphOptimizer {
+ public:
+ PinToHostOptimizer() : opt_level_(RewriterConfig::DEFAULT) {}
+ explicit PinToHostOptimizer(RewriterConfig::Toggle opt_level)
+ : opt_level_(opt_level) {}
+
+ ~PinToHostOptimizer() override {}
+
+ string name() const override { return "pin_to_host_optimizer"; };
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ RewriterConfig::Toggle opt_level_;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
new file mode 100644
index 0000000000..339ddfd1b5
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -0,0 +1,162 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class PinToHostOptimizerTest : public GrapplerTest {};
+
+TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
+ gtl::FlatSet<string> devices = {};
+ EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
+
+ devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
+ "/device:CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
+ "/device:CPU:0");
+
+ devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
+ "/device:XLA_CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
+ "/device:XLA_CPU:0");
+
+ devices = {"/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
+ "/device:XLA_GPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
+ "/device:XLA_GPU:*");
+}
+
+TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
+ Output c = ops::Shape(s.WithOpName("c"), a);
+ Output d = ops::Const(s.WithOpName("d"), 0, {1});
+ Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
+
+ GrapplerItem item;
+ item.fetch = {"a", "c", "d", "e"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_TRUE(node.device().empty());
+ } else if (node.name() == "d" || node.name() == "e") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 4);
+}
+
+TEST_F(PinToHostOptimizerTest, TopologicalSort) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
+ Output c = ops::Shape(s.WithOpName("c"), a);
+ Output d = ops::Const(s.WithOpName("d"), 0, {1});
+ Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
+
+ GrapplerItem item;
+ item.fetch = {"a", "c", "d", "e"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ // Reverse the graph, and hence rely on the optimizer to sort it.
+ std::reverse(item.graph.mutable_node()->begin(),
+ item.graph.mutable_node()->end());
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_TRUE(node.device().empty());
+ } else if (node.name() == "d" || node.name() == "e") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 4);
+}
+
+TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});
+ ops::ShapeN b(s.WithOpName("b"), {a, a, a});
+
+ GrapplerItem item;
+ item.fetch = {"a", "b"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ ++found;
+ }
+ EXPECT_EQ(found, 2);
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc
index 910b0acaef..6266733f3e 100644
--- a/tensorflow/core/grappler/utils/grappler_test.cc
+++ b/tensorflow/core/grappler/utils/grappler_test.cc
@@ -30,13 +30,16 @@ GrapplerTest::GrapplerTest() {
// optimizations interfering in the comparison.
RewriterConfig* cfg =
options_.config.mutable_graph_options()->mutable_rewrite_options();
- cfg->set_constant_folding(RewriterConfig::OFF);
+ // TODO(rmlarsen): Add utility to generate config w/ all optimizers turned
+ // off.
cfg->set_arithmetic_optimization(RewriterConfig::OFF);
+ cfg->set_constant_folding(RewriterConfig::OFF);
+ cfg->set_debug_stripper(RewriterConfig::OFF);
cfg->set_dependency_optimization(RewriterConfig::OFF);
- cfg->set_loop_optimization(RewriterConfig::OFF);
cfg->set_function_optimization(RewriterConfig::OFF);
cfg->set_layout_optimizer(RewriterConfig::OFF);
- cfg->set_debug_stripper(RewriterConfig::OFF);
+ cfg->set_loop_optimization(RewriterConfig::OFF);
+ cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
}
std::vector<Tensor> GrapplerTest::EvaluateNodes(
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 08245e6ea0..ab69925d04 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -217,6 +217,19 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "extract_volume_patches_op",
+ prefix = "extract_volume_patches_op",
+ deps = [
+ ":bounds_check",
+ ":eigen_helpers",
+ ":ops_util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
cc_library(
name = "conv_3d",
hdrs = ["conv_3d.h"],
@@ -622,6 +635,7 @@ cc_library(
":diag_op",
":edit_distance_op",
":extract_image_patches_op",
+ ":extract_volume_patches_op",
":gather_nd_op",
":gather_op",
":guarantee_const_op",
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index b3c359010d..87efdff789 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -628,6 +628,20 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "multi_device_iterator_ops",
+ srcs = ["multi_device_iterator_ops.cc"],
+ deps = [
+ ":dataset",
+ ":dataset_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_kernel_library(
name = "optional_ops",
srcs = ["optional_ops.cc"],
hdrs = ["optional_ops.h"],
@@ -722,6 +736,7 @@ tf_kernel_library(
":map_dataset_op",
":map_defun_op",
":model_dataset_op",
+ ":multi_device_iterator_ops",
":optimize_dataset_op",
":optional_ops",
":padded_batch_dataset_op",
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index e7ac368ae3..e10833f525 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -44,5 +44,42 @@ Status MakeIteratorFromInputElement(
ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
}
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " types but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (expected[i] != received[i]) {
+ return errors::InvalidArgument("Data type mismatch at component ", i,
+ ": expected ", DataTypeString(expected[i]),
+ " but got ", DataTypeString(received[i]),
+ ".");
+ }
+ }
+ return Status::OK();
+}
+
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " shapes but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (!expected[i].IsCompatibleWith(received[i])) {
+ return errors::InvalidArgument("Incompatible shapes at component ", i,
+ ": expected ", expected[i].DebugString(),
+ " but got ", received[i].DebugString(),
+ ".");
+ }
+ }
+
+ return Status::OK();
+}
+
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 234856ea39..6ec1350cd4 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -27,6 +27,16 @@ Status MakeIteratorFromInputElement(
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator);
+// Returns Status::OK() if `expected` and `received` types match,
+// errors::InvalidArgument otherwise.
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received);
+
+// Returns Status::OK() if `expected` and `received` shapes are compatible,
+// errors::InvalidArgument otherwise.
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received);
+
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 30c6585ba2..c0bc507ec0 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -44,43 +44,6 @@ namespace {
const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
-Status VerifyTypesMatch(const DataTypeVector& expected,
- const DataTypeVector& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " types but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (expected[i] != received[i]) {
- return errors::InvalidArgument("Data type mismatch at component ", i,
- ": expected ", DataTypeString(expected[i]),
- " but got ", DataTypeString(received[i]),
- ".");
- }
- }
- return Status::OK();
-}
-
-Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
- const std::vector<PartialTensorShape>& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " shapes but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (!expected[i].IsCompatibleWith(received[i])) {
- return errors::InvalidArgument("Incompatible shapes at component ", i,
- ": expected ", expected[i].DebugString(),
- " but got ", received[i].DebugString(),
- ".");
- }
- }
-
- return Status::OK();
-}
-
} // namespace
class IteratorResource : public ResourceBase {
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
new file mode 100644
index 0000000000..5f143967d9
--- /dev/null
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -0,0 +1,633 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <deque>
+
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_op_kernel.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+struct HostBufferElement {
+ Status status;
+ bool end_of_sequence;
+ std::vector<Tensor> value;
+};
+
+using MultiDeviceIteratorCallback =
+ std::function<void(const HostBufferElement&)>;
+
+class MultiDeviceIterator : public ResourceBase {
+ public:
+ MultiDeviceIterator(const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes,
+ const std::vector<string>& devices,
+ std::unique_ptr<FunctionLibraryDefinition> flib_def,
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+ FunctionLibraryRuntime* lib)
+ : output_types_(output_types),
+ output_shapes_(output_shapes),
+ devices_(devices),
+ flib_def_(std::move(flib_def)),
+ pflr_(std::move(pflr)),
+ lib_(lib) {
+ DCHECK(lib_ != nullptr);
+ }
+
+ string DebugString() override {
+ return strings::StrCat("MultiDeviceIterator for ", devices_.size(),
+ " devices");
+ }
+
+ Status Init(std::unique_ptr<IteratorBase> iterator, int64 max_buffer_size,
+ int64* incarnation_id) {
+ if (iterator) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_types_, iterator->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
+ }
+
+ mutex_lock l(mu_);
+ if (multi_device_buffer_) {
+ multi_device_buffer_->Reset();
+ }
+
+ ++incarnation_id_;
+ *incarnation_id = incarnation_id_;
+
+ multi_device_buffer_.reset(
+ new MultiDeviceBuffer(devices_.size(), max_buffer_size, incarnation_id_,
+ std::move(iterator)));
+ return Status::OK();
+ }
+
+ void GetNextFromShard(IteratorContext* ctx, int shard_num,
+ int64 incarnation_id,
+ MultiDeviceIteratorCallback callback) {
+ if (lib_ != nullptr) {
+ ctx->set_lib(lib_);
+ }
+ tf_shared_lock l(mu_);
+ multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id,
+ std::move(callback));
+ }
+
+ const DataTypeVector& output_types() const { return output_types_; }
+
+ const std::vector<PartialTensorShape>& output_shapes() const {
+ return output_shapes_;
+ }
+
+ std::shared_ptr<const FunctionLibraryDefinition> function_library() {
+ tf_shared_lock l(mu_);
+ return lib_def_;
+ }
+
+ FunctionLibraryRuntime* const lib() {
+ tf_shared_lock l(mu_);
+ return lib_;
+ }
+
+ private:
+ // A private class that uses a background thread to keep a per device buffer
+ // full.
+ class MultiDeviceBuffer {
+ public:
+ MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id,
+ std::unique_ptr<IteratorBase> host_iterator)
+ : buffer_(size),
+ size_(size),
+ max_buffer_size_(max_buffer_size),
+ incarnation_id_(incarnation_id),
+ host_iterator_(std::move(host_iterator)) {}
+
+ ~MultiDeviceBuffer() {
+ {
+ mutex_lock l(mu_);
+ if (!background_thread_started_) return;
+ }
+ Reset();
+ }
+
+ void Reset() LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (background_thread_finished_) {
+ return;
+ }
+
+ cancelled_ = true;
+ // Wake up the background thread.
+ for (int i = 0; i < size_; ++i) {
+ buffer_[i].cond_var.notify_all();
+ }
+
+ // Make sure background thread has finished first.
+ while (!background_thread_finished_) {
+ shutdown_cond_var_.wait(l);
+ }
+ }
+ RunPendingCallbacks();
+ }
+
+ void GetNextFromShard(IteratorContext* ctx, int shard_num,
+ int64 incarnation_id,
+ MultiDeviceIteratorCallback callback) {
+ HostBufferElement elem;
+ if (incarnation_id_ != incarnation_id) {
+ elem.status = errors::InvalidArgument("Invalid incarnation id");
+ callback(elem);
+ return;
+ }
+
+ bool produced_output = false;
+ {
+ mutex_lock l(mu_);
+ if (cancelled_) {
+ elem.status = errors::Cancelled("Cancelled Multidevice iterator");
+ callback(elem);
+ return;
+ }
+
+ EnsureBackgroundThreadStarted(ctx);
+
+ if (!buffer_[shard_num].data.empty()) {
+ produced_output = true;
+ std::swap(elem, buffer_[shard_num].data.front());
+ buffer_[shard_num].data.pop_front();
+ // Wake up background thread if it is blocked on this element.
+ if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) {
+ buffer_[shard_num].cond_var.notify_all();
+ }
+ } else {
+ if (background_thread_finished_) {
+ produced_output = true;
+ elem.end_of_sequence = true;
+ } else {
+ buffer_[shard_num].callbacks.push_back(std::move(callback));
+ callback = nullptr;
+ }
+ }
+ }
+
+ if (produced_output) {
+ callback(elem);
+ }
+ }
+
+ private:
+ void EnsureBackgroundThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!background_thread_) {
+ background_thread_.reset(ctx->env()->StartThread(
+ {}, "multi_device_iterator_background_thread",
+ std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
+ this, new IteratorContext(*ctx))));
+ }
+ }
+
+ void RunPendingCallbacks() LOCKS_EXCLUDED(mu_) {
+ // Run all remaining callbacks.
+ std::vector<MultiDeviceIteratorCallback> cancellation_callbacks;
+ std::vector<HostBufferElement> cancellation_elements;
+ {
+ mutex_lock l(mu_);
+
+ for (int i = 0; i < size_; ++i) {
+ while (!buffer_[i].callbacks.empty()) {
+ if (buffer_[i].data.empty()) {
+ HostBufferElement elem;
+ elem.status =
+ errors::Cancelled("Cancelled and buffer not filled.");
+ cancellation_elements.push_back(std::move(elem));
+ } else {
+ cancellation_elements.push_back(
+ std::move(buffer_[i].data.front()));
+ buffer_[i].data.pop_front();
+ }
+ cancellation_callbacks.push_back(
+ std::move(buffer_[i].callbacks.front()));
+ buffer_[i].callbacks.pop_front();
+ }
+ }
+ }
+ for (int i = 0; i < cancellation_callbacks.size(); ++i) {
+ cancellation_callbacks[i](cancellation_elements[i]);
+ }
+ }
+
+ void BackgroundThread(IteratorContext* ctx) {
+ {
+ mutex_lock l(mu_);
+ background_thread_started_ = true;
+ }
+ std::unique_ptr<IteratorContext> cleanup(ctx);
+ int shard_to_fetch = 0;
+ while (true) {
+ HostBufferElement elem;
+ MultiDeviceIteratorCallback callback = nullptr;
+ bool end_of_iterator = false;
+
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ buffer_[shard_to_fetch].data.size() >= max_buffer_size_) {
+ buffer_[shard_to_fetch].cond_var.wait(l);
+ }
+
+ if (cancelled_) {
+ background_thread_finished_ = true;
+ shutdown_cond_var_.notify_all();
+ return;
+ }
+ }
+
+ elem.status =
+ host_iterator_->GetNext(ctx, &elem.value, &elem.end_of_sequence);
+
+ if (elem.status.ok() && elem.end_of_sequence) {
+ end_of_iterator = true;
+ }
+
+ {
+ mutex_lock l(mu_);
+ // Try to find a callback, else just push stuff into buffer.
+ if (!buffer_[shard_to_fetch].callbacks.empty()) {
+ callback = buffer_[shard_to_fetch].callbacks.front();
+ buffer_[shard_to_fetch].callbacks.pop_front();
+ } else {
+ buffer_[shard_to_fetch].data.push_back(std::move(elem));
+ elem = HostBufferElement();
+ }
+ }
+
+ if (callback) {
+ (*ctx->runner())(std::bind(std::move(callback), std::move(elem)));
+ }
+
+ // Finish off the thread if we reach the end of the iterator. Runs
+ // pending callbacks.
+ if (end_of_iterator) {
+ {
+ mutex_lock l(mu_);
+ background_thread_finished_ = true;
+ shutdown_cond_var_.notify_all();
+ }
+ RunPendingCallbacks();
+ return;
+ }
+ shard_to_fetch = (shard_to_fetch + 1) % size_;
+ }
+ }
+
+ struct HostBuffer {
+ condition_variable cond_var;
+ std::deque<HostBufferElement> data;
+ std::deque<MultiDeviceIteratorCallback> callbacks;
+ };
+
+ mutex mu_;
+ std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
+ bool background_thread_finished_ GUARDED_BY(mu_) = false;
+ bool background_thread_started_ GUARDED_BY(mu_) = false;
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
+
+ std::vector<HostBuffer> buffer_;
+
+ const size_t size_;
+ const int64 max_buffer_size_;
+ const int64 incarnation_id_;
+ const std::unique_ptr<IteratorBase> host_iterator_;
+ };
+
+ mutex mu_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ const std::vector<string> devices_;
+ const std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+ const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ FunctionLibraryRuntime* const lib_ = nullptr; // not owned.
+ std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
+
+ int64 incarnation_id_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_);
+};
+
+// Just creates a MultiDeviceIterator and returns it.
+class MultiDeviceIteratorHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_));
+ }
+
+ // The resource is deleted from the resource manager only when it is private
+ // to kernel.
+ ~MultiDeviceIteratorHandleOp() override {
+ if (resource_ != nullptr) {
+ resource_->Unref();
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->template Delete<MultiDeviceIterator>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (resource_ == nullptr) {
+ FunctionLibraryRuntime* lib;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
+ OP_REQUIRES_OK(context, context->function_library()->Clone(
+ &flib_def, &pflr, &lib));
+ ResourceMgr* mgr = context->resource_manager();
+ OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
+
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(
+ context,
+ mgr->LookupOrCreate<MultiDeviceIterator>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, lib, &flib_def, &pflr](MultiDeviceIterator** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new MultiDeviceIterator(
+ output_types_, output_shapes_, devices_,
+ std::move(flib_def), std::move(pflr), lib);
+ return Status::OK();
+ }));
+
+ Status s = VerifyResource(resource);
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ resource->Unref();
+ context->SetStatus(s);
+ return;
+ }
+
+ resource_ = resource;
+ }
+ }
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<MultiDeviceIterator>()));
+ }
+
+ private:
+ // During the first Compute(), resource is either created or looked up using
+ // shared_name. In the latter case, the resource found should be verified if
+ // it is compatible with this op's configuration. The verification may fail in
+ // cases such as two graphs asking queues of the same shared name to have
+ // inconsistent capacities.
+ Status VerifyResource(MultiDeviceIterator* resource) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_types_, resource->output_types()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
+ return Status::OK();
+ }
+
+ mutex mu_;
+ ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
+ MultiDeviceIterator* resource_ GUARDED_BY(mu_) = nullptr;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ const int graph_def_version_;
+ string name_;
+ string container_;
+ std::vector<string> devices_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU),
+ MultiDeviceIteratorHandleOp);
+
+// Calls init on the MultiDeviceIterator.
+class MultiDeviceIteratorInitOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorInitOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* tensor_max_buffer_size;
+ OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size));
+ int64 max_buffer_size = tensor_max_buffer_size->scalar<int64>()();
+
+ DatasetBase* dataset;
+ OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
+ core::ScopedUnref unref(resource);
+
+ std::unique_ptr<IteratorBase> iterator;
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(resource->lib());
+ OP_REQUIRES_OK(
+ ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
+ int64 incarnation_id;
+ OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
+ &incarnation_id));
+ Tensor tensor_incarnation_id(DT_INT64, TensorShape({}));
+ tensor_incarnation_id.scalar<int64>()() = incarnation_id;
+ OP_REQUIRES_OK(ctx,
+ ctx->set_output("incarnation_id", tensor_incarnation_id));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MultiDeviceIteratorInit").Device(DEVICE_CPU),
+ MultiDeviceIteratorInitOp);
+
+// Calls GetNextFromShard(shard) and returns a vector of Tensors as output.
+// TODO(rohanj): Implement using BackgroundWorker that Derek built?
+class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
+ public:
+ explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ thread_pool_(new thread::ThreadPool(
+ ctx->env(), ThreadOptions(),
+ strings::StrCat("multi_device_iterator_get_next_thread_",
+ SanitizeThreadSuffix(name())),
+ 1 /* num_threads */, false /* low_latency_hint */)) {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ const Tensor* tensor_shard_num;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done);
+ int32 shard_num = tensor_shard_num->scalar<int32>()();
+
+ const Tensor* tensor_incarnation_id;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
+ int64 incarnation_id = tensor_incarnation_id->scalar<int64>()();
+
+ MultiDeviceIterator* iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
+ thread_pool_->Schedule(std::bind(
+ [ctx, iterator, shard_num, incarnation_id](DoneCallback done) {
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.function_library = iterator->function_library();
+ DeviceBase* device = ctx->function_library()->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ IteratorContext iter_ctx(std::move(params));
+
+ MultiDeviceIteratorCallback callback = std::bind(
+ [ctx](const HostBufferElement& elem, DoneCallback done) {
+ // iterator->Unref();
+ Status s = elem.status;
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else if (elem.end_of_sequence) {
+ ctx->SetStatus(errors::OutOfRange("End of sequence"));
+ } else {
+ for (int i = 0; i < elem.value.size(); ++i) {
+ ctx->set_output(i, elem.value[i]);
+ }
+ }
+ done();
+ },
+ std::placeholders::_1, std::move(done));
+
+ iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
+ callback);
+ iterator->Unref();
+ },
+ std::move(done)));
+ }
+
+ private:
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorGetNextFromShard").Device(DEVICE_CPU),
+ MultiDeviceIteratorGetNextFromShardOp);
+
+class MultiDeviceIteratorToStringHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorToStringHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& resource_handle_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
+ errors::InvalidArgument("resource_handle must be a scalar"));
+
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an MultiDeviceIterator.
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ resource->Unref();
+
+ Tensor* string_handle_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &string_handle_t));
+ string_handle_t->scalar<string>()() =
+ resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorToStringHandle").Device(DEVICE_CPU),
+ MultiDeviceIteratorToStringHandleOp);
+
+class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES(
+ ctx,
+ output_types_.empty() || output_shapes_.empty() ||
+ output_types_.size() == output_shapes_.size(),
+ errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
+ "are set, they must have the same length."));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& string_handle_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
+ errors::InvalidArgument("string_handle must be a scalar"));
+
+ ResourceHandle resource_handle;
+ OP_REQUIRES(
+ ctx,
+ resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
+ errors::InvalidArgument(
+ "Could not parse string_handle as a valid ResourceHandle"));
+
+ OP_REQUIRES(
+ ctx, resource_handle.device() == ctx->device()->attributes().name(),
+ errors::InvalidArgument("Attempted create an iterator on device \"",
+ ctx->device()->attributes().name(),
+ "\" from handle defined on device \"",
+ resource_handle.device(), "\""));
+
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an MultiDeviceIterator.
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
+ core::ScopedUnref unref_iterator(resource);
+ if (!output_types_.empty()) {
+ OP_REQUIRES_OK(ctx,
+ VerifyTypesMatch(output_types_, resource->output_types()));
+ }
+ if (!output_shapes_.empty()) {
+ OP_REQUIRES_OK(ctx, VerifyShapesCompatible(output_shapes_,
+ resource->output_shapes()));
+ }
+
+ Tensor* resource_handle_t;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
+ resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
+ }
+
+ private:
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU),
+ MultiDeviceIteratorFromStringHandleOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
index 346e4ceebd..2ab5c83082 100644
--- a/tensorflow/core/kernels/data/optional_ops.cc
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -213,6 +213,14 @@ static Status OptionalDeviceCopy(
std::vector<Tensor> to_values;
to_values.reserve(from_values.size());
for (const Tensor& t : from_values) {
+ if (t.dtype() == DT_VARIANT) {
+ // TODO(b/116349787): Implement support for nested variants.
+ return errors::Unimplemented(
+ "Support for copying nested variants to device has not yet been "
+ "implemented.");
+ }
+ }
+ for (const Tensor& t : from_values) {
if (DMAHelper::CanUseDMA(&t)) {
Tensor tmp(t.dtype());
TF_RETURN_IF_ERROR(copy(t, &tmp));
diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h
index 37414ddca3..6a9a2accd8 100644
--- a/tensorflow/core/kernels/eigen_cuboid_convolution.h
+++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h
@@ -113,6 +113,11 @@ class TensorContractionInputMapper<
m_num_patches = tensor.impl().dimensions()[NumDims - 5];
}
+ // Strides for navigating through the single patch.
+ m_patch_plane_stride = m_patch_depth;
+ m_patch_row_stride = m_patch_planes * m_patch_plane_stride;
+ m_patch_col_stride = m_patch_rows * m_patch_row_stride;
+
// Strides for the output tensor.
// IMPORTANT: These strides are used to locate an element in a patch at a
// depth zero (channel), which is not quite the same as "traditional"
@@ -166,6 +171,13 @@ class TensorContractionInputMapper<
m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
+ m_fastPatchPlaneStride =
+ internal::TensorIntDivisor<Index>(m_patch_plane_stride);
+ m_fastPatchRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_stride);
+ m_fastPatchColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_stride);
+
m_fastInputPlaneStride =
internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides);
m_fastInputRowStride =
@@ -195,6 +207,10 @@ class TensorContractionInputMapper<
m_patch_cols = base_mapper.m_patch_cols;
m_num_patches = base_mapper.m_num_patches;
+ m_patch_plane_stride = base_mapper.m_patch_plane_stride;
+ m_patch_row_stride = base_mapper.m_patch_row_stride;
+ m_patch_col_stride = base_mapper.m_patch_col_stride;
+
m_rowStride = base_mapper.m_rowStride;
m_colStride = base_mapper.m_colStride;
m_patchStride = base_mapper.m_patchStride;
@@ -234,6 +250,9 @@ class TensorContractionInputMapper<
m_outputPlanesRows = base_mapper.m_outputPlanesRows;
m_fastNumPatches = base_mapper.m_fastNumPatches;
+ m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride;
+ m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
+ m_fastPatchColStride = base_mapper.m_fastPatchColStride;
m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride;
m_fastInputRowStride = base_mapper.m_fastInputRowStride;
m_fastInputColStride = base_mapper.m_fastInputColStride;
@@ -305,9 +324,9 @@ class TensorContractionInputMapper<
}
EIGEN_DEVICE_FUNC
- EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_patch_depth; }
+ EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; }
EIGEN_DEVICE_FUNC
- EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_patch_planes; }
+ EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; }
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; }
EIGEN_DEVICE_FUNC
@@ -594,7 +613,12 @@ class TensorContractionInputMapper<
Index m_patch_cols; // number of columns in the patch
Index m_num_patches; // number of patches to extract
- // Strides for the output tensor.
+ // Strides for navigating through the single patch.
+ Index m_patch_plane_stride;
+ Index m_patch_row_stride;
+ Index m_patch_col_stride;
+
+ // Strides for the output tensor (depth is not the part of the stride).
Index m_rowStride;
Index m_colStride;
Index m_patchStride;
@@ -637,6 +661,10 @@ class TensorContractionInputMapper<
// Fast representation of various divisors.
internal::TensorIntDivisor<Index> m_fastNumPatches;
+ internal::TensorIntDivisor<Index> m_fastPatchPlaneStride;
+ internal::TensorIntDivisor<Index> m_fastPatchRowStride;
+ internal::TensorIntDivisor<Index> m_fastPatchColStride;
+
internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
internal::TensorIntDivisor<Index> m_fastInputRowStride;
internal::TensorIntDivisor<Index> m_fastInputColStride;
@@ -750,13 +778,62 @@ class TensorContractionSubMapper<
return m_base_mapper.nonStandardPatches();
}
+ // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row,
+ // plane and depth index respectively that fits into the peeled_k elements
+ // starting at m_depth_offset.
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
+ const Index max_col =
+ fastPatchColStride().divide(m_depth_offset + peeled_k);
+ return std::min<Index>(1 + max_col, patchCols());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
+ const Index col) const {
+ const Index max_row = fastPatchRowStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride());
+ return std::min<Index>(1 + max_row, patchRows());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col,
+ const Index row) const {
+ const Index max_plane = fastPatchPlaneStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride() -
+ row * patchRowStride());
+ return std::min<Index>(1 + max_plane, patchPlanes());
+ }
+
+ // MaxDepth uses only the remaining number of elements in the peeled_k.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
+ const Index start_depth) const {
+ return std::min<Index>(start_depth + num_elements, patchDepth());
+ }
+
+ // Every register matters in this code, so sometimes to prevent register
+ // spilling, instead of the variable that you would expect to see, we use
+ // another one, that is guaranteed to have the same value. E.g. patch depth is
+ // always the same as input depth, and it's also the same as input plane
+ // stride. Bunch of other parameters have similar relations.
+
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchDepth() const {
- return m_base_mapper.m_patch_depth;
+ eigen_assert(m_base_mapper.m_patch_depth ==
+ m_base_mapper.m_planeInputStride &&
+ "Patch depth must be equal to plane input stride.");
+ return m_base_mapper.m_planeInputStride;
}
+
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchPlanes() const {
- return m_base_mapper.m_patch_planes;
+ eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride &&
+ "Patch planes must be equal to row stride.");
+ return m_base_mapper.m_rowStride;
}
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchRows() const {
@@ -768,6 +845,36 @@ class TensorContractionSubMapper<
}
EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlaneStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
+ "Patch depth must be equal to patch plane stride.");
+ return patchDepth();
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRowStride() const {
+ return m_base_mapper.m_patch_row_stride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchColStride() const {
+ return m_base_mapper.m_patch_col_stride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
+ "Patch depth must be equal to patch plane stride.");
+ return m_base_mapper.m_fastDimZero; // patch_depth
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
+ return m_base_mapper.m_fastPatchRowStride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
+ return m_base_mapper.m_fastPatchColStride;
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
const Index baseIndex) const {
const Index inputIndex = depth + baseIndex;
@@ -832,8 +939,7 @@ class TensorContractionSubMapper<
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index depthOffset() const {
- const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
- return patchOffset;
+ return m_depth_offset % patchDepth();
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
@@ -859,14 +965,14 @@ class TensorContractionSubMapper<
// matrix" constructed from extracted volume patches) in contiguous memory.
//
// Given column major input (A0 beside A1 in memory):
-// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
-// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
-// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
-// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
-// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
-// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
-// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
-// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
+// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
+// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
+// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
+// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
+// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
+// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
+// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
+// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
// A8 ...
// ...
//
@@ -910,7 +1016,11 @@ struct gemm_pack_rhs<
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
inner_dim_reordered, Alignment>
SubMapper;
+
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -919,9 +1029,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
const bool non_standard_patches = rhs.nonStandardPatches();
@@ -934,81 +1041,58 @@ struct gemm_pack_rhs<
Index k = 0;
if ((packet_size % 4) == 0 && !non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
-
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
- const Index patch_planes = rhs.patchPlanes();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) +
- startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
-
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- Eigen::divup(
- peeled_k - c * patch_rows * patch_planes * patch_depth,
- patch_planes * patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns, rows and planes if we know that a single
+ // packet do not span across multiple planes, rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
- const Index startPlane =
- ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0;
- const Index max_planes = std::min<Index>(
- Eigen::divup(
- peeled_k -
- c * patch_rows * patch_planes * patch_depth - // col
- r * patch_planes * patch_depth, // row
- patch_depth) +
- startPlane,
- patch_planes);
+ const Index start_plane = ((c == start_col) && (r == start_row))
+ ? rhs.planeOffset()
+ : 0;
+ const Index max_plane = rhs.maxPlane(peeled_k, c, r);
- const bool pad_row0 = dm0.padRow(r);
- const bool pad_row1 = dm1.padRow(r);
- const bool pad_row2 = dm2.padRow(r);
- const bool pad_row3 = dm3.padRow(r);
+ const bool pad_row0 = pad_col0 || dm0.padRow(r);
+ const bool pad_row1 = pad_col1 || dm1.padRow(r);
+ const bool pad_row2 = pad_col2 || dm2.padRow(r);
+ const bool pad_row3 = pad_col3 || dm3.padRow(r);
- for (Index p = startPlane; p < max_planes; ++p) {
- eigen_assert(k < peeled_k);
+ for (Index p = start_plane; p < max_plane; ++p) {
+ eigen_assert(k <= peeled_k);
- const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
- const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
- const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
- const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
+ const bool pad0 = pad_row0 || dm0.padPlane(p);
+ const bool pad1 = pad_row1 || dm1.padPlane(p);
+ const bool pad2 = pad_row2 || dm2.padPlane(p);
+ const bool pad3 = pad_row3 || dm3.padPlane(p);
const Index idx0 = dm0.baseIndex(p, r, c);
const Index idx1 = dm1.baseIndex(p, r, c);
const Index idx2 = dm2.baseIndex(p, r, c);
const Index idx3 = dm3.baseIndex(p, r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow) && (p == startPlane))
+ const Index start_depth =
+ ((c == start_col) && (r == start_row) && (p == start_plane))
? rhs.depthOffset()
: 0;
- const Index max_depth = std::min<Index>(
- peeled_k -
- c * patch_rows * patch_planes * patch_depth - // col
- r * patch_planes * patch_depth - // row
- p * patch_depth + // plane
- startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
-
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 4> kernel;
kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
@@ -1031,20 +1115,12 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 4> kernel;
- kernel.packet[0] = dm0.loadPacketFast(k);
- kernel.packet[1] = dm1.loadPacketFast(k);
- kernel.packet[2] = dm2.loadPacketFast(k);
- kernel.packet[3] = dm3.loadPacketFast(k);
- ptranspose(kernel);
- pstoreu(block + 0 * packet_size, kernel.packet[0]);
- pstoreu(block + 1 * packet_size, kernel.packet[1]);
- pstoreu(block + 2 * packet_size, kernel.packet[2]);
- pstoreu(block + 3 * packet_size, kernel.packet[3]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
+ // Packet can span multiple planes, rows or columns, so we have to go
+ // though the slower "standard" path.
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 4> kernel;
kernel.packet[0] = dm0.loadPacketStandard(k);
@@ -1060,7 +1136,9 @@ struct gemm_pack_rhs<
}
}
}
- if (!rhs.nonStandardPatches()) {
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!non_standard_patches) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
block[1] = dm1.loadCoeffStandard(k);
@@ -1079,7 +1157,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1118,6 +1196,9 @@ struct gemm_pack_rhs<
inner_dim_reordered, Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -1126,9 +1207,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const int packet_size = 2;
const Index packet_cols4 = (cols / 4) * 4;
@@ -1143,56 +1221,39 @@ struct gemm_pack_rhs<
Index k = 0;
if (!non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
-
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
- const Index patch_planes = rhs.patchPlanes();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) +
- startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
-
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- Eigen::divup(
- peeled_k - c * patch_rows * patch_planes * patch_depth,
- patch_planes * patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns, rows and planes if we know that a single
+ // packet do not span across multiple planes, rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
- const Index startPlane =
- ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0;
- const Index max_planes = std::min<Index>(
- Eigen::divup(
- peeled_k -
- c * patch_rows * patch_planes * patch_depth - // col
- r * patch_planes * patch_depth, // row
- patch_depth) +
- startPlane,
- patch_planes);
+ const Index start_plane = ((c == start_col) && (r == start_row))
+ ? rhs.planeOffset()
+ : 0;
+ const Index max_plane = rhs.maxPlane(peeled_k, c, r);
const bool pad_row0 = dm0.padRow(r);
const bool pad_row1 = dm1.padRow(r);
const bool pad_row2 = dm2.padRow(r);
const bool pad_row3 = dm3.padRow(r);
- for (Index p = startPlane; p < max_planes; ++p) {
- eigen_assert(k < peeled_k);
+ for (Index p = start_plane; p < max_plane; ++p) {
+ eigen_assert(k <= peeled_k);
const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
@@ -1204,20 +1265,14 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(p, r, c);
const Index idx3 = dm3.baseIndex(p, r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow) && (p == startPlane))
+ const Index start_depth =
+ ((c == start_col) && (r == start_row) && (p == start_plane))
? rhs.depthOffset()
: 0;
- const Index max_depth = std::min<Index>(
- peeled_k -
- c * patch_rows * patch_planes * patch_depth - // col
- r * patch_planes * patch_depth - // row
- p * patch_depth + // plane
- startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
-
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -1242,21 +1297,9 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 2> kernel0;
- PacketBlock<Packet, 2> kernel1;
- kernel0.packet[0] = dm0.loadPacketFast(k);
- kernel0.packet[1] = dm1.loadPacketFast(k);
- kernel1.packet[0] = dm2.loadPacketFast(k);
- kernel1.packet[1] = dm3.loadPacketFast(k);
- ptranspose(kernel0);
- ptranspose(kernel1);
- pstoreu(block + 0 * packet_size, kernel0.packet[0]);
- pstoreu(block + 1 * packet_size, kernel1.packet[0]);
- pstoreu(block + 2 * packet_size, kernel0.packet[1]);
- pstoreu(block + 3 * packet_size, kernel1.packet[1]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 2> kernel0;
@@ -1275,6 +1318,8 @@ struct gemm_pack_rhs<
}
}
}
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
if (!rhs.nonStandardPatches()) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
@@ -1294,7 +1339,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1333,6 +1378,8 @@ struct gemm_pack_rhs<
SubMapper;
typedef SubMapper DataMapper;
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
Index depth, Index cols, Index stride = 0,
@@ -1340,8 +1387,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
-
const Index packet_cols4 = (cols / 4) * 4;
for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
@@ -1369,7 +1414,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h
index 8bd362db45..e926d73f87 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h
@@ -105,12 +105,17 @@ class TensorContractionInputMapper<
m_patch_cols = tensor.impl().dimensions()[2];
m_num_patches = tensor.impl().dimensions()[3];
} else {
- const int NumDims = tensor.impl().dimensions().size();
+ const size_t NumDims = tensor.impl().dimensions().size();
patch_depth = tensor.impl().dimensions()[NumDims - 1];
patch_rows = tensor.impl().dimensions()[NumDims - 2];
m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
m_num_patches = tensor.impl().dimensions()[NumDims - 4];
}
+
+ // Strides for navigating through the single patch.
+ m_patch_row_stride = patch_depth;
+ m_patch_col_stride = patch_rows * m_patch_row_stride;
+
m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
m_patch_col_inflate_strides = tensor.impl().colInflateStride();
@@ -139,6 +144,10 @@ class TensorContractionInputMapper<
m_rowPaddingTop = tensor.impl().rowPaddingTop();
m_colPaddingLeft = tensor.impl().colPaddingLeft();
+ m_fastPatchRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_stride);
+ m_fastPatchColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_stride);
m_fastInputRowStride =
internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
m_fastInputColStride =
@@ -154,6 +163,10 @@ class TensorContractionInputMapper<
: m_impl(base_mapper.m_impl) {
m_patch_cols = base_mapper.m_patch_cols;
m_num_patches = base_mapper.m_num_patches;
+
+ m_patch_row_stride = base_mapper.m_patch_row_stride;
+ m_patch_col_stride = base_mapper.m_patch_col_stride;
+
m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
@@ -176,6 +189,8 @@ class TensorContractionInputMapper<
m_rowPaddingTop = base_mapper.m_rowPaddingTop;
m_colPaddingLeft = base_mapper.m_colPaddingLeft;
+ m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
+ m_fastPatchColStride = base_mapper.m_fastPatchColStride;
m_fastInputRowStride = base_mapper.m_fastInputRowStride;
m_fastInputColStride = base_mapper.m_fastInputColStride;
m_fastNumPatches = base_mapper.m_fastNumPatches;
@@ -450,8 +465,15 @@ class TensorContractionInputMapper<
rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
}
- Index m_patch_cols; // number of colums in the patch
- Index m_num_patches; // number of patches to extract.
+ Index m_patch_cols; // number of columns in the patch
+ Index m_num_patches; // number of patches to extract.
+
+ // Strides for navigating through the single patch.
+ Index m_patch_row_stride;
+ Index m_patch_col_stride;
+ internal::TensorIntDivisor<Index> m_fastPatchRowStride;
+ internal::TensorIntDivisor<Index> m_fastPatchColStride;
+
Index m_patch_row_inflate_strides; // the strides for row inflation in the
// image patch
Index m_patch_col_inflate_strides; // the strides for col inflation in the
@@ -585,6 +607,40 @@ class TensorContractionSubMapper<
return m_base_mapper.nonStandardPatches();
}
+ // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
+ // index respectively that fits into the peeled_k elements starting at
+ // m_depth_offset.
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
+ const Index max_col =
+ fastPatchColStride().divide(m_depth_offset + peeled_k);
+ return std::min<Index>(1 + max_col, patchCols());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
+ const Index col) const {
+ const Index max_row = fastPatchRowStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride());
+ return std::min<Index>(1 + max_row, patchRows());
+ }
+
+ // MaxDepth uses only the remaining number of elements in the peeled_k.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
+ const Index start_depth) const {
+ return std::min<Index>(start_depth + num_elements, patchDepth());
+ }
+
+ // Every register matters in this code, so sometimes to prevent register
+ // spilling, instead of the variable that you would expect to see, we use
+ // another one, that is guaranteed to have the same value. E.g. patch depth is
+ // always the same as input depth, and it's also the same as input row stride.
+ // Bunch of other parameters have similar relations.
+
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchDepth() const {
return m_base_mapper.m_rowInputStride;
@@ -599,6 +655,28 @@ class TensorContractionSubMapper<
}
EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRowStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
+ "Patch depth must be equal to patch row stride.");
+ return patchDepth();
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchColStride() const {
+ return m_base_mapper.m_patch_col_stride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
+ "Patch depth must be equal to patch row stride.");
+ return m_base_mapper.m_fastDimZero; // patch_depth
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
+ return m_base_mapper.m_fastPatchColStride;
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
const Index baseIndex) const {
const Index inputIndex = depth + baseIndex;
@@ -639,8 +717,7 @@ class TensorContractionSubMapper<
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index depthOffset() const {
- const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
- return patchOffset;
+ return m_depth_offset % patchDepth();
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
@@ -665,14 +742,14 @@ class TensorContractionSubMapper<
// matrix" constructed from extracted image patches) in contiguous memory.
//
// Given column major input (A0 beside A1 in memory):
-// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
-// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
-// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
-// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
-// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
-// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
-// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
-// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
+// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
+// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
+// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
+// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
+// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
+// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
+// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
+// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
// A8 ...
// ...
//
@@ -717,9 +794,9 @@ struct gemm_pack_rhs<
inner_dim_reordered, Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -728,9 +805,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
const bool non_standard_patches = rhs.nonStandardPatches();
@@ -743,30 +817,27 @@ struct gemm_pack_rhs<
Index k = 0;
if ((packet_size % 4) == 0 && !non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns and rows, if we know that a single
+ // packet do not span across multiple rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
const bool pad0 = pad_col0 || dm0.padRow(r);
const bool pad1 = pad_col1 || dm1.padRow(r);
const bool pad2 = pad_col2 || dm2.padRow(r);
@@ -777,14 +848,13 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(r, c);
const Index idx3 = dm3.baseIndex(r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
- const Index max_depth =
- std::min<Index>(peeled_k - c * patch_rows * patch_depth -
- r * patch_depth + startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index start_depth = ((c == start_col) && (r == start_row))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 4> kernel;
kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
@@ -806,19 +876,9 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 4> kernel;
- kernel.packet[0] = dm0.loadPacketFast(k);
- kernel.packet[1] = dm1.loadPacketFast(k);
- kernel.packet[2] = dm2.loadPacketFast(k);
- kernel.packet[3] = dm3.loadPacketFast(k);
- ptranspose(kernel);
- pstoreu(block + 0 * packet_size, kernel.packet[0]);
- pstoreu(block + 1 * packet_size, kernel.packet[1]);
- pstoreu(block + 2 * packet_size, kernel.packet[2]);
- pstoreu(block + 3 * packet_size, kernel.packet[3]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 4> kernel;
@@ -835,6 +895,8 @@ struct gemm_pack_rhs<
}
}
}
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
if (!rhs.nonStandardPatches()) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
@@ -892,9 +954,9 @@ struct gemm_pack_rhs<
Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -903,9 +965,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const int packet_size = 2;
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
@@ -919,30 +978,27 @@ struct gemm_pack_rhs<
Index k = 0;
if (!non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns and rows if we know that a single
+ // packet do not span across multiple rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
const bool pad0 = pad_col0 || dm0.padRow(r);
const bool pad1 = pad_col1 || dm1.padRow(r);
const bool pad2 = pad_col2 || dm2.padRow(r);
@@ -953,14 +1009,13 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(r, c);
const Index idx3 = dm3.baseIndex(r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
- const Index max_depth =
- std::min<Index>(peeled_k - c * patch_rows * patch_depth -
- r * patch_depth + startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index start_depth = ((c == start_col) && (r == start_row))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -984,22 +1039,12 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 2> kernel0;
- PacketBlock<Packet, 2> kernel1;
- kernel0.packet[0] = dm0.loadPacketFast(k);
- kernel0.packet[1] = dm1.loadPacketFast(k);
- kernel1.packet[0] = dm2.loadPacketFast(k);
- kernel1.packet[1] = dm3.loadPacketFast(k);
- ptranspose(kernel0);
- ptranspose(kernel1);
- pstoreu(block + 0 * packet_size, kernel0.packet[0]);
- pstoreu(block + 1 * packet_size, kernel1.packet[0]);
- pstoreu(block + 2 * packet_size, kernel0.packet[1]);
- pstoreu(block + 3 * packet_size, kernel1.packet[1]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
+ // Packet can span multiple rows or columns, so we have to go
+ // though the slower "standard" path.
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -1017,7 +1062,9 @@ struct gemm_pack_rhs<
}
}
}
- if (!rhs.nonStandardPatches()) {
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!non_standard_patches) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
block[1] = dm1.loadCoeffStandard(k);
@@ -1036,7 +1083,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1074,8 +1121,7 @@ struct gemm_pack_rhs<
SubMapper;
typedef SubMapper DataMapper;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -1084,8 +1130,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
-
const Index packet_cols4 = (cols / 4) * 4;
for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
@@ -1113,7 +1157,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
diff --git a/tensorflow/core/kernels/extract_volume_patches_op.cc b/tensorflow/core/kernels/extract_volume_patches_op.cc
new file mode 100644
index 0000000000..52cd078a35
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/*
+See extract_image_patches_op* files and docs for extract_image_patches in
+../ops/image_ops.cc.
+
+Rates are not supported as of now, but the comments hint how to edit the code
+when rates are to be added.
+*/
+
+#define USE_EIGEN_TENSOR
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/extract_volume_patches_op.h"
+#include <vector>
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+static inline void ParseAttributeVec5(OpKernelConstruction* context,
+ const string& attr_name,
+ std::vector<int32>* attr) {
+ OP_REQUIRES_OK(context, context->GetAttr(attr_name, attr));
+ OP_REQUIRES(
+ context, (*attr)[0] == 1 && (*attr)[4] == 1,
+ errors::Unimplemented("Only support ", attr_name, " across space."));
+ OP_REQUIRES(context, (*attr)[1] >= 1 && (*attr)[2] >= 1 && (*attr)[3] >= 1,
+ errors::OutOfRange(attr_name, " is out of range."));
+}
+
+template <typename Device, typename T>
+class ExtractVolumePatchesOp : public UnaryOp<T> {
+ public:
+ explicit ExtractVolumePatchesOp(OpKernelConstruction* context)
+ : UnaryOp<T>(context) {
+ ParseAttributeVec5(context, "ksizes", &ksizes_);
+ ParseAttributeVec5(context, "strides", &strides_);
+ // ParseAttributeVec5(context, "rates", &rates_);
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Input tensor is of the following dimensions:
+ // [ batch, in_planes, in_rows, in_cols, channels ]
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, input.dims() == 5,
+ errors::InvalidArgument("input must be 5-dimensional",
+ input.shape().DebugString()));
+
+ const int batch = input.dim_size(0);
+ const int in_planes = input.dim_size(1);
+ const int in_rows = input.dim_size(2);
+ const int in_cols = input.dim_size(3);
+ const int depth = input.dim_size(4);
+
+ const int ksize_planes = ksizes_[1];
+ const int ksize_rows = ksizes_[2];
+ const int ksize_cols = ksizes_[3];
+
+ const int stride_planes = strides_[1];
+ const int stride_rows = strides_[2];
+ const int stride_cols = strides_[3];
+
+ /*
+ // TODO(hsgkim): enable rates
+ // Rates are disabled as of now due to Eigen's definitions of
+ // `extract_volume_patch` functions; none of them accept rates
+ // as its argument and rates are fixed to (1, 1, 1, 1, 1). A
+ // workaround has to be found for this.
+ // In order to enable rates, uncomment the following lines and use
+ // ksize_*_eff instead of ksize_* for the second argument of
+ // GetWindowedOutputSize calls.
+
+ const int rate_planes = rates_[1];
+ const int rate_rows = rates_[2];
+ const int rate_cols = rates_[3];
+
+ const int ksize_planes_eff = ksize_planes +
+ (ksize_planes - 1) * (rate_planes - 1);
+ const int ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
+ const int ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
+ */
+
+ int64 out_planes = 0, out_rows = 0, out_cols = 0;
+ int64 pad_planes = 0, pad_rows = 0, pad_cols = 0;
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_planes, ksize_planes, stride_planes,
+ padding_, &out_planes, &pad_planes));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_rows, ksize_rows, stride_rows,
+ padding_, &out_rows, &pad_rows));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_cols, ksize_cols, stride_cols,
+ padding_, &out_cols, &pad_cols));
+
+ const std::vector<int64> out_sizes = {
+ batch, out_planes, out_rows, out_cols,
+ ksize_planes * ksize_rows * ksize_cols * depth};
+ TensorShape out_shape(out_sizes);
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+ // If there is nothing to compute, return.
+ if (out_shape.num_elements() == 0) {
+ return;
+ }
+
+ functor::ExtractVolumePatchesForward<Device, T>()(
+ context->eigen_device<Device>(), input.tensor<T, 5>(), ksize_planes,
+ ksize_rows, ksize_cols, stride_planes, stride_rows, stride_cols,
+ /* rate_planes, rate_rows, rate_cols, */
+ BrainPadding2EigenPadding(padding_), output->tensor<T, 5>());
+ }
+
+ private:
+ std::vector<int32> ksizes_;
+ std::vector<int32> strides_;
+ // std::vector<int32> rates_;
+
+ Padding padding_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExtractVolumePatchesOp);
+};
+
+// Registration of the CPU implementations.
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ExtractVolumePatches").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ ExtractVolumePatchesOp<CPUDevice, T>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+#if GOOGLE_CUDA
+
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+
+// clang-format off
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ExtractVolumePatchesForward<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input, \
+ int patch_planes, int patch_rows, int patch_cols, \
+ int stride_planes, int stride_rows, int stride_cols, \
+ /* int rate_planes, int rate_rows, int rate_cols, */ \
+ const Eigen::PaddingType& padding, \
+ typename TTypes<T, 5>::Tensor output); \
+ extern template struct ExtractVolumePatchesForward<GPUDevice, T>;
+// clang-format on
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+
+#undef DECLARE_GPU_SPEC
+
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ExtractVolumePatches").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ ExtractVolumePatchesOp<GPUDevice, T>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/extract_volume_patches_op.h b/tensorflow/core/kernels/extract_volume_patches_op.h
new file mode 100644
index 0000000000..7e0502b770
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op.h
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
+#define TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
+
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_volume_patch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T>
+struct ExtractVolumePatchesForward {
+ void operator()(const Device& d, typename TTypes<T, 5>::ConstTensor input,
+ int patch_planes, int patch_rows, int patch_cols,
+ int stride_planes, int stride_rows, int stride_cols,
+ /* int rate_planes, int rate_rows, int rate_cols, */
+ const Eigen::PaddingType& padding,
+ typename TTypes<T, 5>::Tensor output) {
+ const int64 N = std::max(input.size(), output.size());
+ if (N <= std::numeric_limits<Index32>::max()) {
+ auto output_32bit = To32Bit(output);
+ output_32bit.device(d) =
+ To32Bit(input)
+ .extract_volume_patches(patch_cols, patch_rows, patch_planes,
+ stride_cols, stride_rows, stride_planes,
+ padding)
+ .reshape(output_32bit.dimensions());
+ } else {
+ output.device(d) =
+ input
+ .extract_volume_patches(patch_cols, patch_rows, patch_planes,
+ stride_cols, stride_rows, stride_planes,
+ padding)
+ .reshape(output.dimensions());
+ }
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
diff --git a/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc
new file mode 100644
index 0000000000..c636493602
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/extract_volume_patches_op.h"
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+#define REGISTER(T) template struct ExtractVolumePatchesForward<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index f93ebea771..e22adcd569 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -108,6 +108,59 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
return Status::OK();
}
+Status RecordReader::GetMetadata(Metadata* md) {
+ if (!md) {
+ return errors::InvalidArgument(
+ "Metadata object call to GetMetadata() was null");
+ }
+
+ // Compute the metadata of the TFRecord file if not cached.
+ if (!cached_metadata_) {
+ TF_RETURN_IF_ERROR(input_stream_->Reset());
+
+ int64 data_size = 0;
+ int64 entries = 0;
+
+ // Within the loop, we always increment offset positively, so this
+ // loop should be guaranteed to either return after reaching EOF
+ // or encountering an error.
+ uint64 offset = 0;
+ string record;
+ while (true) {
+ // Read header, containing size of data.
+ Status s = ReadChecksummed(offset, sizeof(uint64), &record);
+ if (!s.ok()) {
+ if (errors::IsOutOfRange(s)) {
+ // We should reach out of range when the record file is complete.
+ break;
+ }
+ return s;
+ }
+
+ // Read the length of the data.
+ const uint64 length = core::DecodeFixed64(record.data());
+
+ // Skip reading the actual data since we just want the number
+ // of records and the size of the data.
+ TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(length + kFooterSize));
+ offset += kHeaderSize + length + kFooterSize;
+
+ // Increment running stats.
+ data_size += length;
+ ++entries;
+ }
+
+ cached_metadata_.reset(new Metadata());
+ cached_metadata_->stats.entries = entries;
+ cached_metadata_->stats.data_size = data_size;
+ cached_metadata_->stats.file_size =
+ data_size + (kHeaderSize + kFooterSize) * entries;
+ }
+
+ md->stats = cached_metadata_->stats;
+ return Status::OK();
+}
+
Status RecordReader::ReadRecord(uint64* offset, string* record) {
// Position the input stream.
int64 curr_pos = input_stream_->Tell();
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index 11af1366b0..17444660d4 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -66,6 +66,18 @@ class RecordReader {
static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
static const size_t kFooterSize = sizeof(uint32);
+ // Statistics (sizes are in units of bytes)
+ struct Stats {
+ int64 file_size = -1;
+ int64 data_size = -1;
+ int64 entries = -1; // Number of values
+ };
+
+ // Metadata for the TFRecord file.
+ struct Metadata {
+ Stats stats;
+ };
+
// Create a reader that will return log records from "*file".
// "*file" must remain live while this Reader is in use.
explicit RecordReader(
@@ -79,6 +91,17 @@ class RecordReader {
// OUT_OF_RANGE for end of file, or something else for an error.
Status ReadRecord(uint64* offset, string* record);
+ // Return the metadata of the Record file.
+ //
+ // The current implementation scans the file to completion,
+ // skipping over the data regions, to extract the metadata once
+ // on the first call to GetStats(). An improved implementation
+ // would change RecordWriter to write the metadata into TFRecord
+ // so that GetMetadata() could be a const method.
+ //
+ // 'metadata' must not be nullptr.
+ Status GetMetadata(Metadata* md);
+
private:
Status ReadChecksummed(uint64 offset, size_t n, string* result);
@@ -86,6 +109,8 @@ class RecordReader {
std::unique_ptr<InputStreamInterface> input_stream_;
bool last_read_failed_;
+ std::unique_ptr<Metadata> cached_metadata_;
+
TF_DISALLOW_COPY_AND_ASSIGN(RecordReader);
};
diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc
index 13bea1f8f1..a88d34d293 100644
--- a/tensorflow/core/lib/io/record_reader_writer_test.cc
+++ b/tensorflow/core/lib/io/record_reader_writer_test.cc
@@ -147,6 +147,13 @@ TEST(RecordReaderWriterTest, TestBasics) {
EXPECT_EQ("abc", record);
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
EXPECT_EQ("defg", record);
+
+ io::RecordReader::Metadata md;
+ TF_ASSERT_OK(reader.GetMetadata(&md));
+ EXPECT_EQ(2, md.stats.entries);
+ EXPECT_EQ(7, md.stats.data_size);
+ // Two entries have 16 bytes of header/footer each.
+ EXPECT_EQ(39, md.stats.file_size);
}
}
}
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index c24950643f..442686c92a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -2595,6 +2595,116 @@ REGISTER_OP("ExtractImagePatches")
// --------------------------------------------------------------------------
+// To enable rates, uncomment all lines commented below and use ksize_*_eff
+// as the second parameter of all GetWindowedOutputSizeVerbose calls instead
+// of ksize_*.
+REGISTER_OP("ExtractVolumePatches")
+ .Input("input: T")
+ .Output("patches: T")
+ .Attr("ksizes: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ /* .Attr("rates: list(int) >= 5") */
+ .Attr("T: realnumbertype")
+ .Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle input_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
+
+ std::vector<int32> ksizes;
+ TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
+ if (ksizes.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the ksizes attribute to contain 5 "
+ "values, but got: ",
+ ksizes.size());
+ }
+
+ std::vector<int32> strides;
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
+ if (strides.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the stride attribute to contain 5 "
+ "values, but got: ",
+ strides.size());
+ }
+
+ /*
+ // TODO(hsgkim): Enable rates.
+ // See extract_volume_patches_op.cc for why rates are disabled now.
+
+ std::vector<int32> rates;
+ TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
+ if (rates.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the rates attribute to contain 5 "
+ "values, but got: ",
+ rates.size());
+ }
+ */
+
+ int32 ksize_planes = ksizes[1];
+ int32 ksize_rows = ksizes[2];
+ int32 ksize_cols = ksizes[3];
+
+ int32 stride_planes = strides[1];
+ int32 stride_rows = strides[2];
+ int32 stride_cols = strides[3];
+
+ /*
+ int32 rate_planes = rates[1];
+ int32 rate_rows = rates[2];
+ int32 rate_cols = rates[3];
+
+ int32 ksize_planes_eff = ksize_planes +
+ (ksize_planes - 1) * (rate_planes - 1);
+ int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
+ int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
+ */
+
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
+ DimensionHandle output_depth_dim;
+ TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
+ ksize_planes * ksize_rows * ksize_cols,
+ &output_depth_dim));
+
+ if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
+ !c->ValueKnown(in_cols_dim)) {
+ ShapeHandle output_shape =
+ c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
+ InferenceContext::kUnknownDim, output_depth_dim});
+ c->set_output(0, output_shape);
+ return Status::OK();
+ }
+ auto in_planes = c->Value(in_planes_dim);
+ auto in_rows = c->Value(in_rows_dim);
+ auto in_cols = c->Value(in_cols_dim);
+
+ Padding padding;
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
+
+ int64 output_planes, output_rows, output_cols;
+ int64 padding_before, padding_after;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_planes, ksize_planes, stride_planes, padding, &output_planes,
+ &padding_before, &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_rows, ksize_rows, stride_rows, padding, &output_rows,
+ &padding_before, &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_cols, ksize_cols, stride_cols, padding, &output_cols,
+ &padding_before, &padding_after));
+ ShapeHandle output_shape =
+ c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
+ output_depth_dim});
+ c->set_output(0, output_shape);
+ return Status::OK();
+ });
+
+// --------------------------------------------------------------------------
+
REGISTER_OP("Bitcast")
.Input("input: T")
.Output("output: type")
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index e30a111096..b02ea64ac9 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -21902,6 +21902,59 @@ op {
}
}
op {
+ name: "ExtractVolumePatches"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "patches"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksizes"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+}
+op {
name: "FFT"
input_arg {
name: "input"
@@ -35273,6 +35326,134 @@ op {
is_commutative: true
}
op {
+ name: "MultiDeviceIterator"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "devices"
+ type: "list(string)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorFromStringHandle"
+ input_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorGetNextFromShard"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "shard_num"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorInit"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "max_buffer_size"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorToStringHandle"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
name: "Multinomial"
input_arg {
name: "logits"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 4d3f272c1b..1ada623cf5 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -932,4 +932,41 @@ REGISTER_OP("MapDefun")
return Status::OK();
});
+REGISTER_OP("MultiDeviceIterator")
+ .Output("handle: resource")
+ .Attr("devices: list(string) >= 1")
+ .Attr("shared_name: string")
+ .Attr("container: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("MultiDeviceIteratorInit")
+ .Input("dataset: variant")
+ .Input("multi_device_iterator: resource")
+ .Input("max_buffer_size: int64")
+ .Output("incarnation_id: int64")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("MultiDeviceIteratorGetNextFromShard")
+ .Input("multi_device_iterator: resource")
+ .Input("shard_num: int32")
+ .Input("incarnation_id: int64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(IteratorGetNextShapeFn);
+
+REGISTER_OP("MultiDeviceIteratorToStringHandle")
+ .Input("multi_device_iterator: resource")
+ .Output("string_handle: string")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("MultiDeviceIteratorFromStringHandle")
+ .Input("string_handle: string")
+ .Output("multi_device_iterator: resource")
+ .Attr("output_types: list(type) >= 0 = []")
+ .Attr("output_shapes: list(shape) >= 0 = []")
+ .SetShapeFn(shape_inference::ScalarShape);
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 594edfd7f0..4c5a472e9f 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -10187,6 +10187,59 @@ op {
}
}
op {
+ name: "ExtractVolumePatches"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "patches"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksizes"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+}
+op {
name: "FFT"
input_arg {
name: "input"
@@ -16812,6 +16865,134 @@ op {
is_commutative: true
}
op {
+ name: "MultiDeviceIterator"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "devices"
+ type: "list(string)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorFromStringHandle"
+ input_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorGetNextFromShard"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "shard_num"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorInit"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "max_buffer_size"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorToStringHandle"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
name: "Multinomial"
input_arg {
name: "logits"
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 07f984ceea..0e780eacc9 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -75,6 +75,8 @@ message RewriterConfig {
// Try to allocate some independent Op outputs contiguously in order to
// merge or eliminate downstream Ops (off by default).
Toggle scoped_allocator_optimization = 15;
+ // Force small ops onto the CPU (default is OFF).
+ Toggle pin_to_host_optimization = 18;
// Controls how many times we run the optimizers in meta optimizer (default
// is once).
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 680211edff..cf7ffd8149 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -34,9 +34,8 @@ limitations under the License.
#endif
#ifdef INTEL_MKL_ML_ONLY
-// Using pragma message since #warning doesn't work with all compilers
-#pragma message("Compiling for INTEL MKL ML only will be deprecated soon.")
-#pragma message("Please use MKL DNN (the default option for --config=mkl)")
+#error \
+ "Compiling for INTEL MKL ML only is no longer supported.Please use MKL DNN (the default option for --config=mkl)"
#endif
#ifdef INTEL_MKL_ML_ONLY
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 1d72bcd2b6..e6e07c8437 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3770,6 +3770,68 @@ func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf
return op.Output(0)
}
+// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
+//
+// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble.
+//
+// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
+// layer.
+func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesGetEnsembleStates",
+ Input: []tf.Input{
+ tree_ensemble_handle,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
+}
+
+// Creates a tree ensemble model and returns a handle to it.
+//
+// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble resource to be created.
+// stamp_token: Token to use as the initial value of the resource stamp.
+// tree_ensemble_serialized: Serialized proto of the tree ensemble.
+//
+// Returns the created operation.
+func BoostedTreesCreateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesCreateEnsemble",
+ Input: []tf.Input{
+ tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Checks whether a tree ensemble has been initialized.
+//
+// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble resouce.
+//
+// Returns output boolean on whether it is initialized or not.
+func IsBoostedTreesEnsembleInitialized(scope *Scope, tree_ensemble_handle tf.Output) (is_initialized tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "IsBoostedTreesEnsembleInitialized",
+ Input: []tf.Input{
+ tree_ensemble_handle,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the sum along sparse segments of a tensor.
//
// Read
@@ -5755,26 +5817,6 @@ func LogicalAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
-// Checks whether a tree ensemble has been initialized.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble resouce.
-//
-// Returns output boolean on whether it is initialized or not.
-func IsBoostedTreesEnsembleInitialized(scope *Scope, tree_ensemble_handle tf.Output) (is_initialized tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "IsBoostedTreesEnsembleInitialized",
- Input: []tf.Input{
- tree_ensemble_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// CastAttr is an optional argument to Cast.
type CastAttr func(optionalAttr)
@@ -19714,27 +19756,6 @@ func OptimizeDataset(scope *Scope, input_dataset tf.Output, optimizations tf.Out
return op.Output(0)
}
-// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-//
-// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
-// layer.
-func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesGetEnsembleStates",
- Input: []tf.Input{
- tree_ensemble_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
-}
-
// Returns the element-wise min of two SparseTensors.
//
// Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
@@ -21078,6 +21099,147 @@ func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_
return op.Output(0)
}
+// Deserializes a serialized tree ensemble config and replaces current tree
+//
+// ensemble.
+//
+// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble.
+// stamp_token: Token to use as the new value of the resource stamp.
+// tree_ensemble_serialized: Serialized proto of the ensemble.
+//
+// Returns the created operation.
+func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesDeserializeEnsemble",
+ Input: []tf.Input{
+ tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Transforms a tf.Example proto (as a string) into typed tensors.
+//
+// Arguments:
+// serialized: A vector containing a batch of binary serialized Example protos.
+// dense_defaults: A list of Tensors (some may be empty), whose length matches
+// the length of `dense_keys`. dense_defaults[j] provides default values
+// when the example's feature_map lacks dense_key[j]. If an empty Tensor is
+// provided for dense_defaults[j], then the Feature dense_keys[j] is required.
+// The input type is inferred from dense_defaults[j], even when it's empty.
+// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined,
+// then the shape of dense_defaults[j] must match that of dense_shapes[j].
+// If dense_shapes[j] has an undefined major dimension (variable strides dense
+// feature), dense_defaults[j] must contain a single element:
+// the padding element.
+// num_sparse: The number of sparse features to be parsed from the example. This
+// must match the lengths of `sparse_keys` and `sparse_types`.
+// sparse_keys: A list of `num_sparse` strings.
+// The keys expected in the Examples' features associated with sparse values.
+// dense_keys: The keys expected in the Examples' features associated with dense
+// values.
+// sparse_types: A list of `num_sparse` types; the data types of data in each
+// Feature given in sparse_keys.
+// Currently the ParseSingleExample op supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// dense_shapes: The shapes of data in each Feature given in dense_keys.
+// The length of this list must match the length of `dense_keys`. The
+// number of elements in the Feature corresponding to dense_key[j] must
+// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] ==
+// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j]
+// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1,
+// ..., DN), the shape of the output Tensor dense_values[j] will be (M,
+// D1, .., DN), where M is the number of blocks of elements of length
+// D1 * .... * DN, in the input.
+func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes}
+ opspec := tf.OpSpec{
+ Type: "ParseSingleExample",
+ Input: []tf.Input{
+ serialized, tf.OutputList(dense_defaults),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ return sparse_indices, sparse_values, sparse_shapes, dense_values
+}
+
+// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
+type WholeFileReaderV2Attr func(optionalAttr)
+
+// WholeFileReaderV2Container sets the optional container attribute to value.
+//
+// value: If non-empty, this reader is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this reader is named in the given bucket
+// with this shared_name. Otherwise, the node name is used instead.
+// If not specified, defaults to ""
+func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// A Reader that outputs the entire contents of a file as a value.
+//
+// To use, enqueue filenames in a Queue. The output of ReaderRead will
+// be a filename (key) and the contents of that file (value).
+//
+// Returns The handle to reference the Reader.
+func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "WholeFileReaderV2",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Pop the element at the top of the stack.
//
// Arguments:
@@ -30734,27 +30896,6 @@ func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, val
return op.Output(0)
}
-// Creates a tree ensemble model and returns a handle to it.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble resource to be created.
-// stamp_token: Token to use as the initial value of the resource stamp.
-// tree_ensemble_serialized: Serialized proto of the tree ensemble.
-//
-// Returns the created operation.
-func BoostedTreesCreateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesCreateEnsemble",
- Input: []tf.Input{
- tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
- },
- }
- return scope.AddOperation(opspec)
-}
-
// Applies sparse addition to `input` using individual values or slices
//
// from `updates` according to indices `indices`. The updates are non-aliasing:
@@ -32575,144 +32716,3 @@ func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1), op.Output(2)
}
-
-// Transforms a tf.Example proto (as a string) into typed tensors.
-//
-// Arguments:
-// serialized: A vector containing a batch of binary serialized Example protos.
-// dense_defaults: A list of Tensors (some may be empty), whose length matches
-// the length of `dense_keys`. dense_defaults[j] provides default values
-// when the example's feature_map lacks dense_key[j]. If an empty Tensor is
-// provided for dense_defaults[j], then the Feature dense_keys[j] is required.
-// The input type is inferred from dense_defaults[j], even when it's empty.
-// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined,
-// then the shape of dense_defaults[j] must match that of dense_shapes[j].
-// If dense_shapes[j] has an undefined major dimension (variable strides dense
-// feature), dense_defaults[j] must contain a single element:
-// the padding element.
-// num_sparse: The number of sparse features to be parsed from the example. This
-// must match the lengths of `sparse_keys` and `sparse_types`.
-// sparse_keys: A list of `num_sparse` strings.
-// The keys expected in the Examples' features associated with sparse values.
-// dense_keys: The keys expected in the Examples' features associated with dense
-// values.
-// sparse_types: A list of `num_sparse` types; the data types of data in each
-// Feature given in sparse_keys.
-// Currently the ParseSingleExample op supports DT_FLOAT (FloatList),
-// DT_INT64 (Int64List), and DT_STRING (BytesList).
-// dense_shapes: The shapes of data in each Feature given in dense_keys.
-// The length of this list must match the length of `dense_keys`. The
-// number of elements in the Feature corresponding to dense_key[j] must
-// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] ==
-// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j]
-// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1,
-// ..., DN), the shape of the output Tensor dense_values[j] will be (M,
-// D1, .., DN), where M is the number of blocks of elements of length
-// D1 * .... * DN, in the input.
-func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes}
- opspec := tf.OpSpec{
- Type: "ParseSingleExample",
- Input: []tf.Input{
- serialized, tf.OutputList(dense_defaults),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- return sparse_indices, sparse_values, sparse_shapes, dense_values
-}
-
-// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
-type WholeFileReaderV2Attr func(optionalAttr)
-
-// WholeFileReaderV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this reader is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this reader is named in the given bucket
-// with this shared_name. Otherwise, the node name is used instead.
-// If not specified, defaults to ""
-func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// A Reader that outputs the entire contents of a file as a value.
-//
-// To use, enqueue filenames in a Queue. The output of ReaderRead will
-// be a filename (key) and the contents of that file (value).
-//
-// Returns The handle to reference the Reader.
-func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "WholeFileReaderV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Deserializes a serialized tree ensemble config and replaces current tree
-//
-// ensemble.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-// stamp_token: Token to use as the new value of the resource stamp.
-// tree_ensemble_serialized: Serialized proto of the ensemble.
-//
-// Returns the created operation.
-func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesDeserializeEnsemble",
- Input: []tf.Input{
- tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
- },
- }
- return scope.AddOperation(opspec)
-}
diff --git a/tensorflow/python/autograph/pyct/compiler.py b/tensorflow/python/autograph/pyct/compiler.py
index 9e1b6bdbe8..37f3e72f6e 100644
--- a/tensorflow/python/autograph/pyct/compiler.py
+++ b/tensorflow/python/autograph/pyct/compiler.py
@@ -108,7 +108,7 @@ def ast_to_object(nodes,
indices = (-1,)
if include_source_map:
- source_map = origin_info.source_map(nodes, source, f.name, indices)
+ source_map = origin_info.create_source_map(nodes, source, f.name, indices)
# TODO(mdan): Try flush() and delete=False instead.
if delete_on_exit:
diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py
index 4c7c4165ef..102bd42c91 100644
--- a/tensorflow/python/autograph/pyct/origin_info.py
+++ b/tensorflow/python/autograph/pyct/origin_info.py
@@ -75,7 +75,7 @@ class OriginInfo(
# TODO(mdan): This source map should be a class - easier to refer to.
-def source_map(nodes, code, filename, indices_in_code):
+def create_source_map(nodes, code, filename, indices_in_code):
"""Creates a source map between an annotated AST and the code it compiles to.
Args:
diff --git a/tensorflow/python/autograph/pyct/origin_info_test.py b/tensorflow/python/autograph/pyct/origin_info_test.py
index 6b9c30dbd0..3b1d5f2040 100644
--- a/tensorflow/python/autograph/pyct/origin_info_test.py
+++ b/tensorflow/python/autograph/pyct/origin_info_test.py
@@ -27,49 +27,41 @@ from tensorflow.python.platform import test
class OriginInfoTest(test.TestCase):
- def test_source_map(self):
+ def test_create_source_map(self):
def test_fn(x):
- if x > 0:
- x += 1
- return x
-
- node, source = parser.parse_entity(test_fn)
+ return x + 1
+
+ node, _ = parser.parse_entity(test_fn)
+ fake_origin = origin_info.OriginInfo(
+ loc=origin_info.Location('fake_filename', 3, 7),
+ function_name='fake_function_name',
+ source_code_line='fake source line',
+ comment=None)
fn_node = node.body[0]
- origin_info.resolve(fn_node, source)
-
- # Insert a traced line.
- new_node = parser.parse_str('x = abs(x)').body[0]
- anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
- fn_node.body.insert(0, new_node)
+ anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin)
+ converted_code = compiler.ast_to_source(fn_node)
- # Insert an untraced line.
- fn_node.body.insert(0, parser.parse_str('x = 0').body[0])
+ source_map = origin_info.create_source_map(
+ fn_node, converted_code, 'test_filename', [0])
- modified_source = compiler.ast_to_source(fn_node)
+ loc = origin_info.LineLocation('test_filename', 2)
+ self.assertIn(loc, source_map)
+ self.assertIs(source_map[loc], fake_origin)
- source_map = origin_info.source_map(fn_node, modified_source,
- 'test_filename', [0])
+ def test_source_map_no_origin(self):
- loc = origin_info.LineLocation('test_filename', 1)
- origin = source_map[loc]
- self.assertEqual(origin.source_code_line, 'def test_fn(x):')
- self.assertEqual(origin.loc.lineno, 1)
+ def test_fn(x):
+ return x + 1
- # The untraced line, inserted second.
- loc = origin_info.LineLocation('test_filename', 2)
- self.assertFalse(loc in source_map)
+ node, _ = parser.parse_entity(test_fn)
+ fn_node = node.body[0]
+ converted_code = compiler.ast_to_source(fn_node)
- # The traced line, inserted first.
- loc = origin_info.LineLocation('test_filename', 3)
- origin = source_map[loc]
- self.assertEqual(origin.source_code_line, ' if x > 0:')
- self.assertEqual(origin.loc.lineno, 2)
+ source_map = origin_info.create_source_map(
+ fn_node, converted_code, 'test_filename', [0])
- loc = origin_info.LineLocation('test_filename', 4)
- origin = source_map[loc]
- self.assertEqual(origin.source_code_line, ' if x > 0:')
- self.assertEqual(origin.loc.lineno, 2)
+ self.assertEqual(len(source_map), 0)
def test_resolve(self):
@@ -79,6 +71,7 @@ class OriginInfoTest(test.TestCase):
node, source = parser.parse_entity(test_fn)
fn_node = node.body[0]
+
origin_info.resolve(fn_node, source)
origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 5e8f5d6e8e..45f40cd183 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 21)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 24)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/BUILD b/tensorflow/python/data/BUILD
index 3e08c1587e..138141f4fc 100644
--- a/tensorflow/python/data/BUILD
+++ b/tensorflow/python/data/BUILD
@@ -12,6 +12,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/data/ops:readers",
],
)
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 17d4fec662..28ee3ebaa6 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -394,6 +394,7 @@ cuda_py_test(
size = "small",
srcs = ["optional_ops_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:optional_ops",
@@ -408,6 +409,26 @@ cuda_py_test(
],
)
+cuda_py_test(
+ name = "multi_device_iterator_test",
+ size = "small",
+ srcs = ["multi_device_iterator_test.py"],
+ additional_deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ tags = [
+ "no_windows_gpu",
+ ],
+)
+
tf_py_test(
name = "window_dataset_op_test",
size = "small",
diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
new file mode 100644
index 0000000000..056664b83b
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
@@ -0,0 +1,190 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""MultiDeviceIterator tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import multi_device_iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class MultiDeviceIteratorTest(test.TestCase):
+
+ def testNoGetNext(self):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"])
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+
+ def testBasic(self):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testOneOnSameDevice(self):
+ with ops.device("/cpu:0"):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:0", "/cpu:1"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testRepeatDevices(self):
+ with ops.device("/cpu:0"):
+ dataset = dataset_ops.Dataset.range(20)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"])
+ elements = multi_device_iterator.get_next()
+ elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 20, 4):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ self.assertEqual(i + 2, sess.run(elem_on_3))
+ self.assertEqual(i + 3, sess.run(elem_on_4))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+ sess.run(elem_on_3)
+ sess.run(elem_on_4)
+
+ def testNotFullyDivisible(self):
+ dataset = dataset_ops.Dataset.range(9)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 8, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ self.assertEqual(8, sess.run(elem_on_1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testUneven(self):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4)
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ for i in range(0, 10, 2):
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testMultipleInitializations(self):
+ with ops.device("/cpu:0"):
+ epoch = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000)
+ dataset2 = dataset_ops.Dataset.range(1000)
+ dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+ init_op = multi_device_iterator.initializer
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ for i in range(1000):
+ sess.run(init_op, feed_dict={epoch: i})
+ self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2]))
+
+ def testBasicGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/gpu:0"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testUnevenGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4)
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ for i in range(0, 10, 2):
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py
index c344513e71..706a65fe55 100644
--- a/tensorflow/python/data/kernel_tests/optional_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py
@@ -17,11 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import optional_ops
+from tensorflow.python.data.util import structure
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,14 +35,11 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class OptionalTest(test.TestCase):
+class OptionalTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFromValue(self):
opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
- self.assertEqual(dtypes.float32, opt.output_types)
- self.assertEqual([], opt.output_shapes)
- self.assertEqual(ops.Tensor, opt.output_classes)
self.assertTrue(self.evaluate(opt.has_value()))
self.assertEqual(37.0, self.evaluate(opt.get_value()))
@@ -50,15 +49,6 @@ class OptionalTest(test.TestCase):
"a": constant_op.constant(37.0),
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
})
- self.assertEqual({
- "a": dtypes.float32,
- "b": (dtypes.string, dtypes.string)
- }, opt.output_types)
- self.assertEqual({"a": [], "b": ([1], [])}, opt.output_shapes)
- self.assertEqual({
- "a": ops.Tensor,
- "b": (ops.Tensor, ops.Tensor)
- }, opt.output_classes)
self.assertTrue(self.evaluate(opt.has_value()))
self.assertEqual({
"a": 37.0,
@@ -76,46 +66,29 @@ class OptionalTest(test.TestCase):
values=np.array([-1., 1.], dtype=np.float32),
dense_shape=np.array([2, 2]))
opt = optional_ops.Optional.from_value((st_0, st_1))
- self.assertEqual((dtypes.int64, dtypes.float32), opt.output_types)
- self.assertEqual(([1], [2, 2]), opt.output_shapes)
- self.assertEqual((sparse_tensor.SparseTensor, sparse_tensor.SparseTensor),
- opt.output_classes)
+ self.assertTrue(self.evaluate(opt.has_value()))
+ val_0, val_1 = opt.get_value()
+ for expected, actual in [(st_0, val_0), (st_1, val_1)]:
+ self.assertAllEqual(expected.indices, self.evaluate(actual.indices))
+ self.assertAllEqual(expected.values, self.evaluate(actual.values))
+ self.assertAllEqual(expected.dense_shape,
+ self.evaluate(actual.dense_shape))
@test_util.run_in_graph_and_eager_modes
def testFromNone(self):
- opt = optional_ops.Optional.none_from_structure(tensor_shape.scalar(),
- dtypes.float32, ops.Tensor)
- self.assertEqual(dtypes.float32, opt.output_types)
- self.assertEqual([], opt.output_shapes)
- self.assertEqual(ops.Tensor, opt.output_classes)
+ value_structure = structure.TensorStructure(dtypes.float32, [])
+ opt = optional_ops.Optional.none_from_structure(value_structure)
+ self.assertTrue(opt.value_structure.is_compatible_with(value_structure))
+ self.assertFalse(
+ opt.value_structure.is_compatible_with(
+ structure.TensorStructure(dtypes.float32, [1])))
+ self.assertFalse(
+ opt.value_structure.is_compatible_with(
+ structure.TensorStructure(dtypes.int32, [])))
self.assertFalse(self.evaluate(opt.has_value()))
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(opt.get_value())
- def testStructureMismatchError(self):
- tuple_output_shapes = (tensor_shape.scalar(), tensor_shape.scalar())
- tuple_output_types = (dtypes.float32, dtypes.float32)
- tuple_output_classes = (ops.Tensor, ops.Tensor)
-
- dict_output_shapes = {
- "a": tensor_shape.scalar(),
- "b": tensor_shape.scalar()
- }
- dict_output_types = {"a": dtypes.float32, "b": dtypes.float32}
- dict_output_classes = {"a": ops.Tensor, "b": ops.Tensor}
-
- with self.assertRaises(TypeError):
- optional_ops.Optional.none_from_structure(
- tuple_output_shapes, tuple_output_types, dict_output_classes)
-
- with self.assertRaises(TypeError):
- optional_ops.Optional.none_from_structure(
- tuple_output_shapes, dict_output_types, tuple_output_classes)
-
- with self.assertRaises(TypeError):
- optional_ops.Optional.none_from_structure(
- dict_output_shapes, tuple_output_types, tuple_output_classes)
-
@test_util.run_in_graph_and_eager_modes
def testCopyToGPU(self):
if not test_util.is_gpu_available():
@@ -126,17 +99,15 @@ class OptionalTest(test.TestCase):
(constant_op.constant(37.0), constant_op.constant("Foo"),
constant_op.constant(42)))
optional_none = optional_ops.Optional.none_from_structure(
- tensor_shape.scalar(), dtypes.float32, ops.Tensor)
+ structure.TensorStructure(dtypes.float32, []))
with ops.device("/gpu:0"):
gpu_optional_with_value = optional_ops._OptionalImpl(
array_ops.identity(optional_with_value._variant_tensor),
- optional_with_value.output_shapes, optional_with_value.output_types,
- optional_with_value.output_classes)
+ optional_with_value.value_structure)
gpu_optional_none = optional_ops._OptionalImpl(
array_ops.identity(optional_none._variant_tensor),
- optional_none.output_shapes, optional_none.output_types,
- optional_none.output_classes)
+ optional_none.value_structure)
gpu_optional_with_value_has_value = gpu_optional_with_value.has_value()
gpu_optional_with_value_values = gpu_optional_with_value.get_value()
@@ -148,14 +119,101 @@ class OptionalTest(test.TestCase):
self.evaluate(gpu_optional_with_value_values))
self.assertFalse(self.evaluate(gpu_optional_none_has_value))
- def testIteratorGetNextAsOptional(self):
- ds = dataset_ops.Dataset.range(3)
+ def _assertElementValueEqual(self, expected, actual):
+ if isinstance(expected, dict):
+ self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
+ for k in expected.keys():
+ self._assertElementValueEqual(expected[k], actual[k])
+ elif isinstance(expected, sparse_tensor.SparseTensorValue):
+ self.assertAllEqual(expected.indices, actual.indices)
+ self.assertAllEqual(expected.values, actual.values)
+ self.assertAllEqual(expected.dense_shape, actual.dense_shape)
+ else:
+ self.assertAllEqual(expected, actual)
+
+ # pylint: disable=g-long-lambda
+ @parameterized.named_parameters(
+ ("Tensor", lambda: constant_op.constant(37.0),
+ structure.TensorStructure(dtypes.float32, [])),
+ ("SparseTensor", lambda: sparse_tensor.SparseTensor(
+ indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
+ dense_shape=[1]),
+ structure.SparseTensorStructure(dtypes.int32, [1])),
+ ("Nest", lambda: {
+ "a": constant_op.constant(37.0),
+ "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
+ structure.NestedStructure({
+ "a": structure.TensorStructure(dtypes.float32, []),
+ "b": (structure.TensorStructure(dtypes.string, [1]),
+ structure.TensorStructure(dtypes.string, []))})),
+ ("Optional", lambda: optional_ops.Optional.from_value(37.0),
+ optional_ops.OptionalStructure(
+ structure.TensorStructure(dtypes.float32, []))),
+ )
+ def testOptionalStructure(self, tf_value_fn, expected_value_structure):
+ tf_value = tf_value_fn()
+ opt = optional_ops.Optional.from_value(tf_value)
+
+ self.assertTrue(
+ expected_value_structure.is_compatible_with(opt.value_structure))
+ self.assertTrue(
+ opt.value_structure.is_compatible_with(expected_value_structure))
+
+ opt_structure = structure.Structure.from_value(opt)
+ self.assertIsInstance(opt_structure, optional_ops.OptionalStructure)
+ self.assertTrue(opt_structure.is_compatible_with(opt_structure))
+ self.assertTrue(opt_structure._value_structure.is_compatible_with(
+ expected_value_structure))
+ self.assertEqual([dtypes.variant], opt_structure._flat_types)
+ self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes)
+
+ # All OptionalStructure objects are not compatible with a non-optional
+ # value.
+ non_optional_structure = structure.Structure.from_value(
+ constant_op.constant(42.0))
+ self.assertFalse(opt_structure.is_compatible_with(non_optional_structure))
+
+ # Assert that the optional survives a round-trip via _from_tensor_list()
+ # and _to_tensor_list().
+ round_trip_opt = opt_structure._from_tensor_list(
+ opt_structure._to_tensor_list(opt))
+ if isinstance(tf_value, optional_ops.Optional):
+ self.assertEqual(
+ self.evaluate(tf_value.get_value()),
+ self.evaluate(round_trip_opt.get_value().get_value()))
+ else:
+ self.assertEqual(
+ self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value()))
+
+ @parameterized.named_parameters(
+ ("Tensor", np.array([1, 2, 3], dtype=np.int32),
+ lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
+ ("SparseTensor", sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]],
+ values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]),
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]),
+ False),
+ ("Nest", {"a": np.array([1, 2, 3], dtype=np.int32),
+ "b": sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]],
+ values=np.array([-1., 1.], dtype=np.float32),
+ dense_shape=[2, 2])},
+ lambda: {"a": constant_op.constant([4, 5, 6], dtype=dtypes.int32),
+ "b": sparse_tensor.SparseTensor(
+ indices=[[0, 1], [1, 0]], values=[37.0, 42.0],
+ dense_shape=[2, 2])}, False),
+ )
+ def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, works_on_gpu):
+ if not works_on_gpu and test.is_gpu_available():
+ self.skipTest("Test case not yet supported on GPU.")
+ ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
iterator = ds.make_initializable_iterator()
next_elem = iterator_ops.get_next_as_optional(iterator)
- self.assertTrue(isinstance(next_elem, optional_ops.Optional))
- self.assertEqual(ds.output_types, next_elem.output_types)
- self.assertEqual(ds.output_shapes, next_elem.output_shapes)
- self.assertEqual(ds.output_classes, next_elem.output_classes)
+ self.assertIsInstance(next_elem, optional_ops.Optional)
+ self.assertTrue(
+ next_elem.value_structure.is_compatible_with(
+ structure.Structure.from_value(tf_value_fn())))
elem_has_value_t = next_elem.has_value()
elem_value_t = next_elem.get_value()
with self.cached_session() as sess:
@@ -169,10 +227,10 @@ class OptionalTest(test.TestCase):
# For each element of the dataset, assert that the optional evaluates to
# the expected value.
sess.run(iterator.initializer)
- for i in range(3):
+ for _ in range(3):
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
self.assertTrue(elem_has_value)
- self.assertEqual(i, elem_value)
+ self._assertElementValueEqual(np_value, elem_value)
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
# false, and attempting to get the value will fail.
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index 57517afae8..76bf2470b1 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -19,6 +19,7 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:script_ops",
+ "//tensorflow/python:smart_cond",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
@@ -63,6 +64,7 @@ py_library(
"//tensorflow/python/compat",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/util:structure",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
],
@@ -77,8 +79,23 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/util:structure",
+ ],
+)
+
+py_library(
+ name = "multi_device_iterator_ops",
+ srcs = ["multi_device_iterator_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
],
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 93b3a7b93b..7c20c049f5 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1007,8 +1007,25 @@ class Dataset(object):
return ParallelMapDataset(self, map_func, num_parallel_calls)
def flat_map(self, map_func):
- """Maps `map_func` across this dataset and flattens the result.
+ """Maps `map_func` across this dataset and flattens the result.
+
+ Use `flat_map` if you want to make sure that the order of your dataset
+ stays the same. For example, to flatten a dataset of batches into a
+ dataset of their elements:
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset. '[...]' represents a tensor.
+ a = {[1,2,3,4,5], [6,7,8,9], [10]}
+
+ a.flat_map(lambda x: Dataset.from_tensor_slices(x)) ==
+ {[1,2,3,4,5,6,7,8,9,10]}
+ ```
+
+ `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
+ `flat_map` produces the same output as
+ `tf.data.Dataset.interleave(cycle_length=1)`
+
Args:
map_func: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to a
@@ -1043,7 +1060,7 @@ class Dataset(object):
elements are produced. `cycle_length` controls the number of input elements
that are processed concurrently. If you set `cycle_length` to 1, this
transformation will handle one input element at a time, and will produce
- identical results = to `tf.data.Dataset.flat_map`. In general,
+ identical results to `tf.data.Dataset.flat_map`. In general,
this transformation will apply `map_func` to `cycle_length` input elements,
open iterators on the returned `Dataset` objects, and cycle through them
producing `block_length` consecutive elements from each iterator, and
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 8f8e026df9..cae00cdbfc 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -24,6 +24,7 @@ from tensorflow.python.compat import compat
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
+from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -85,10 +86,10 @@ class Iterator(checkpointable.CheckpointableBase):
initializer: A `tf.Operation` that should be run to initialize this
iterator.
output_types: A nested structure of `tf.DType` objects corresponding to
- each component of an element of this dataset.
+ each component of an element of this iterator.
output_shapes: A nested structure of `tf.TensorShape` objects
- corresponding to each component of an element of this dataset.
- output_classes: A nested structure of Python `type` object corresponding
+ corresponding to each component of an element of this iterator.
+ output_classes: A nested structure of Python `type` objects corresponding
to each component of an element of this iterator.
"""
self._iterator_resource = iterator_resource
@@ -670,6 +671,6 @@ def get_next_as_optional(iterator):
output_shapes=nest.flatten(
sparse.as_dense_shapes(iterator.output_shapes,
iterator.output_classes))),
- output_shapes=iterator.output_shapes,
- output_types=iterator.output_types,
- output_classes=iterator.output_classes)
+ structure.Structure._from_legacy_structure(iterator.output_types,
+ iterator.output_shapes,
+ iterator.output_classes))
diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py
new file mode 100644
index 0000000000..84e8abbd83
--- /dev/null
+++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py
@@ -0,0 +1,213 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrapper for prefetching_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
+
+
+class _PerDeviceGenerator(dataset_ops.Dataset):
+ """A `dummy` generator dataset."""
+
+ def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
+ source_device, target_device, output_shapes, output_types,
+ output_classes):
+ self._target_device = target_device
+ self._output_types = output_types
+ self._output_shapes = output_shapes
+ self._output_classes = output_classes
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._output_shapes, self._output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._output_types, self._output_classes))
+
+ multi_device_iterator_string_handle = (
+ gen_dataset_ops.multi_device_iterator_to_string_handle(
+ multi_device_iterator_resource))
+
+ @function.Defun()
+ def _init_func():
+ return multi_device_iterator_string_handle
+
+ @function.Defun()
+ def _remote_init_func():
+ return functional_ops.remote_call(
+ target=source_device,
+ args=_init_func.captured_inputs,
+ Tout=[dtypes.string],
+ f=_init_func)
+
+ self._init_func = _remote_init_func
+ self._init_captured_args = _remote_init_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _next_func(string_handle):
+ multi_device_iterator = (
+ gen_dataset_ops.multi_device_iterator_from_string_handle(
+ string_handle=string_handle,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes))
+ return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
+ multi_device_iterator=multi_device_iterator,
+ shard_num=shard_num,
+ incarnation_id=incarnation_id,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @function.Defun(dtypes.string)
+ def _remote_next_func(string_handle):
+ return functional_ops.remote_call(
+ target=source_device,
+ args=[string_handle] + _next_func.captured_inputs,
+ Tout=self._flat_output_types,
+ f=_next_func)
+
+ self._next_func = _remote_next_func
+ self._next_captured_args = _remote_next_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _finalize_func(unused_string_handle):
+ return array_ops.constant(0, dtypes.int64)
+
+ @function.Defun(dtypes.string)
+ def _remote_finalize_func(string_handle):
+ return functional_ops.remote_call(
+ target=source_device,
+ args=[string_handle] + _finalize_func.captured_inputs,
+ Tout=[dtypes.int64],
+ f=_finalize_func)
+
+ self._finalize_func = _remote_finalize_func
+ self._finalize_captured_args = _remote_finalize_func.captured_inputs
+
+ def _as_variant_tensor(self):
+ with ops.device(self._target_device):
+ return gen_dataset_ops.generator_dataset(
+ self._init_captured_args,
+ self._next_captured_args,
+ self._finalize_captured_args,
+ init_func=self._init_func,
+ next_func=self._next_func,
+ finalize_func=self._finalize_func,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+class MultiDeviceIterator(object):
+ """An iterator over multiple devices."""
+
+ def __init__(self,
+ dataset,
+ devices,
+ max_buffer_size=1,
+ prefetch_buffer_size=1,
+ source_device="/cpu:0"):
+ """Constructs a MultiDeviceIterator.
+
+ Args:
+ dataset: The input dataset to be iterated over.
+ devices: The list of devices to fetch data to.
+ max_buffer_size: Maximum size of the host side per device buffer to keep.
+ prefetch_buffer_size: if > 1, then we setup a buffer on each device
+ to prefetch into.
+ source_device: The host device to place the `dataset` on.
+ """
+ self._dataset = dataset
+ self._devices = devices
+ self._source_device = source_device
+ self._source_device_tensor = ops.convert_to_tensor(source_device)
+
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._dataset.output_shapes,
+ self._dataset.output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._dataset.output_types,
+ self._dataset.output_classes))
+
+ # Create the MultiDeviceIterator.
+ with ops.device(self._source_device):
+ self._multi_device_iterator_resource = (
+ gen_dataset_ops.multi_device_iterator(
+ devices=self._devices,
+ shared_name="",
+ container="",
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes))
+
+ # The incarnation ID is used to ensure consistency between the per-device
+ # iterators and the multi-device iterator.
+ self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
+ self._dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._multi_device_iterator_resource,
+ max_buffer_size=max_buffer_size)
+
+ # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
+ # initialize the device side of the pipeline. This would allow the
+ # MultiDeviceIterator to choose, for example, to move some transformations
+ # into the device side from its input. It might be useful in rewriting.
+ # Create the per device iterators.
+ self._device_iterators = []
+ i = 0
+ for device in self._devices:
+ ds = _PerDeviceGenerator(
+ i, self._multi_device_iterator_resource, self._incarnation_id,
+ self._source_device_tensor, device, self._dataset.output_shapes,
+ self._dataset.output_types, self._dataset.output_classes)
+ if prefetch_buffer_size > 0:
+ ds = ds.prefetch(prefetch_buffer_size)
+ with ops.device(device):
+ self._device_iterators.append(ds.make_initializable_iterator())
+ i += 1
+
+ device_iterator_initializers = [
+ iterator.initializer for iterator in self._device_iterators
+ ]
+ self._initializer = control_flow_ops.group(*device_iterator_initializers)
+
+ def get_next(self):
+ result = []
+ i = 0
+ for device in self._devices:
+ with ops.device(device):
+ result.append(self._device_iterators[i].get_next())
+ i += 1
+ return result
+
+ @property
+ def initializer(self):
+ return self._initializer
diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py
index b75b98dc72..3bbebd7878 100644
--- a/tensorflow/python/data/ops/optional_ops.py
+++ b/tensorflow/python/data/ops/optional_ops.py
@@ -19,11 +19,9 @@ from __future__ import print_function
import abc
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
+from tensorflow.python.data.util import structure
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
@@ -67,36 +65,14 @@ class Optional(object):
raise NotImplementedError("Optional.get_value()")
@abc.abstractproperty
- def output_classes(self):
- """Returns the class of each component of this optional.
-
- The expected values are `tf.Tensor` and `tf.SparseTensor`.
-
- Returns:
- A nested structure of Python `type` objects corresponding to each
- component of this optional.
- """
- raise NotImplementedError("Optional.output_classes")
-
- @abc.abstractproperty
- def output_shapes(self):
- """Returns the shape of each component of this optional.
-
- Returns:
- A nested structure of `tf.TensorShape` objects corresponding to each
- component of this optional.
- """
- raise NotImplementedError("Optional.output_shapes")
-
- @abc.abstractproperty
- def output_types(self):
- """Returns the type of each component of this optional.
+ def value_structure(self):
+ """The structure of the components of this optional.
Returns:
- A nested structure of `tf.DType` objects corresponding to each component
- of this optional.
+ A `Structure` object representing the structure of the components of this
+ optional.
"""
- raise NotImplementedError("Optional.output_types")
+ raise NotImplementedError("Optional.value_structure")
@staticmethod
def from_value(value):
@@ -108,48 +84,30 @@ class Optional(object):
Returns:
An `Optional` that wraps `value`.
"""
- # TODO(b/110122868): Consolidate this destructuring logic with the
- # similar code in `Dataset.from_tensors()`.
with ops.name_scope("optional") as scope:
with ops.name_scope("value"):
- value = nest.pack_sequence_as(value, [
- sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
- t, name="component_%d" % i)
- for i, t in enumerate(nest.flatten(value))
- ])
-
- encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value))
- output_classes = sparse.get_classes(value)
- output_shapes = nest.pack_sequence_as(
- value, [t.get_shape() for t in nest.flatten(value)])
- output_types = nest.pack_sequence_as(
- value, [t.dtype for t in nest.flatten(value)])
+ value_structure = structure.Structure.from_value(value)
+ encoded_value = value_structure._to_tensor_list(value) # pylint: disable=protected-access
return _OptionalImpl(
gen_dataset_ops.optional_from_value(encoded_value, name=scope),
- output_shapes, output_types, output_classes)
+ value_structure)
@staticmethod
- def none_from_structure(output_shapes, output_types, output_classes):
+ def none_from_structure(value_structure):
"""Returns an `Optional` that has no value.
- NOTE: This method takes arguments that define the structure of the value
+ NOTE: This method takes an argument that defines the structure of the value
that would be contained in the returned `Optional` if it had a value.
Args:
- output_shapes: A nested structure of `tf.TensorShape` objects
- corresponding to each component of this optional.
- output_types: A nested structure of `tf.DType` objects corresponding to
- each component of this optional.
- output_classes: A nested structure of Python `type` objects corresponding
- to each component of this optional.
+ value_structure: A `Structure` object representing the structure of the
+ components of this optional.
Returns:
An `Optional` that has no value.
"""
- return _OptionalImpl(gen_dataset_ops.optional_none(), output_shapes,
- output_types, output_classes)
+ return _OptionalImpl(gen_dataset_ops.optional_none(), value_structure)
class _OptionalImpl(Optional):
@@ -159,20 +117,9 @@ class _OptionalImpl(Optional):
`Optional.__init__()` in the public API.
"""
- def __init__(self, variant_tensor, output_shapes, output_types,
- output_classes):
- # TODO(b/110122868): Consolidate the structure validation logic with the
- # similar logic in `Iterator.from_structure()` and
- # `Dataset.from_generator()`.
- output_types = nest.map_structure(dtypes.as_dtype, output_types)
- output_shapes = nest.map_structure_up_to(
- output_types, tensor_shape.as_shape, output_shapes)
- nest.assert_same_structure(output_types, output_shapes)
- nest.assert_same_structure(output_types, output_classes)
+ def __init__(self, variant_tensor, value_structure):
self._variant_tensor = variant_tensor
- self._output_shapes = output_shapes
- self._output_types = output_types
- self._output_classes = output_classes
+ self._value_structure = value_structure
def has_value(self, name=None):
return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name)
@@ -182,28 +129,55 @@ class _OptionalImpl(Optional):
# in `Iterator.get_next()` and `StructuredFunctionWrapper`.
with ops.name_scope(name, "OptionalGetValue",
[self._variant_tensor]) as scope:
- return sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(
- self._output_types,
- gen_dataset_ops.optional_get_value(
- self._variant_tensor,
- name=scope,
- output_types=nest.flatten(
- sparse.as_dense_types(self._output_types,
- self._output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self._output_shapes,
- self._output_classes)))),
- self._output_types, self._output_shapes, self._output_classes)
+ # pylint: disable=protected-access
+ return self._value_structure._from_tensor_list(
+ gen_dataset_ops.optional_get_value(
+ self._variant_tensor,
+ name=scope,
+ output_types=self._value_structure._flat_types,
+ output_shapes=self._value_structure._flat_shapes))
@property
- def output_classes(self):
- return self._output_classes
+ def value_structure(self):
+ return self._value_structure
+
+
+class OptionalStructure(structure.Structure):
+ """Represents an optional potentially containing a structured value."""
+
+ def __init__(self, value_structure):
+ self._value_structure = value_structure
@property
- def output_shapes(self):
- return self._output_shapes
+ def _flat_shapes(self):
+ return [tensor_shape.scalar()]
@property
- def output_types(self):
- return self._output_types
+ def _flat_types(self):
+ return [dtypes.variant]
+
+ def is_compatible_with(self, other):
+ # pylint: disable=protected-access
+ return (isinstance(other, OptionalStructure) and
+ self._value_structure.is_compatible_with(other._value_structure))
+
+ def _to_tensor_list(self, value):
+ return [value._variant_tensor] # pylint: disable=protected-access
+
+ def _from_tensor_list(self, flat_value):
+ if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
+ not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "OptionalStructure corresponds to a single tf.variant scalar.")
+ # pylint: disable=protected-access
+ return _OptionalImpl(flat_value[0], self._value_structure)
+
+ @staticmethod
+ def from_value(value):
+ return OptionalStructure(value.value_structure)
+
+
+# pylint: disable=protected-access
+structure.Structure._register_custom_converter(Optional,
+ OptionalStructure.from_value)
+# pylint: enable=protected-access
diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py
index c5764b8dfe..a90ca258c0 100644
--- a/tensorflow/python/data/util/structure.py
+++ b/tensorflow/python/data/util/structure.py
@@ -28,6 +28,9 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import sparse_ops
+_STRUCTURE_CONVERSION_FUNCTION_REGISTRY = {}
+
+
class Structure(object):
"""Represents structural information, such as type and shape, about a value.
@@ -64,12 +67,10 @@ class Structure(object):
raise NotImplementedError("Structure._flat_shapes")
@abc.abstractmethod
- def is_compatible_with(self, value):
- """Returns `True` if `value` is compatible with this structure.
+ def is_compatible_with(self, other):
+ """Returns `True` if `other` is compatible with this structure.
- A value `value` is compatible with a structure `s` if
- `Structure.from_value(value)` would return a structure `t` that is a
- "subtype" of `s`. A structure `t` is a "subtype" of `s` if:
+ A structure `t` is a "subtype" of `s` if:
* `s` and `t` are instances of the same `Structure` subclass.
* The nested structures (if any) of `s` and `t` are the same, according to
@@ -83,10 +84,10 @@ class Structure(object):
`tf.TensorShape.is_compatible_with`.
Args:
- value: A potentially structured value.
+ other: A `Structure`.
Returns:
- `True` if `value` matches this structure, otherwise `False`.
+ `True` if `other` is a subtype of this structure, otherwise `False`.
"""
raise NotImplementedError("Structure.is_compatible_with()")
@@ -98,7 +99,7 @@ class Structure(object):
`self._flat_types` to represent structured values in lower level APIs
(such as plain TensorFlow operations) that do not understand structure.
- Requires: `self.is_compatible_with(value)`.
+ Requires: `self.is_compatible_with(Structure.from_value(value))`.
Args:
value: A value with compatible structure.
@@ -137,9 +138,8 @@ class Structure(object):
TypeError: If a structure cannot be built for `value`, because its type
or one of its component types is not supported.
"""
-
- # TODO(b/110122868): Add support for custom types, Dataset, and Optional
- # to this method.
+ # TODO(b/110122868): Add support for custom types and Dataset to this
+ # method.
if isinstance(
value,
(sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)):
@@ -147,12 +147,76 @@ class Structure(object):
elif isinstance(value, (tuple, dict)):
return NestedStructure.from_value(value)
else:
+ for converter_type, converter_fn in (
+ _STRUCTURE_CONVERSION_FUNCTION_REGISTRY.items()):
+ if isinstance(value, converter_type):
+ return converter_fn(value)
try:
tensor = ops.convert_to_tensor(value)
except (ValueError, TypeError):
raise TypeError("Could not build a structure for %r" % value)
return TensorStructure.from_value(tensor)
+ @staticmethod
+ def _from_legacy_structure(output_types, output_shapes, output_classes):
+ """Returns a `Structure` that represents the given legacy structure.
+
+ This method provides a way to convert from the existing `Dataset` and
+ `Iterator` structure-related properties to a `Structure` object.
+
+ TODO(b/110122868): Remove this method once `Structure` is used throughout
+ `tf.data`.
+
+ Args:
+ output_types: A nested structure of `tf.DType` objects corresponding to
+ each component of a structured value.
+ output_shapes: A nested structure of `tf.TensorShape` objects
+ corresponding to each component a structured value.
+ output_classes: A nested structure of Python `type` objects corresponding
+ to each component of a structured value.
+
+ Returns:
+ A `Structure`.
+
+ Raises:
+ TypeError: If a structure cannot be built the arguments, because one of
+ the component classes in `output_classes` is not supported.
+ """
+ flat_types = nest.flatten(output_types)
+ flat_shapes = nest.flatten(output_shapes)
+ flat_classes = nest.flatten(output_classes)
+ flat_ret = []
+ for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
+ flat_classes):
+ if issubclass(flat_class, sparse_tensor_lib.SparseTensor):
+ flat_ret.append(SparseTensorStructure(flat_type, flat_shape))
+ elif issubclass(flat_class, ops.Tensor):
+ flat_ret.append(TensorStructure(flat_type, flat_shape))
+ else:
+ # NOTE(mrry): Since legacy structures produced by iterators only
+ # comprise Tensors, SparseTensors, and nests, we do not need to support
+ # all structure types here.
+ raise TypeError(
+ "Could not build a structure for output class %r" % flat_type)
+
+ ret = nest.pack_sequence_as(output_classes, flat_ret)
+ if isinstance(ret, Structure):
+ return ret
+ else:
+ return NestedStructure(ret)
+
+ @staticmethod
+ def _register_custom_converter(type_object, converter_fn):
+ """Registers `converter_fn` for converting values of the given type.
+
+ Args:
+ type_object: A Python `type` object representing the type of values
+ accepted by `converter_fn`.
+ converter_fn: A function that takes one argument (an instance of the
+ type represented by `type_object`) and returns a `Structure`.
+ """
+ _STRUCTURE_CONVERSION_FUNCTION_REGISTRY[type_object] = converter_fn
+
# NOTE(mrry): The following classes make extensive use of non-public methods of
# their base class, so we disable the protected-access lint warning once here.
@@ -179,16 +243,21 @@ class NestedStructure(Structure):
def _flat_types(self):
return self._flat_types_list
- def is_compatible_with(self, value):
+ def is_compatible_with(self, other):
+ if not isinstance(other, NestedStructure):
+ return False
try:
- nest.assert_shallow_structure(self._nested_structure, value)
+ # pylint: disable=protected-access
+ nest.assert_same_structure(self._nested_structure,
+ other._nested_structure)
except (ValueError, TypeError):
return False
return all(
- s.is_compatible_with(v) for s, v in zip(
+ substructure.is_compatible_with(other_substructure)
+ for substructure, other_substructure in zip(
nest.flatten(self._nested_structure),
- nest.flatten_up_to(self._nested_structure, value)))
+ nest.flatten(other._nested_structure)))
def _to_tensor_list(self, value):
ret = []
@@ -201,7 +270,7 @@ class NestedStructure(Structure):
for sub_value, structure in zip(flat_value,
nest.flatten(self._nested_structure)):
- if not structure.is_compatible_with(sub_value):
+ if not structure.is_compatible_with(Structure.from_value(sub_value)):
raise ValueError("Component value %r is not compatible with the nested "
"structure %r." % (sub_value, structure))
ret.extend(structure._to_tensor_list(sub_value))
@@ -242,17 +311,13 @@ class TensorStructure(Structure):
def _flat_types(self):
return [self._dtype]
- def is_compatible_with(self, value):
- try:
- value = ops.convert_to_tensor(value, dtype=self._dtype)
- except (ValueError, TypeError):
- return False
-
- return (self._dtype.is_compatible_with(value.dtype) and
- self._shape.is_compatible_with(value.shape))
+ def is_compatible_with(self, other):
+ return (isinstance(other, TensorStructure) and
+ self._dtype.is_compatible_with(other._dtype) and
+ self._shape.is_compatible_with(other._shape))
def _to_tensor_list(self, value):
- if not self.is_compatible_with(value):
+ if not self.is_compatible_with(Structure.from_value(value)):
raise ValueError("Value %r is not convertible to a tensor with dtype %s "
"and shape %s." % (value, self._dtype, self._shape))
return [value]
@@ -260,7 +325,7 @@ class TensorStructure(Structure):
def _from_tensor_list(self, flat_value):
if len(flat_value) != 1:
raise ValueError("TensorStructure corresponds to a single tf.Tensor.")
- if not self.is_compatible_with(flat_value[0]):
+ if not self.is_compatible_with(Structure.from_value(flat_value[0])):
raise ValueError("Cannot convert %r to a tensor with dtype %s and shape "
"%s." % (flat_value[0], self._dtype, self._shape))
return flat_value[0]
@@ -285,16 +350,10 @@ class SparseTensorStructure(Structure):
def _flat_types(self):
return [dtypes.variant]
- def is_compatible_with(self, value):
- try:
- value = sparse_tensor_lib.SparseTensor.from_value(value)
- except TypeError:
- return False
- return (isinstance(value, (sparse_tensor_lib.SparseTensor,
- sparse_tensor_lib.SparseTensorValue)) and
- self._dtype.is_compatible_with(value.dtype) and
- self._dense_shape.is_compatible_with(
- tensor_util.constant_value_as_shape(value.dense_shape)))
+ def is_compatible_with(self, other):
+ return (isinstance(other, SparseTensorStructure) and
+ self._dtype.is_compatible_with(other._dtype) and
+ self._dense_shape.is_compatible_with(other._dense_shape))
def _to_tensor_list(self, value):
return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)]
diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py
index d0c7df67ae..2982763181 100644
--- a/tensorflow/python/data/util/structure_test.py
+++ b/tensorflow/python/data/util/structure_test.py
@@ -25,7 +25,9 @@ from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -106,13 +108,17 @@ class StructureTest(test.TestCase, parameterized.TestCase):
indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
}, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
)
- def testIsCompatibleWith(self, original_value, compatible_values,
- incompatible_values):
+ def testIsCompatibleWithStructure(self, original_value, compatible_values,
+ incompatible_values):
s = structure.Structure.from_value(original_value)
for compatible_value in compatible_values:
- self.assertTrue(s.is_compatible_with(compatible_value))
+ self.assertTrue(
+ s.is_compatible_with(
+ structure.Structure.from_value(compatible_value)))
for incompatible_value in incompatible_values:
- self.assertFalse(s.is_compatible_with(incompatible_value))
+ self.assertFalse(
+ s.is_compatible_with(
+ structure.Structure.from_value(incompatible_value)))
# NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
# will be executed before the (eager- or graph-mode) test environment has been
@@ -322,6 +328,28 @@ class StructureTest(test.TestCase, parameterized.TestCase):
ValueError, "Expected 3 flat values in NestedStructure but got 2."):
s_2._from_tensor_list(flat_s_1)
+ @parameterized.named_parameters(
+ ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor,
+ structure.TensorStructure(dtypes.float32, [])),
+ ("SparseTensor", dtypes.int32, tensor_shape.matrix(2, 2),
+ sparse_tensor.SparseTensor,
+ structure.SparseTensorStructure(dtypes.int32, [2, 2])),
+ ("Nest",
+ {"a": dtypes.float32, "b": (dtypes.int32, dtypes.string)},
+ {"a": tensor_shape.scalar(),
+ "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar())},
+ {"a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor)},
+ structure.NestedStructure({
+ "a": structure.TensorStructure(dtypes.float32, []),
+ "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
+ structure.TensorStructure(dtypes.string, []))})),
+ )
+ def testFromLegacyStructure(self, output_types, output_shapes, output_classes,
+ expected_structure):
+ actual_structure = structure.Structure._from_legacy_structure(
+ output_types, output_shapes, output_classes)
+ self.assertTrue(expected_structure.is_compatible_with(actual_structure))
+ self.assertTrue(actual_structure.is_compatible_with(expected_structure))
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index 55231954d1..4630bda590 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -57,7 +57,8 @@ def no_rewrite_session_config():
disable_model_pruning=True,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
- dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
diff --git a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
index 676097fde9..1f67f8a0d4 100644
--- a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
+++ b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
@@ -45,6 +45,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
def _no_rewrite_session_config(self):
rewriter_config = rewriter_config_pb2.RewriterConfig(
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF,
min_graph_nodes=-1)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
@@ -156,7 +157,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
sess, cond, expected_output=21.0)
def testReconstructGraphWithWhileLoop(self):
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
loop_body = lambda i: math_ops.add(i, 2)
loop_cond = lambda i: math_ops.less(i, 16)
i = constant_op.constant(10, name="i")
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index a2686c68a9..d3d997e6df 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -17,7 +17,10 @@ cc_library(
"pywrap_tensor.h",
"pywrap_tfe.h",
],
- visibility = ["//tensorflow:internal"],
+ visibility = [
+ "//learning/deepmind/courier:__pkg__",
+ "//tensorflow:internal",
+ ],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
@@ -46,6 +49,7 @@ py_library(
":backprop",
":context",
":core",
+ ":def_function",
":execute",
":function",
":graph_only_ops",
@@ -380,3 +384,30 @@ cuda_py_test(
"optonly", # The test is too slow in non-opt mode
],
)
+
+py_library(
+ name = "def_function",
+ srcs = ["def_function.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":context",
+ ":function",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/training/checkpointable:base",
+ ],
+)
+
+py_test(
+ name = "def_function_test",
+ srcs = ["def_function_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":def_function",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ ],
+)
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
new file mode 100644
index 0000000000..8dcacd5c99
--- /dev/null
+++ b/tensorflow/python/eager/def_function.py
@@ -0,0 +1,235 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=unidiomatic-typecheck
+"""Prototype decorator for defining graph-mode functions with eager semantics."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training.checkpointable import base as checkpointable
+
+
+class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable):
+ """Variable which does not lift its initializer out of function context.
+
+ Instances of this variable, when created, build a graph which runs their
+ initializer inside a tf.cond(is_initialized) block.
+
+ This can only be created inside a defun called from (eventually) eager
+ mode. That is, non-function-building graphs are not supported.
+ """
+
+ def __init__(self, # pylint: disable=super-init-not-called
+ initial_value=None,
+ trainable=True,
+ caching_device=None,
+ name=None,
+ dtype=None,
+ constraint=None,
+ **unused_kwargs):
+ """Creates a variable.
+
+ Args:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called.
+ (Note that initializer functions from init_ops.py must first be bound
+ to a shape before being used here.)
+ trainable: If `True`, GradientTapes automatically watch uses of this
+ Variable.
+ caching_device: Optional device string or function describing where the
+ Variable should be cached for reading. Defaults to the Variable's
+ device. If not `None`, caches on another device. Typical use is to
+ cache on the device where the Ops using the Variable reside, to
+ deduplicate copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ dtype: If set, initial_value will be converted to the given type.
+ If None, either the datatype will be kept (if initial_value is
+ a Tensor) or float32 will be used (if it is a Python object convertible
+ to a Tensor).
+ constraint: An optional projection function to be applied to the variable
+ after being updated by an `Optimizer` (e.g. used to implement norm
+ constraints or value constraints for layer weights). The function must
+ take as input the unprojected Tensor representing the value of the
+ variable and return the Tensor for the projected value
+ (which must have the same shape). Constraints are not safe to
+ use when doing asynchronous distributed training.
+
+ Raises:
+ ValueError: If the initial value is not specified, or does not have a
+ shape and `validate_shape` is `True`.
+ RuntimeError: If called outside of a function definition.
+ """
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "UnliftedInitializerVariable should not be created "
+ "outside of functions.")
+ with ops.init_scope():
+ if not context.executing_eagerly():
+ raise RuntimeError(
+ "UnliftedInitializerVariable does not support legacy graph mode.")
+ self._in_graph_mode = False
+ if initial_value is None:
+ raise ValueError("initial_value must be specified.")
+ init_from_fn = callable(initial_value)
+
+ if constraint is not None and not callable(constraint):
+ raise ValueError("The `constraint` argument must be a callable.")
+
+ if isinstance(initial_value, checkpointable.CheckpointInitialValue):
+ self._maybe_initialize_checkpointable()
+ self._update_uid = initial_value.checkpoint_position.restore_uid
+ initial_value = initial_value.wrapped_value
+
+ self._trainable = trainable
+ self._save_slice_info = None
+ self._initial_value = None
+ self._initializer_op = None
+ self._is_initialized_op = None
+ self._graph_element = None
+ self._cached_value = None
+ # Store the graph key so optimizers know how to only retrieve variables from
+ # this graph. Guaranteed to be the same as the eager graph_key.
+ self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ with ops.name_scope(name, "Variable", []
+ if init_from_fn else [initial_value]) as name:
+ # pylint: disable=protected-access
+ with ops.init_scope():
+ assert context.executing_eagerly()
+ shared_name = ops._name_from_scope_name(name)
+ shared_name = "%s_%d" % (shared_name, ops.uid())
+ # Use attr_scope and device(None) to simulate the behavior of
+ # colocate_with when the variable we want to colocate with doesn't
+ # yet exist.
+ with ops.name_scope("Initializer"), ops.device(None):
+ initial_value = ops.convert_to_tensor(
+ initial_value() if init_from_fn else initial_value,
+ name="initial_value", dtype=dtype)
+ with ops.init_scope():
+ self._handle = resource_variable_ops.eager_safe_variable_handle(
+ shape=initial_value.get_shape(),
+ dtype=initial_value.dtype.base_dtype,
+ shared_name=shared_name,
+ name=name,
+ graph_mode=False)
+ self._shape = initial_value.shape
+ self._unique_id = shared_name
+ self._handle_name = shared_name + ":0"
+ self._dtype = initial_value.dtype.base_dtype
+ self._constraint = constraint
+ assert initial_value is not None
+ def assign_fn():
+ with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
+ resource_variable_ops.assign_variable_op(
+ self._handle,
+ initial_value,
+ name=n)
+ # Returning values to keep tf.cond happy.
+ return ops.convert_to_tensor(1)
+ def not_assign_fn():
+ return ops.convert_to_tensor(0)
+ # Note: this cond is always guaranteed to run because we're inside a defun
+ # which will insert automatic control dependencies.
+ control_flow_ops.cond(
+ resource_variable_ops.var_is_initialized_op(self._handle),
+ not_assign_fn, assign_fn)
+
+ # After the handle has been created, set up a way to clean it up when
+ # executing eagerly. We'll hold the only reference to the deleter, so that
+ # when this object is garbage collected the deleter will be too. This
+ # means ResourceVariables can be part of reference cycles without those
+ # cycles being uncollectable.
+ self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._handle, handle_device=self._handle.device)
+ self._cached_shape_as_list = None
+
+
+def _defun_with_scope(scope, fn):
+
+ def wrapped_fn(*args, **kwds):
+ with variable_scope.variable_creator_scope(scope):
+ return fn(*args, **kwds)
+
+ return function.defun(wrapped_fn)
+
+
+def def_function(fn):
+ """Defines a function as per the "functions, not sessions" document."""
+
+ # Wrapping the values in lists to bypass python's lack of way to mutate
+ # symbols from an outer scope.
+ first_call = [True]
+ function_to_call = []
+
+ # TODO(apassos) represent this as an object and not as a closure.
+ def decorated_fn(*args, **kwds):
+ """Graph function for fn."""
+ if not first_call[0]:
+ return function_to_call[0](*args, **kwds)
+
+ first_call[0] = False
+ created_variables = []
+
+ def variable_creator_scope(unused_next_creator, **kwds):
+ """Creates UnliftedInitializerVariables and saves references to them."""
+ v = UnliftedInitializerVariable(**kwds)
+ created_variables.append(v)
+ return v
+
+ first_graph_function = _defun_with_scope(variable_creator_scope, fn)
+
+ # Force the definition of the function for these arguments
+ first_concrete = first_graph_function.get_concrete_function(*args, **kwds)
+
+ def invalid_creator_scope(*unused_args, **unused_kwds):
+ """Disables variable creation."""
+ raise ValueError(
+ "def_function-decorated function tried to create "
+ "variables on second call.")
+
+ second_graph_function = _defun_with_scope(invalid_creator_scope, fn)
+
+ function_to_call.append(second_graph_function)
+ if not created_variables:
+ # Note: this retracing might be unnecessary, but running the function
+ # forever in the scope which disallows variable creation is safer than not
+ # doing so.
+ return second_graph_function(*args, **kwds)
+
+ def fn_with_cond(*inner_args, **inner_kwds):
+ """Conditionally runs initialization if it's needed."""
+ condition = True
+ for variable in created_variables:
+ condition = condition and resource_variable_ops.var_is_initialized_op(
+ variable.handle)
+ # We want to call second_graph_function if possible because it avoids
+ # recomputing potentially expensive initializers.
+ return control_flow_ops.cond(
+ condition,
+ lambda: second_graph_function(*inner_args, **inner_kwds),
+ lambda: first_concrete(*inner_args, **inner_kwds))
+
+ return function.defun(fn_with_cond)(*args, **kwds)
+
+ return decorated_fn
diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py
new file mode 100644
index 0000000000..804436c4bb
--- /dev/null
+++ b/tensorflow/python/eager/def_function_test.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class DefFunctionTest(test.TestCase):
+
+ def testNoVariables(self):
+
+ @def_function.def_function
+ def fn(x):
+ return 2 * x
+
+ self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0)
+
+ def testFailIfVariablesAreCreatedMoreThanOnce(self):
+
+ @def_function.def_function
+ def fn(x):
+ return variables.Variable(1.0) + x
+
+ with self.assertRaises(ValueError):
+ fn(1.0)
+
+ def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self):
+ state = []
+
+ @def_function.def_function
+ def fn(x):
+ state.append(variables.Variable(1.0))
+ return state[-1] + x
+
+ with self.assertRaises(ValueError):
+ fn(1.0)
+
+ def testCorrectVariableCreation(self):
+
+ state = []
+
+ @def_function.def_function
+ def fn(x):
+ if not state:
+ state.append(variables.Variable(2.0))
+ return state[0] * x
+
+ self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
+ self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
+
+ def testVariableInitializerNotConstant(self):
+
+ state = []
+
+ @def_function.def_function
+ def fn(x):
+ if not state:
+ state.append(variables.Variable(2.0 * x))
+ return state[0] * x
+
+ self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
+ self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
+
+
+if __name__ == '__main__':
+ ops.enable_eager_execution()
+ test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index bcb1881264..1f5d479882 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -850,7 +850,7 @@ def _get_defun_inputs_from_args(args):
def func_graph_from_py_func(name,
python_func,
args,
- kwds,
+ kwargs,
signature=None,
func_graph=None):
"""Returns a `FuncGraph` generated from `python_func`.
@@ -860,11 +860,11 @@ def func_graph_from_py_func(name,
python_func: the Python function to trace.
args: the positional args with which the Python function should be called;
ignored if a signature is provided.
- kwds: the keyword args with which the Python function should be called;
+ kwargs: the keyword args with which the Python function should be called;
ignored if a signature is provided.
signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
and dtypes of the arguments. When a signature is provided, `args` and
- `kwds` are ignored, and `python_func` is traced with Tensors conforming
+ `kwargs` are ignored, and `python_func` is traced with Tensors conforming
to `signature`. If `None`, the shapes and dtypes are inferred from the
inputs.
func_graph: Optional. An instance of FuncGraph. If provided, we will use
@@ -885,16 +885,17 @@ def func_graph_from_py_func(name,
if signature is None:
func_args = _get_defun_inputs_from_args(args)
- func_kwds = _get_defun_inputs_from_args(kwds)
+ func_kwargs = _get_defun_inputs_from_args(kwargs)
else:
func_args = _get_defun_inputs_from_signature(signature)
- func_kwds = {}
+ func_kwargs = {}
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
# Variables to help check whether mutation happens in calling the function
# Copy the recursive list, tuple and map structure, but not base objects
func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
- func_kwds_before = nest.pack_sequence_as(func_kwds, nest.flatten(func_kwds))
+ func_kwargs_before = nest.pack_sequence_as(
+ func_kwargs, nest.flatten(func_kwargs))
def convert(x):
"""Converts an argument to a Tensor."""
@@ -913,7 +914,7 @@ def func_graph_from_py_func(name,
this_tape = tape.push_new_tape()
try:
- func_outputs = python_func(*func_args, **func_kwds)
+ func_outputs = python_func(*func_args, **func_kwargs)
# invariant: `func_outputs` contains only Tensors and `None`s.
func_outputs = nest.map_structure(convert, func_outputs)
@@ -933,16 +934,16 @@ def func_graph_from_py_func(name,
raise ValueError(errmsg)
check_mutation(func_args_before, func_args)
- check_mutation(func_kwds_before, func_kwds)
+ check_mutation(func_kwargs_before, func_kwargs)
finally:
tape.pop_tape(this_tape)
- # Variables in `func_args`, `func_kwds` should be explicit inputs
+ # Variables in `func_args`, `func_kwargs` should be explicit inputs
# to the function, not captured inputs.
tape_variables = this_tape.watched_variables()
arg_variables = set()
inputs = []
- for arg in nest.flatten(func_args) + nest.flatten(func_kwds):
+ for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
if isinstance(arg, resource_variable_ops.ResourceVariable):
try:
resource_placeholder = func_graph.captures.pop(arg.handle)
@@ -1073,11 +1074,11 @@ class PolymorphicFunction(object):
if isinstance(python_function, functools.partial):
self._python_function = python_function.func
self._args_to_prepend = python_function.args or tuple()
- self._kwds_to_include = python_function.keywords or {}
+ self._kwargs_to_include = python_function.keywords or {}
else:
self._python_function = python_function
self._args_to_prepend = tuple()
- self._kwds_to_include = {}
+ self._kwargs_to_include = {}
self._name = name
self._function_cache = collections.OrderedDict()
self._function_attributes = attributes or {}
@@ -1115,9 +1116,9 @@ class PolymorphicFunction(object):
self._input_signature = tuple(input_signature)
self._flat_input_signature = tuple(nest.flatten(input_signature))
- def __call__(self, *args, **kwds):
+ def __call__(self, *args, **kwargs):
"""Calls a graph function specialized to the inputs."""
- graph_function, inputs = self._maybe_define_function(*args, **kwds)
+ graph_function, inputs = self._maybe_define_function(args, kwargs)
return graph_function(*inputs)
@property
@@ -1135,7 +1136,7 @@ class PolymorphicFunction(object):
*args: inputs to specialize on.
**kwargs: inputs to specialize on.
"""
- graph_function, _ = self._maybe_define_function(*args, **kwargs)
+ graph_function, _ = self._maybe_define_function(args, kwargs)
return graph_function
def __get__(self, instance, owner):
@@ -1156,13 +1157,13 @@ class PolymorphicFunction(object):
# then `instance` will be `foo` (and `owner` will be `Foo`).
return functools.partial(self.__call__, instance)
- def _cache_key(self, args, kwds, ctx, graph):
+ def _cache_key(self, args, kwargs, ctx, graph):
"""Computes the cache key given inputs and execution context."""
if self._input_signature is None:
- inputs = (args, kwds) if kwds else args
+ inputs = (args, kwargs) if kwargs else args
cache_key = tuple(_encode_arg(arg) for arg in inputs)
else:
- del args, kwds
+ del args, kwargs
cache_key = self._flat_input_signature
# The graph, or whether we're executing eagerly, should be a part of the
@@ -1181,8 +1182,8 @@ class PolymorphicFunction(object):
return cache_key + (execution_context, device_functions, colocation_stack)
- def _canonicalize_function_inputs(self, *args, **kwds):
- """Canonicalizes `args` and `kwds`.
+ def _canonicalize_function_inputs(self, *args, **kwargs):
+ """Canonicalizes `args` and `kwargs`.
Canonicalize the inputs to the Python function using its fullargspec. In
particular, we parse the varags and kwargs that this
@@ -1192,28 +1193,28 @@ class PolymorphicFunction(object):
Args:
*args: The varargs this object was called with.
- **kwds: The keyword args this function was called with.
+ **kwargs: The keyword args this function was called with.
Returns:
A canonicalized ordering of the inputs.
Raises:
- ValueError: If a keyword in `kwds` cannot be matched with a positional
+ ValueError: If a keyword in `kwargs` cannot be matched with a positional
argument when an input signature is specified, or when the inputs
do not conform to the input signature.
"""
args = self._args_to_prepend + args
- kwds = dict(kwds, **self._kwds_to_include)
+ kwargs = dict(kwargs, **self._kwargs_to_include)
# Maps from index of arg to its corresponding value, according to `args`
- # and `kwds`; seeded with the default values for the named args that aren't
- # in `args`.
+ # and `kwargs`; seeded with the default values for the named args that
+ # aren't in `args`.
arg_indices_to_values = {
index: default
for index, default in six.iteritems(self._arg_indices_to_default_values)
if index >= len(args)
}
consumed_args = []
- for arg, value in six.iteritems(kwds):
+ for arg, value in six.iteritems(kwargs):
index = self._args_to_indices.get(arg, None)
if index is not None:
arg_indices_to_values[index] = value
@@ -1223,9 +1224,9 @@ class PolymorphicFunction(object):
"function with keyword arguments when "
"input_signature is provided.")
for arg in consumed_args:
- # After this loop, `kwds` will only contain true keyword arguments, as
+ # After this loop, `kwargs` will only contain true keyword arguments, as
# opposed to named arguments called in a keyword-like fashion.
- kwds.pop(arg)
+ kwargs.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values)
flat_inputs = nest.flatten(inputs)
@@ -1239,9 +1240,9 @@ class PolymorphicFunction(object):
inputs = nest.pack_sequence_as(structure=inputs,
flat_sequence=flat_inputs)
if self._input_signature is None:
- return inputs, kwds
+ return inputs, kwargs
else:
- assert not kwds
+ assert not kwargs
try:
nest.assert_same_structure(self._input_signature, inputs)
except (ValueError, TypeError):
@@ -1260,24 +1261,27 @@ class PolymorphicFunction(object):
(str(inputs), str(self._input_signature)))
return inputs, {}
- def _maybe_define_function(self, *args, **kwds):
+ def _maybe_define_function(self, args, kwargs):
"""Gets a function for these inputs, defining it if necessary.
+ `args` and `kwargs` can be None if this `PolymorphicFunction` was created
+ with an `input_signature`.
+
Args:
- *args: args for the Python function.
- **kwds: keywords for the Python function.
+ args: The varargs for the Python function.
+ kwargs: The keyword args for the Python function.
Returns:
A graph function corresponding to the input signature implied by args and
- kwds, as well as the inputs that the object should be called with.
+ kwargs, as well as the inputs that the object should be called with.
Raises:
ValueError: If inputs are incompatible with the input signature.
TypeError: If the function inputs include non-hashable objects
"""
-
- args, kwds = self._canonicalize_function_inputs(*args, **kwds)
- cache_key = self._cache_key(args, kwds, context.context(),
+ if self._input_signature is None or args is not None or kwargs is not None:
+ args, kwargs = self._canonicalize_function_inputs(*args, **kwargs)
+ cache_key = self._cache_key(args, kwargs, context.context(),
ops.get_default_graph())
with self._lock:
try:
@@ -1289,11 +1293,11 @@ class PolymorphicFunction(object):
if graph_function is None:
graph_function = Function(
func_graph_from_py_func(self._name, self._python_function, args,
- kwds, self._input_signature),
+ kwargs, self._input_signature),
self._function_attributes)
self._function_cache[cache_key] = graph_function
return graph_function, [
- t for t in nest.flatten((args, kwds))
+ t for t in nest.flatten((args, kwargs))
if isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable))
]
@@ -1933,9 +1937,9 @@ def automatic_control_dependencies(f):
The wrapped function.
"""
- def wrapper(*args, **kwds):
+ def wrapper(*args, **kwargs):
with AutomaticControlDependencies() as a:
- result = f(*args, **kwds)
+ result = f(*args, **kwargs)
result_flat = [a.mark_as_return(t) for t in nest.flatten(result)]
return nest.pack_sequence_as(result, result_flat)
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index 28c5c82d2c..57f7af7635 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -3433,9 +3433,11 @@ def _safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
+ if not isinstance(embedding_weights[0],
+ resource_variable_ops.ResourceVariable):
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
with ops.name_scope(name, 'embedding_lookup',
embedding_weights + [sparse_ids,
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 68b3170dfe..f287289bd0 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -1096,6 +1096,21 @@ def _from_library(lib):
return initialized.values()
+def _get_experimental_kwarg_as_attr(attr_name, value):
+ """Creates an AttrValue for a python object."""
+ if isinstance(value, bool):
+ return attr_value_pb2.AttrValue(b=value)
+ elif isinstance(value, int):
+ return attr_value_pb2.AttrValue(i=value)
+ elif isinstance(value, float):
+ return attr_value_pb2.AttrValue(f=value)
+ elif isinstance(value, str):
+ return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
+ else:
+ raise ValueError("Unsupported attribute type for %s with type %s" %
+ (attr_name, type(value)))
+
+
def _parse_kwargs_as_attrs(func_name, **kwargs):
"""Parses **kwargs into a node's attributes."""
attrs = {}
@@ -1122,7 +1137,7 @@ def _parse_kwargs_as_attrs(func_name, **kwargs):
kwargs_keys = list(kwargs.keys())
for key in kwargs_keys:
if key.startswith("experimental_"):
- attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(kwargs[key]))
+ attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key])
del kwargs[key]
if kwargs:
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 903768a039..f740e5cfaa 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1331,12 +1331,33 @@ class FunctionsFromProtos(test.TestCase):
def testExperimentalAttrs(self):
@function.Defun(dtypes.int32, experimental_tag="tag_value")
- def FunctionWithAttr(i):
+ def FunctionWithStrAttr(i):
return array_ops.identity(i)
- self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr)
- self.assertEqual(FunctionWithAttr.definition.attr["experimental_tag"].s,
+ @function.Defun(dtypes.int32, experimental_tag=123)
+ def FunctionWithIntAttr(i):
+ return array_ops.identity(i)
+
+ @function.Defun(dtypes.int32, experimental_tag=123.0)
+ def FunctionWithFloatAttr(i):
+ return array_ops.identity(i)
+
+ @function.Defun(dtypes.int32, experimental_tag=True)
+ def FunctionWithBoolAttr(i):
+ return array_ops.identity(i)
+
+ self.assertTrue("experimental_tag" in FunctionWithStrAttr.definition.attr)
+ self.assertEqual(FunctionWithStrAttr.definition.attr["experimental_tag"].s,
b"tag_value")
+ self.assertTrue("experimental_tag" in FunctionWithIntAttr.definition.attr)
+ self.assertEqual(FunctionWithIntAttr.definition.attr["experimental_tag"].i,
+ 123)
+ self.assertTrue("experimental_tag" in FunctionWithFloatAttr.definition.attr)
+ self.assertEqual(
+ FunctionWithFloatAttr.definition.attr["experimental_tag"].f, 123.0)
+ self.assertTrue("experimental_tag" in FunctionWithBoolAttr.definition.attr)
+ self.assertEqual(FunctionWithBoolAttr.definition.attr["experimental_tag"].b,
+ True)
@test_util.with_c_shapes
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index c302072aa1..68b7b323d5 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -1934,6 +1934,8 @@ class TensorFlowTestCase(googletest.TestCase):
rewriter_config_pb2.RewriterConfig.OFF)
config.graph_options.rewrite_options.arithmetic_optimization = (
rewriter_config_pb2.RewriterConfig.OFF)
+ config.graph_options.rewrite_options.pin_to_host_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
return config
return ErrorLoggingSession(graph=graph, config=prepare_config(config))
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 4a72c4b3f3..ac011a2940 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -7,6 +7,7 @@ exports_files(["LICENSE"])
package(default_visibility = ["//visibility:public"])
+load("@pip_deps//:requirements.bzl", "requirement")
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
@@ -62,6 +63,7 @@ py_library(
":backend",
":engine",
":layers",
+ requirement("keras_applications"),
"//tensorflow/python/saved_model",
"//tensorflow/python:training",
],
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 17831fa5cb..5183e4d30c 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1663,6 +1663,18 @@ cuda_py_test(
)
cuda_py_test(
+ name = "extract_volume_patches_op_test",
+ size = "small",
+ srcs = ["extract_volume_patches_op_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ ],
+)
+
+cuda_py_test(
name = "functional_ops_test",
size = "small",
srcs = ["functional_ops_test.py"],
diff --git a/tensorflow/python/kernel_tests/extract_volume_patches_op_test.py b/tensorflow/python/kernel_tests/extract_volume_patches_op_test.py
new file mode 100644
index 0000000000..64757a3e07
--- /dev/null
+++ b/tensorflow/python/kernel_tests/extract_volume_patches_op_test.py
@@ -0,0 +1,131 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for ExtractVolumePatches op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+class ExtractVolumePatches(test.TestCase):
+ """Functional tests for ExtractVolumePatches op."""
+
+ def _VerifyValues(self, image, ksizes, strides, padding, patches):
+ """Tests input-output pairs for the ExtractVolumePatches op.
+
+ Args:
+ image: Input tensor with shape:
+ [batch, in_planes, in_rows, in_cols, depth].
+ ksizes: Patch size specified as: [ksize_planes, ksize_rows, ksize_cols].
+ strides: Output strides, specified as:
+ [stride_planes, stride_rows, stride_cols].
+ padding: Padding type.
+ patches: Expected output.
+
+ Note:
+ rates are not supported as of now.
+ """
+ ksizes = [1] + ksizes + [1]
+ strides = [1] + strides + [1]
+
+ with self.test_session(use_gpu=True):
+ out_tensor = array_ops.extract_volume_patches(
+ constant_op.constant(image),
+ ksizes=ksizes,
+ strides=strides,
+ padding=padding,
+ name="im2col_3d")
+ self.assertAllClose(patches, out_tensor.eval())
+
+ # pylint: disable=bad-whitespace
+ def testKsize1x1x1Stride1x1x1(self):
+ """Verifies that for 1x1x1 kernel the output equals the input."""
+ image = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) + 1
+ patches = image
+ for padding in ["VALID", "SAME"]:
+ self._VerifyValues(
+ image,
+ ksizes=[1, 1, 1],
+ strides=[1, 1, 1],
+ padding=padding,
+ patches=patches)
+
+ def testKsize1x1x1Stride2x3x4(self):
+ """Test for 1x1x1 kernel and strides."""
+ image = np.arange(6 * 2 * 4 * 5 * 3).reshape([6, 2, 4, 5, 3]) + 1
+ patches = image[:, ::2, ::3, ::4, :]
+ for padding in ["VALID", "SAME"]:
+ self._VerifyValues(
+ image,
+ ksizes=[1, 1, 1],
+ strides=[2, 3, 4],
+ padding=padding,
+ patches=patches)
+
+ def testKsize1x1x2Stride2x2x3(self):
+ """Test for 1x1x2 kernel and strides."""
+ image = np.arange(45).reshape([1, 3, 3, 5, 1]) + 1
+ patches = np.array([[[[[ 1, 2],
+ [ 4, 5]],
+ [[11, 12],
+ [14, 15]]],
+ [[[31, 32],
+ [34, 35]],
+ [[41, 42],
+ [44, 45]]]]])
+ for padding in ["VALID", "SAME"]:
+ self._VerifyValues(
+ image,
+ ksizes=[1, 1, 2],
+ strides=[2, 2, 3],
+ padding=padding,
+ patches=patches)
+
+ def testKsize2x2x2Stride1x1x1Valid(self):
+ """Test for 2x2x2 kernel with VALID padding."""
+ image = np.arange(8).reshape([1, 2, 2, 2, 1]) + 1
+ patches = np.array([[[[[1, 2, 3, 4, 5, 6, 7, 8]]]]])
+ self._VerifyValues(
+ image,
+ ksizes=[2, 2, 2],
+ strides=[1, 1, 1],
+ padding="VALID",
+ patches=patches)
+
+ def testKsize2x2x2Stride1x1x1Same(self):
+ """Test for 2x2x2 kernel with SAME padding."""
+ image = np.arange(8).reshape([1, 2, 2, 2, 1]) + 1
+ patches = np.array([[[[[1, 2, 3, 4, 5, 6, 7, 8],
+ [2, 0, 4, 0, 6, 0, 8, 0]],
+ [[3, 4, 0, 0, 7, 8, 0, 0],
+ [4, 0, 0, 0, 8, 0, 0, 0]]],
+ [[[5, 6, 7, 8, 0, 0, 0, 0],
+ [6, 0, 8, 0, 0, 0, 0, 0]],
+ [[7, 8, 0, 0, 0, 0, 0, 0],
+ [8, 0, 0, 0, 0, 0, 0, 0]]]]])
+ self._VerifyValues(
+ image,
+ ksizes=[2, 2, 2],
+ strides=[1, 1, 1],
+ padding="SAME",
+ patches=patches)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index cf0beba3c3..b24a0d0f9b 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -34,7 +34,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
class LoggingOpsTest(test.TestCase):
@@ -273,39 +272,6 @@ class PrintV2Test(test.TestCase):
self.assertTrue((expected + "\n") in printed.contents())
@test_util.run_in_graph_and_eager_modes()
- def testPrintOneTensorLogInfo(self):
- with self.test_session():
- tensor = math_ops.range(10)
- with self.captureWritesToStream(sys.stderr) as printed:
- print_op = logging_ops.print_v2(
- tensor, output_stream=tf_logging.info)
- self.evaluate(print_op)
- expected = "[0 1 2 ... 7 8 9]"
- self.assertTrue(expected in printed.contents())
-
- @test_util.run_in_graph_and_eager_modes()
- def testPrintOneTensorLogWarning(self):
- with self.test_session():
- tensor = math_ops.range(10)
- with self.captureWritesToStream(sys.stderr) as printed:
- print_op = logging_ops.print_v2(
- tensor, output_stream=tf_logging.warning)
- self.evaluate(print_op)
- expected = "[0 1 2 ... 7 8 9]"
- self.assertTrue(expected in printed.contents())
-
- @test_util.run_in_graph_and_eager_modes()
- def testPrintOneTensorLogError(self):
- with self.test_session():
- tensor = math_ops.range(10)
- with self.captureWritesToStream(sys.stderr) as printed:
- print_op = logging_ops.print_v2(
- tensor, output_stream=tf_logging.error)
- self.evaluate(print_op)
- expected = "[0 1 2 ... 7 8 9]"
- self.assertTrue(expected in printed.contents())
-
- @test_util.run_in_graph_and_eager_modes()
def testInvalidOutputStreamRaisesError(self):
with self.test_session():
tensor = math_ops.range(10)
diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py
index 2e7aa30296..9c63385dd0 100644
--- a/tensorflow/python/ops/distributions/bijector_impl.py
+++ b/tensorflow/python/ops/distributions/bijector_impl.py
@@ -825,10 +825,21 @@ class Bijector(object):
min_event_ndims=self.inverse_min_event_ndims,
event_ndims=event_ndims)):
if not self._is_injective: # No caching for non-injective
- ildjs = self._inverse_log_det_jacobian(y, **kwargs)
- return tuple(self._reduce_jacobian_det_over_event(
- y, ildj, self.inverse_min_event_ndims, event_ndims)
- for ildj in ildjs)
+ try:
+ ildjs = self._inverse_log_det_jacobian(y, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ y, ildj, self.inverse_min_event_ndims, event_ndims)
+ for ildj in ildjs)
+ except NotImplementedError as original_exception:
+ try:
+ x = self._inverse(y, **kwargs)
+ fldjs = self._forward_log_det_jacobian(x, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ x, -fldj, self.forward_min_event_ndims, event_ndims)
+ for fldj in fldjs)
+ except NotImplementedError:
+ raise original_exception
+
mapping = self._lookup(y=y, kwargs=kwargs)
if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
return mapping.ildj_map[event_ndims]
@@ -917,11 +928,21 @@ class Bijector(object):
return -1. * self._constant_ildj_map[event_ndims]
x = ops.convert_to_tensor(x, name="x")
self._maybe_assert_dtype(x)
- if not self._is_injective:
- fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching.
- return tuple(self._reduce_jacobian_det_over_event(
- x, fldj, self.forward_min_event_ndims, event_ndims)
- for fldj in fldjs)
+ if not self._is_injective: # No caching for non-injective
+ try:
+ fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching.
+ return tuple(self._reduce_jacobian_det_over_event(
+ x, fldj, self.forward_min_event_ndims, event_ndims)
+ for fldj in fldjs)
+ except NotImplementedError as original_exception:
+ try:
+ y = self._forward(x, **kwargs)
+ ildjs = self._inverse_log_det_jacobian(y, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ y, -ildj, self.inverse_min_event_ndims, event_ndims)
+ for ildj in ildjs)
+ except NotImplementedError:
+ raise original_exception
mapping = self._lookup(x=x, kwargs=kwargs)
if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
return -mapping.ildj_map[event_ndims]
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 3e480a79f5..c61efebca0 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -524,6 +524,8 @@ def matrix_diag_transform(matrix, transform=None, name=None):
Example of heteroskedastic 2-D linear regression.
```python
+ tfd = tfp.distributions
+
# Get a trainable Cholesky factor.
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
@@ -533,7 +535,7 @@ def matrix_diag_transform(matrix, transform=None, name=None):
mu = tf.contrib.layers.fully_connected(activations, 2)
# This is a fully trainable multivariate normal!
- dist = tf.contrib.distributions.MVNCholesky(mu, chol)
+ dist = tfd.MultivariateNormalTriL(mu, chol)
# Standard log loss. Minimizing this will "train" mu and chol, and then dist
# will be a distribution predicting labels as multivariate Gaussians.
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 6263041b8d..60d73a1693 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -550,9 +550,11 @@ def safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
+ if not isinstance(embedding_weights[0],
+ resource_variable_ops.ResourceVariable):
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
with ops.name_scope(name, 'embedding_lookup',
embedding_weights + [sparse_ids,
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 325418d5f7..d680c12ac5 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -1029,10 +1029,10 @@ def resize_images(images,
scale_factor_width = (math_ops.to_float(new_width_const) /
math_ops.to_float(current_width))
scale_factor = math_ops.minimum(scale_factor_height, scale_factor_width)
- scaled_height_const = math_ops.to_int32(scale_factor *
- math_ops.to_float(current_height))
- scaled_width_const = math_ops.to_int32(scale_factor *
- math_ops.to_float(current_width))
+ scaled_height_const = math_ops.to_int32(
+ math_ops.round(scale_factor * math_ops.to_float(current_height)))
+ scaled_width_const = math_ops.to_int32(
+ math_ops.round(scale_factor * math_ops.to_float(current_width)))
# NOTE: Reset the size and other constants used later.
size = ops.convert_to_tensor([scaled_height_const, scaled_width_const],
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 795e6bbc3e..da45f6e3e6 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -2687,6 +2687,12 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
self._assertResizeCheckShape(x, x_shape, [3840, 2160], [3840, 2160, 3])
+ def testPreserveAspectRatioSquare(self):
+ x_shape = [299, 299, 3]
+ x = np.random.uniform(size=x_shape)
+
+ self._assertResizeCheckShape(x, x_shape, [320, 320], [320, 320, 3])
+
class ResizeImageWithPadTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index c0e16ca536..94c685274a 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -52,13 +52,19 @@ builder = option_builder.ProfileOptionBuilder
class PrintModelAnalysisTest(test.TestCase):
+ def _no_rewrite_session_config(self):
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ return config_pb2.ConfigProto(graph_options=graph_options)
+
def testDumpToFile(self):
ops.reset_default_graph()
outfile = os.path.join(test.get_temp_dir(), 'dump')
opts = builder(builder.trainable_variables_parameter()
).with_file_output(outfile).build()
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
_ = lib.BuildSmallModel()
model_analyzer.profile(sess.graph, options=opts)
@@ -83,7 +89,8 @@ class PrintModelAnalysisTest(test.TestCase):
with profile_context.ProfileContext(test.get_temp_dir(),
trace_steps=[],
dump_steps=[]) as pctx:
- with session.Session() as sess, ops.device(dev):
+ with session.Session(
+ config=self._no_rewrite_session_config()) as sess, ops.device(dev):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -149,11 +156,8 @@ class PrintModelAnalysisTest(test.TestCase):
.select(['params', 'float_ops', 'occurrence', 'device', 'op_types',
'input_shapes']).build())
- rewriter_config = rewriter_config_pb2.RewriterConfig(
- disable_model_pruning=True)
- graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
- config = config_pb2.ConfigProto(graph_options=graph_options)
- with session.Session(config=config) as sess, ops.device('/device:CPU:0'):
+ with session.Session(config=self._no_rewrite_session_config()
+ ) as sess, ops.device('/device:CPU:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -179,7 +183,7 @@ class PrintModelAnalysisTest(test.TestCase):
.select(['bytes', 'params', 'float_ops', 'num_hidden_ops', 'device',
'input_shapes']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -213,7 +217,7 @@ class PrintModelAnalysisTest(test.TestCase):
with profile_context.ProfileContext(test.get_temp_dir(),
trace_steps=[],
dump_steps=[]) as pctx:
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -274,7 +278,7 @@ class PrintModelAnalysisTest(test.TestCase):
.account_displayed_op_only(False)
.select(['bytes', 'params', 'float_ops', 'device']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -302,7 +306,7 @@ class PrintModelAnalysisTest(test.TestCase):
.with_timeline_output(outfile)
.with_accounted_types(['.*']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -338,7 +342,7 @@ class PrintModelAnalysisTest(test.TestCase):
'peak_bytes', 'residual_bytes',
'output_bytes', 'occurrence', 'input_shapes']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -384,7 +388,7 @@ class PrintModelAnalysisTest(test.TestCase):
def testAdvisor(self):
ops.reset_default_graph()
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -417,7 +421,7 @@ class PrintModelAnalysisTest(test.TestCase):
.with_node_names(trim_name_regexes=['ops.py.*'])
.with_pprof_output(outfile).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -484,7 +488,7 @@ class PrintModelAnalysisTest(test.TestCase):
self.assertGreaterEqual(n.output_bytes, mob)
check_min(n.children, mm, mam, mcm, mb, mpb, mrb, mob)
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
run_meta = config_pb2.RunMetadata()
@@ -549,7 +553,7 @@ class PrintModelAnalysisTest(test.TestCase):
for attr in not_selected:
self.assertFalse(s.find(attr) > 0, s)
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
run_meta = config_pb2.RunMetadata()
@@ -582,7 +586,7 @@ class PrintModelAnalysisTest(test.TestCase):
def _trainLoop(self, train_op, train_steps, time_dir, time_step,
memory_dir, memory_step, profile_dir, dump_step):
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
sess.run(variables.global_variables_initializer())
# start from 1 because variable_initializer took one step.
for i in range(1, train_steps + 1):
@@ -655,7 +659,7 @@ class PrintModelAnalysisTest(test.TestCase):
c = a * b
try:
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
sess.run(c, options=config_pb2.RunOptions(
report_tensor_allocations_upon_oom=True))
except Exception as e: # pylint: disable=broad-except
@@ -758,7 +762,7 @@ class PrintModelAnalysisTest(test.TestCase):
grad = gradients.gradients(y, [x1])
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py
index 67cfd799ff..ab749f28cd 100644
--- a/tensorflow/python/tools/api/generator/create_python_api.py
+++ b/tensorflow/python/tools/api/generator/create_python_api.py
@@ -181,7 +181,6 @@ class _ModuleInitCodeBuilder(object):
_names_with_underscore = [%s]
__all__ = [_s for _s in dir() if not _s.startswith('_')]
__all__.extend([_s for _s in _names_with_underscore])
-__all__.remove('print_function')
''' % underscore_names_str
return module_text_map
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 0e0125a956..82f0e3be52 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -1114,7 +1114,11 @@ class _RecoverableSession(_WrappedSession):
logging.info('An error was raised while a session was being created. '
'This may be due to a preemption of a connected worker '
'or parameter server. A new session will be created. '
- 'Error: %s', e)
+ 'This error may also occur due to a gRPC failure caused '
+ 'by high memory or network bandwidth usage in the '
+ 'parameter servers. If this error occurs repeatedly, try '
+ 'increasing the number of parameter servers assigned to '
+ 'the job. Error: %s', e)
def _check_stop(self):
try:
@@ -1127,7 +1131,11 @@ class _RecoverableSession(_WrappedSession):
'session is complete. This may be due to a preemption in '
'a connected worker or parameter server. The current '
'session will be closed and a new session will be '
- 'created. Error: %s', e)
+ 'created. This error may also occur due to a gRPC failure '
+ 'caused by high memory or network bandwidth usage in the '
+ 'parameter servers. If this error occurs repeatedly, try '
+ 'increasing the number of parameter servers assigned to '
+ 'the job. Error: %s', e)
self.close()
self._sess = self._create_session()
# Since we have just recreated the session, the overall computation should
@@ -1150,7 +1158,11 @@ class _RecoverableSession(_WrappedSession):
logging.info('An error was raised. This may be due to a preemption in '
'a connected worker or parameter server. The current '
'session will be closed and a new session will be '
- 'created. Error: %s', e)
+ 'created. This error may also occur due to a gRPC failure '
+ 'caused by high memory or network bandwidth usage in the '
+ 'parameter servers. If this error occurs repeatedly, try '
+ 'increasing the number of parameter servers assigned to '
+ 'the job. Error: %s', e)
self.close()
self._sess = None
@@ -1166,7 +1178,11 @@ class _RecoverableSession(_WrappedSession):
logging.info('An error was raised. This may be due to a preemption in '
'a connected worker or parameter server. The current '
'session will be closed and a new session will be '
- 'created. Error: %s', e)
+ 'created. This error may also occur due to a gRPC failure '
+ 'caused by high memory or network bandwidth usage in the '
+ 'parameter servers. If this error occurs repeatedly, try '
+ 'increasing the number of parameter servers assigned to '
+ 'the job. Error: %s', e)
self.close()
self._sess = None
diff --git a/tensorflow/python/training/quantize_training.i b/tensorflow/python/training/quantize_training.i
index 41e62e0252..1ab600bb22 100644
--- a/tensorflow/python/training/quantize_training.i
+++ b/tensorflow/python/training/quantize_training.i
@@ -55,6 +55,13 @@ PyObject* DoQuantizeTrainingOnGraphDefHelper(
%insert("python") %{
+from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
+
+@deprecation.deprecated(
+ None,
+ "GraphDef quantized training rewriter is deprecated in the long term")
+@tf_export(v1=["train.do_quantize_training_on_graphdef"])
def do_quantize_training_on_graphdef(input_graph, num_bits):
"""A general quantization scheme is being developed in `tf.contrib.quantize`.
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 2968ca9c07..653ca525dc 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -118,6 +118,18 @@ flatten = _pywrap_tensorflow.Flatten
_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
+class _DotString(object):
+
+ def __str__(self):
+ return "."
+
+ def __repr__(self):
+ return "."
+
+
+_DOT = _DotString()
+
+
def assert_same_structure(nest1, nest2, check_types=True):
"""Asserts that two structures are nested in the same way.
@@ -149,7 +161,15 @@ def assert_same_structure(nest1, nest2, check_types=True):
TypeError: If the two structures differ in the type of sequence in any of
their substructures. Only possible if `check_types` is `True`.
"""
- _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types)
+ try:
+ _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types)
+ except (ValueError, TypeError) as e:
+ str1 = str(map_structure(lambda _: _DOT, nest1))
+ str2 = str(map_structure(lambda _: _DOT, nest2))
+ raise type(e)("%s\n"
+ "Entire first structure:\n%s\n"
+ "Entire second structure:\n%s"
+ % (str(e), str1, str2))
def flatten_dict_items(dictionary):
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index ef503137d1..bfb4c6f910 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -264,7 +264,11 @@ class NestTest(parameterized.TestCase, test.TestCase):
"Second structure:.*\n\n"
"More specifically: Substructure "
r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
- 'substructure "type=str str=spam" is not')):
+ 'substructure "type=str str=spam" is not\n'
+ "Entire first structure:\n"
+ r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
+ "Entire second structure:\n"
+ r"\(\., \.\)")):
nest.assert_same_structure(structure1, structure_different_num_elements)
with self.assertRaisesRegexp(
diff --git a/tensorflow/requirements.txt b/tensorflow/requirements.txt
new file mode 100644
index 0000000000..6e111edefc
--- /dev/null
+++ b/tensorflow/requirements.txt
@@ -0,0 +1,2 @@
+keras_applications >= 1.0.5
+keras_preprocessing >= 1.0.3
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 3a77ba769c..ca90c383f9 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/mathutil.h"
#include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/lib/threadpool.h"
@@ -2406,6 +2407,33 @@ cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
}
}
+// Determines whether we can safely perform a winograd non-fused convolution for
+// the given input and output shapes. This works around b/68264959, an integer
+// overflow in cuDNNv5 and cuDNNv6.
+#if CUDNN_VERSION >= 7000
+bool ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor&,
+ const dnn::BatchDescriptor&) {
+ return true;
+}
+#else
+bool ShouldIncludeWinogradNonfusedAlgo(
+ const dnn::BatchDescriptor& input_desc,
+ const dnn::BatchDescriptor& output_desc) {
+ int64 batch = input_desc.count();
+ int64 in_depths = input_desc.feature_map_count();
+ int64 in_rows = input_desc.height();
+ int64 in_cols = input_desc.ndims() == 1 ? 1 : input_desc.width();
+ int64 out_depths = output_desc.feature_map_count();
+
+ int64 total_size = port::MathUtil::CeilOfRatio(batch, int64{16}) *
+ std::max(in_depths, out_depths) * in_cols * in_rows *
+ sizeof(float);
+
+ const int64 threshold = 1L << 31;
+ return total_size < threshold;
+}
+#endif
+
} // namespace
template <class T>
@@ -2484,6 +2512,13 @@ port::Status CudnnSupport::DoConvolveImpl(
return port::Status::OK();
}());
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See b/68264959.");
+ }
+
RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
cudnn.handle(),
/*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
@@ -2588,6 +2623,14 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
<< "\noutput_nd.handle() = " << output_nd.handle()
<< "\noutput_data->opaque() = " << output_data->opaque();
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(conv_input_descriptor,
+ output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See around b/68264959.");
+ }
+
RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward(
cudnn.handle(),
/*alpha1=*/&conv_input_scale,
@@ -3114,6 +3157,13 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl(
}
}
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See b/68264959.");
+ }
+
// Cudnn 7.1.4 has a bug if the workspace of the following convolution is not
// zero-initialized, nvbugs/2254619.
if (CUDNN_VERSION >= 7000 &&
@@ -3293,6 +3343,13 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
"This configuration potentially produces incorrect results.");
}());
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See b/68264959.");
+ }
+
// Zero out the result buffer for strided conv backward filter for NHWC
// layouts. cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is not
// zeroed.
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index fbc58e5933..18fc5836dc 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1093,6 +1093,10 @@ tf_module {
argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "extract_volume_patches"
+ argspec: "args=[\'input\', \'ksizes\', \'strides\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "eye"
argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 7eca26be06..61448f887d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -1037,6 +1037,10 @@ tf_module {
argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "extract_volume_patches"
+ argspec: "args=[\'input\', \'ksizes\', \'strides\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "eye"
argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
index b21dabbde7..cb6da5088b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
@@ -265,10 +265,6 @@ tf_module {
argspec: "args=[\'graph\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "do_quantize_training_on_graphdef"
- argspec: "args=[\'input_graph\', \'num_bits\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "exponential_decay"
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index f86cb03995..12354a6ab2 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -62,6 +62,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/autograph:autograph",
"//tensorflow/contrib/boosted_trees:boosted_trees_pip",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
+ "//tensorflow/contrib/compiler:xla",
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
"//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
diff --git a/tensorflow/tools/test/check_futures_test.py b/tensorflow/tools/test/check_futures_test.py
index 9181c9bd4a..a883ce221f 100644
--- a/tensorflow/tools/test/check_futures_test.py
+++ b/tensorflow/tools/test/check_futures_test.py
@@ -37,6 +37,7 @@ BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
FUTURES_PATTERN = re.compile(r'^from __future__ import (\w+)\s*$')
FUTURES_PATTERN_2 = re.compile(
r'^from __future__ import (\w+), (\w+), (\w+)\s*$')
+FUTURES_PATTERN_3 = re.compile(r'^from __future__ import (\w+) as \w+\s*$')
REQUIRED_FUTURES = frozenset(['absolute_import', 'division', 'print_function'])
WHITELIST = [
@@ -59,6 +60,8 @@ def check_file(path, old_division):
for line in open(path, encoding='utf-8') if six.PY3 else open(path):
count += 1
m = FUTURES_PATTERN.match(line)
+ if not m:
+ m = FUTURES_PATTERN_3.match(line)
if m:
futures.add(m.group(1))
else:
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index d0531f8193..d47d15315d 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -179,6 +179,10 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
],
sha256 = "fdd3b3aecce60987e5525e55bf3a21d68a8695320bd5b980775af6507eec3944",
strip_prefix = "google-cloud-cpp-14760a86c4ffab9943b476305c4fe927ad95db1c",
+ system_build_file = clean_dep("//third_party/systemlibs:google_cloud_cpp.BUILD"),
+ system_link_files = {
+ "//third_party/systemlibs:google_cloud_cpp.google.cloud.bigtable.BUILD": "google/cloud/bigtable/BUILD",
+ },
)
tf_http_archive(
@@ -190,6 +194,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378",
strip_prefix = "googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb",
build_file = clean_dep("//third_party:googleapis.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:googleapis.BUILD"),
)
tf_http_archive(
@@ -319,6 +324,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
sha256 = "7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930",
strip_prefix = "gast-0.2.0",
build_file = clean_dep("//third_party:gast.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:gast.BUILD"),
)
tf_http_archive(
@@ -341,6 +347,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
],
sha256 = "95160f778a62c7a60ddeadc7bf2d83f85a23a27359814aca12cf949e896fa82c",
strip_prefix = "abseil-py-pypi-v0.2.2",
+ system_build_file = clean_dep("//third_party/systemlibs:absl_py.BUILD"),
+ system_link_files = {
+ "//third_party/systemlibs:absl_py.absl.flags.BUILD": "absl/flags/BUILD",
+ "//third_party/systemlibs:absl_py.absl.testing.BUILD": "absl/testing/BUILD",
+ },
)
tf_http_archive(
@@ -531,6 +542,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
],
sha256 = "1188e29000013ed6517168600fc35a010d58c5d321846d6a6dfee74e4c788b45",
strip_prefix = "boringssl-7f634429a04abc48e2eb041c81c5235816c96514",
+ system_build_file = clean_dep("//third_party/systemlibs:boringssl.BUILD"),
)
tf_http_archive(
@@ -738,14 +750,16 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"),
)
- native.new_http_archive(
+ tf_http_archive(
name = "double_conversion",
urls = [
+ "https://mirror.bazel.build/github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
"https://github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
],
sha256 = "2f7fbffac0d98d201ad0586f686034371a6d152ca67508ab611adc2386ad30de",
strip_prefix = "double-conversion-3992066a95b823efc8ccc1baf82a1cfc73f6e9b8",
build_file = clean_dep("//third_party:double_conversion.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:double_conversion.BUILD"),
)
tf_http_archive(
diff --git a/third_party/repo.bzl b/third_party/repo.bzl
index 7d1aa5dce9..6e30618d39 100644
--- a/third_party/repo.bzl
+++ b/third_party/repo.bzl
@@ -119,6 +119,10 @@ def _tf_http_archive(ctx):
"%prefix%": ".." if _repos_are_siblings() else "external",
}, False)
+ if use_syslib:
+ for internal_src, external_dest in ctx.attr.system_link_files.items():
+ ctx.symlink(Label(internal_src), ctx.path(external_dest))
+
tf_http_archive = repository_rule(
implementation = _tf_http_archive,
attrs = {
@@ -130,6 +134,7 @@ tf_http_archive = repository_rule(
"patch_file": attr.label(),
"build_file": attr.label(),
"system_build_file": attr.label(),
+ "system_link_files": attr.string_dict(),
},
environ = [
"TF_SYSTEM_LIBS",
@@ -180,7 +185,16 @@ def _third_party_http_archive(ctx):
_apply_patch(ctx, ctx.attr.patch_file)
ctx.symlink(Label(ctx.attr.build_file), buildfile_path)
+ link_dict = dict()
+ if use_syslib:
+ link_dict.update(ctx.attr.system_link_files)
+
for internal_src, external_dest in ctx.attr.link_files.items():
+ # if syslib and link exists in both, use the system one
+ if external_dest not in link_dict.values():
+ link_dict[internal_src] = external_dest
+
+ for internal_src, external_dest in link_dict.items():
ctx.symlink(Label(internal_src), ctx.path(external_dest))
# Downloads and creates Bazel repos for dependencies.
@@ -201,6 +215,7 @@ third_party_http_archive = repository_rule(
"system_build_file": attr.string(mandatory = False),
"patch_file": attr.label(),
"link_files": attr.string_dict(),
+ "system_link_files": attr.string_dict(),
},
environ = [
"TF_SYSTEM_LIBS",
diff --git a/third_party/systemlibs/absl_py.BUILD b/third_party/systemlibs/absl_py.BUILD
new file mode 100644
index 0000000000..fe756e1be2
--- /dev/null
+++ b/third_party/systemlibs/absl_py.BUILD
@@ -0,0 +1 @@
+licenses(["notice"]) # Apache 2.0
diff --git a/third_party/systemlibs/absl_py.absl.flags.BUILD b/third_party/systemlibs/absl_py.absl.flags.BUILD
new file mode 100644
index 0000000000..95ec92b887
--- /dev/null
+++ b/third_party/systemlibs/absl_py.absl.flags.BUILD
@@ -0,0 +1,11 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//visibility:public"])
+
+filegroup(
+ name = "LICENSE",
+)
+
+py_library(
+ name = "flags",
+)
diff --git a/third_party/systemlibs/absl_py.absl.testing.BUILD b/third_party/systemlibs/absl_py.absl.testing.BUILD
new file mode 100644
index 0000000000..c1b794c1e9
--- /dev/null
+++ b/third_party/systemlibs/absl_py.absl.testing.BUILD
@@ -0,0 +1,7 @@
+licenses(["notice"]) # Apache 2.0
+
+py_library(
+ name = "parameterized",
+ testonly = 1,
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/boringssl.BUILD b/third_party/systemlibs/boringssl.BUILD
new file mode 100644
index 0000000000..bc4c533403
--- /dev/null
+++ b/third_party/systemlibs/boringssl.BUILD
@@ -0,0 +1,21 @@
+licenses(["notice"])
+
+filegroup(
+ name = "LICENSE",
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "crypto",
+ linkopts = ["-lcrypto"],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "ssl",
+ linkopts = ["-lssl"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":crypto",
+ ],
+)
diff --git a/third_party/systemlibs/double_conversion.BUILD b/third_party/systemlibs/double_conversion.BUILD
new file mode 100644
index 0000000000..568460181a
--- /dev/null
+++ b/third_party/systemlibs/double_conversion.BUILD
@@ -0,0 +1,12 @@
+licenses(["notice"])
+
+filegroup(
+ name = "LICENSE",
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "double-conversion",
+ linkopts = ["-ldouble-conversion"],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/gast.BUILD b/third_party/systemlibs/gast.BUILD
new file mode 100644
index 0000000000..c6e1d0c4e0
--- /dev/null
+++ b/third_party/systemlibs/gast.BUILD
@@ -0,0 +1,12 @@
+licenses(["notice"]) # BSD 3-clause
+
+filegroup(
+ name = "PKG-INFO",
+ visibility = ["//visibility:public"],
+)
+
+py_library(
+ name = "gast",
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/google_cloud_cpp.BUILD b/third_party/systemlibs/google_cloud_cpp.BUILD
new file mode 100644
index 0000000000..cbe6e10ba5
--- /dev/null
+++ b/third_party/systemlibs/google_cloud_cpp.BUILD
@@ -0,0 +1,6 @@
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "LICENSE",
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD b/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
new file mode 100644
index 0000000000..b59d565390
--- /dev/null
+++ b/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
@@ -0,0 +1,7 @@
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "bigtable_client",
+ linkopts = ["-lbigtable_client"],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/googleapis.BUILD b/third_party/systemlibs/googleapis.BUILD
new file mode 100644
index 0000000000..7687745df9
--- /dev/null
+++ b/third_party/systemlibs/googleapis.BUILD
@@ -0,0 +1,12 @@
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "LICENSE",
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "bigtable_protos",
+ linkopts = ["-lbigtable_protos"],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/systemlibs/jsoncpp.BUILD b/third_party/systemlibs/jsoncpp.BUILD
index cf91917cfb..526fd0c418 100644
--- a/third_party/systemlibs/jsoncpp.BUILD
+++ b/third_party/systemlibs/jsoncpp.BUILD
@@ -23,7 +23,7 @@ genrule(
cmd = """
for i in $(OUTS); do
i=$${i##*/}
- ln -vsf /usr/include/jsoncpp/json/$$i $(@D)/include/json/$$i
+ ln -sf $(INCLUDEDIR)/jsoncpp/json/$$i $(@D)/include/json/$$i
done
""",
)
diff --git a/third_party/systemlibs/syslibs_configure.bzl b/third_party/systemlibs/syslibs_configure.bzl
index 8b09c9ac1f..8b0ab39eaf 100644
--- a/third_party/systemlibs/syslibs_configure.bzl
+++ b/third_party/systemlibs/syslibs_configure.bzl
@@ -10,11 +10,17 @@
_TF_SYSTEM_LIBS = "TF_SYSTEM_LIBS"
VALID_LIBS = [
+ "absl_py",
"astor_archive",
+ "boringssl",
+ "com_github_googleapis_googleapis",
+ "com_github_googlecloudplatform_google_cloud_cpp",
"com_googlesource_code_re2",
"curl",
"cython",
+ "double_conversion",
"flatbuffers",
+ "gast_archive",
"gif_archive",
"grpc",
"jemalloc",
diff --git a/tools/bazel.rc b/tools/bazel.rc
index 601e07ffdd..ccf62629d1 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -67,3 +67,8 @@ build -c opt
# Modular TF build options
build:dynamic_kernels --define=dynamic_loaded_kernels=true
+
+# Default paths for TF_SYSTEM_LIBS
+build --define=PREFIX=/usr
+build --define=LIBDIR=$(PREFIX)/lib
+build --define=INCLUDEDIR=$(PREFIX)/include