aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2017-12-28 16:04:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-28 16:08:58 -0800
commit20765b3e1ae3b718699592c98aa9805cb874b6d1 (patch)
treeb429a74cd0046404644f34cc8fe6ff2cab78bb85
parent2e2715baa84720f786b38d1f9cb6887399020d6f (diff)
Merge changes from github.
PiperOrigin-RevId: 180301735
-rw-r--r--CODE_OF_CONDUCT.md2
-rw-r--r--CONTRIBUTING.md6
-rw-r--r--README.md4
-rw-r--r--tensorflow/BUILD20
-rw-r--r--tensorflow/c/c_api_function_test.cc4
-rw-r--r--tensorflow/cc/gradients/math_grad.cc212
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc17
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/cpu/external_constant_pool.cc11
-rw-r--r--tensorflow/compiler/xla/service/cpu/external_constant_pool.h7
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/windows_compatibility.cc32
-rw-r--r--tensorflow/compiler/xla/service/cpu/windows_compatibility.h31
-rw-r--r--tensorflow/compiler/xla/status_macros.h14
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc7
-rw-r--r--tensorflow/contrib/BUILD4
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py2
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake1
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py5
-rw-r--r--tensorflow/contrib/factorization/examples/BUILD11
-rwxr-xr-xtensorflow/contrib/image/BUILD17
-rwxr-xr-xtensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc24
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py87
-rw-r--r--tensorflow/contrib/keras/api/__init__.py18
-rw-r--r--tensorflow/contrib/kernel_methods/BUILD2
-rw-r--r--tensorflow/contrib/kernel_methods/python/losses.py6
-rw-r--r--tensorflow/contrib/kernel_methods/python/losses_test.py23
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py6
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/__init__.py18
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/kernel_tests/__init__.py18
-rw-r--r--tensorflow/contrib/memory_stats/__init__.py4
-rw-r--r--tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py5
-rw-r--r--tensorflow/contrib/metrics/__init__.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py124
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py208
-rw-r--r--tensorflow/contrib/mpi_collectives/BUILD124
-rw-r--r--tensorflow/contrib/mpi_collectives/__init__.py18
-rw-r--r--tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc1132
-rw-r--r--tensorflow/contrib/mpi_collectives/kernels/ring.cc80
-rw-r--r--tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc117
-rw-r--r--tensorflow/contrib/mpi_collectives/kernels/ring.h327
-rw-r--r--tensorflow/contrib/mpi_collectives/mpi_message.proto2
-rw-r--r--tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc132
-rw-r--r--tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py134
-rw-r--r--tensorflow/contrib/ndlstm/__init__.py18
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py14
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py93
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/__init__.py18
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/__init__.py18
-rw-r--r--tensorflow/contrib/specs/__init__.py18
-rw-r--r--tensorflow/contrib/timeseries/examples/__init__.py18
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/__init__.py18
-rw-r--r--tensorflow/contrib/training/python/__init__.py18
-rw-r--r--tensorflow/contrib/training/python/training/__init__.py18
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBox.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBoxV2.pbtxt2
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc7
-rw-r--r--tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc2
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc2
-rw-r--r--tensorflow/core/kernels/bcast_ops.cc76
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h2
-rw-r--r--tensorflow/core/kernels/example_parsing_ops_test.cc60
-rw-r--r--tensorflow/core/kernels/mkl_aggregate_ops.cc161
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc8
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc98
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h3
-rw-r--r--tensorflow/core/kernels/mkl_lrn_op.cc659
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc88
-rw-r--r--tensorflow/core/kernels/reverse_op.cc60
-rw-r--r--tensorflow/core/kernels/set_kernels.cc4
-rw-r--r--tensorflow/core/kernels/slice_op.cc1
-rw-r--r--tensorflow/core/kernels/string_to_number_op.cc38
-rw-r--r--tensorflow/core/kernels/where_op.cc4
-rw-r--r--tensorflow/core/lib/random/random_distributions_test.cc1
-rw-r--r--tensorflow/core/lib/strings/numbers.h32
-rw-r--r--tensorflow/core/lib/strings/proto_text_util.h26
-rw-r--r--tensorflow/core/lib/strings/str_util.h2
-rw-r--r--tensorflow/core/ops/image_ops.cc4
-rw-r--r--tensorflow/core/ops/nn_ops.cc8
-rw-r--r--tensorflow/core/platform/cloud/BUILD2
-rw-r--r--tensorflow/core/platform/default/build_config.bzl5
-rw-r--r--tensorflow/core/platform/types.h4
-rw-r--r--tensorflow/core/platform/windows/integral_types.h25
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor.h2
-rw-r--r--tensorflow/docs_src/programmers_guide/datasets.md2
-rw-r--r--tensorflow/examples/tutorials/word2vec/word2vec_basic.py2
-rw-r--r--tensorflow/go/session.go45
-rw-r--r--tensorflow/go/session_test.go16
-rw-r--r--tensorflow/python/__init__.py2
-rw-r--r--tensorflow/python/client/session_clusterspec_prop_test.py3
-rw-r--r--tensorflow/python/client/tf_session.i3
-rw-r--r--tensorflow/python/debug/README.md2
-rw-r--r--tensorflow/python/framework/function_test.py9
-rw-r--r--tensorflow/python/framework/versions.py4
-rw-r--r--tensorflow/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py36
-rw-r--r--tensorflow/python/kernel_tests/atrous_convolution_test.py13
-rw-r--r--tensorflow/python/kernel_tests/bcast_ops_test.py15
-rw-r--r--tensorflow/python/kernel_tests/record_input_test.py1
-rw-r--r--tensorflow/python/ops/image_ops_impl.py2
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py7
-rw-r--r--tensorflow/python/ops/metrics_impl.py12
-rw-r--r--tensorflow/python/ops/nn_ops.py12
-rw-r--r--tensorflow/python/platform/sysconfig.py6
-rw-r--r--tensorflow/python/pywrap_tensorflow.py1
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc155
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.h10
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc9
-rw-r--r--tensorflow/tensorflow.bzl4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt4
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh1
-rwxr-xr-xtensorflow/tools/git/gen_git_source.py16
-rw-r--r--tensorflow/tools/graph_transforms/BUILD11
114 files changed, 4688 insertions, 383 deletions
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
index ff11d13140..5fff9d05a1 100644
--- a/CODE_OF_CONDUCT.md
+++ b/CODE_OF_CONDUCT.md
@@ -67,4 +67,4 @@ If the Project Stewards receive a report alleging a violation of the Code of Con
## Attribution
-This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at http://contributor-covenant.org/version/1/4, and includes some aspects of the Geek Feminism Code of Conduct and the Drupal Code of Conduct.
+This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at https://contributor-covenant.org/version/1/4, and includes some aspects of the Geek Feminism Code of Conduct and the Drupal Code of Conduct.
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index dc96bc2e3d..de4fded6ae 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -8,8 +8,8 @@ We'd love to accept your patches! Before we can take them, we have to jump a cou
Please fill out either the individual or corporate Contributor License Agreement (CLA).
- * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html).
- * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html).
+ * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](https://code.google.com/legal/individual-cla-v1.0.html).
+ * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](https://code.google.com/legal/corporate-cla-v1.0.html).
Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests.
@@ -117,7 +117,7 @@ pylint --rcfile=/tmp/pylintrc myfile.py
* [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html)
* [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html)
* [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml)
-* [Google Objective-C Style Guide](http://google.github.io/styleguide/objcguide.html)
+* [Google Objective-C Style Guide](https://google.github.io/styleguide/objcguide.html)
#### Running sanity check
diff --git a/README.md b/README.md
index aff3427bdd..2c4afb0b55 100644
--- a/README.md
+++ b/README.md
@@ -49,8 +49,8 @@ packages on Linux, Mac, and Windows.
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/))
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
-* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/))
-* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/))
+* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/))
+* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/))
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 5d639e7f8c..fb62e322a4 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -446,11 +446,13 @@ filegroup(
"//tensorflow/contrib/data/python/kernel_tests:all_files",
"//tensorflow/contrib/data/python/ops:all_files",
"//tensorflow/contrib/decision_trees/proto:all_files",
+ "//tensorflow/contrib/deprecated:all_files",
"//tensorflow/contrib/distributions:all_files",
"//tensorflow/contrib/eager/proto:all_files",
"//tensorflow/contrib/eager/python:all_files",
"//tensorflow/contrib/estimator:all_files",
"//tensorflow/contrib/factorization:all_files",
+ "//tensorflow/contrib/factorization/examples:all_files",
"//tensorflow/contrib/factorization/kernels:all_files",
"//tensorflow/contrib/ffmpeg:all_files",
"//tensorflow/contrib/ffmpeg/default:all_files",
@@ -461,6 +463,7 @@ filegroup(
"//tensorflow/contrib/graph_editor:all_files",
"//tensorflow/contrib/grid_rnn:all_files",
"//tensorflow/contrib/hooks:all_files",
+ "//tensorflow/contrib/hvx/clock_cycle_profiling:all_files",
"//tensorflow/contrib/hvx/hvx_ops_support_checker:all_files",
"//tensorflow/contrib/image:all_files",
"//tensorflow/contrib/input_pipeline:all_files",
@@ -478,6 +481,7 @@ filegroup(
"//tensorflow/contrib/layers/kernels:all_files",
"//tensorflow/contrib/learn:all_files",
"//tensorflow/contrib/learn/python/learn/datasets:all_files",
+ "//tensorflow/contrib/legacy_seq2seq:all_files",
"//tensorflow/contrib/libsvm:all_files",
"//tensorflow/contrib/linalg:all_files",
"//tensorflow/contrib/linear_optimizer:all_files",
@@ -503,15 +507,19 @@ filegroup(
"//tensorflow/contrib/lookup:all_files",
"//tensorflow/contrib/losses:all_files",
"//tensorflow/contrib/makefile:all_files",
+ "//tensorflow/contrib/memory_stats:all_files",
"//tensorflow/contrib/meta_graph_transform:all_files",
"//tensorflow/contrib/metrics:all_files",
"//tensorflow/contrib/model_pruning:all_files",
- "//tensorflow/contrib/mpi_collectives:all_files",
+ "//tensorflow/contrib/model_pruning/examples/cifar10:all_files",
+ "//tensorflow/contrib/nccl:all_files",
"//tensorflow/contrib/ndlstm:all_files",
"//tensorflow/contrib/nearest_neighbor:all_files",
"//tensorflow/contrib/nn:all_files",
"//tensorflow/contrib/opt:all_files",
+ "//tensorflow/contrib/periodic_resample:all_files",
"//tensorflow/contrib/predictor:all_files",
+ "//tensorflow/contrib/quantization:all_files",
"//tensorflow/contrib/quantize:all_files",
"//tensorflow/contrib/receptive_field:all_files",
"//tensorflow/contrib/reduce_slice_ops:all_files",
@@ -580,6 +588,7 @@ filegroup(
"//tensorflow/core/profiler/internal/advisor:all_files",
"//tensorflow/core/util/ctc:all_files",
"//tensorflow/core/util/tensor_bundle:all_files",
+ "//tensorflow/examples/adding_an_op:all_files",
"//tensorflow/examples/android:all_files",
"//tensorflow/examples/benchmark:all_files",
"//tensorflow/examples/get_started/regression:all_files",
@@ -587,10 +596,13 @@ filegroup(
"//tensorflow/examples/image_retraining:all_files",
"//tensorflow/examples/label_image:all_files",
"//tensorflow/examples/learn:all_files",
+ "//tensorflow/examples/multibox_detector:all_files",
"//tensorflow/examples/saved_model:all_files",
"//tensorflow/examples/speech_commands:all_files",
"//tensorflow/examples/tutorials/estimators:all_files",
+ "//tensorflow/examples/tutorials/layers:all_files",
"//tensorflow/examples/tutorials/mnist:all_files",
+ "//tensorflow/examples/tutorials/monitors:all_files",
"//tensorflow/examples/tutorials/word2vec:all_files",
"//tensorflow/examples/wav_to_spectrogram:all_files",
"//tensorflow/go:all_files",
@@ -613,6 +625,7 @@ filegroup(
"//tensorflow/python/kernel_tests/random:all_files",
"//tensorflow/python/ops/distributions:all_files",
"//tensorflow/python/ops/linalg:all_files",
+ "//tensorflow/python/ops/losses:all_files",
"//tensorflow/python/profiler:all_files",
"//tensorflow/python/profiler/internal:all_files",
"//tensorflow/python/saved_model:all_files",
@@ -623,6 +636,7 @@ filegroup(
"//tensorflow/tools/api/tests:all_files",
"//tensorflow/tools/benchmark:all_files",
"//tensorflow/tools/build_info:all_files",
+ "//tensorflow/tools/ci_build/gpu_build:all_files",
"//tensorflow/tools/common:all_files",
"//tensorflow/tools/compatibility:all_files",
"//tensorflow/tools/dist_test/server:all_files",
@@ -630,17 +644,17 @@ filegroup(
"//tensorflow/tools/docker/notebooks:all_files",
"//tensorflow/tools/docs:all_files",
"//tensorflow/tools/git:all_files",
+ "//tensorflow/tools/graph_transforms:all_files",
"//tensorflow/tools/mlpbtxt:all_files",
"//tensorflow/tools/proto_text:all_files",
"//tensorflow/tools/quantization:all_files",
"//tensorflow/tools/test:all_files",
"//tensorflow/user_ops:all_files",
"//third_party/hadoop:all_files",
- "//third_party/mpi:all_files",
"//third_party/sycl:all_files",
"//third_party/sycl/sycl:all_files",
],
- visibility = [":__subpackages__"],
+ visibility = ["//visibility:public"],
)
load(
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index 2e2293ca85..6234372fe3 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -1462,7 +1462,11 @@ TEST_F(CApiFunctionTest, AppendHash) {
/*append_hash=*/true);
tensorflow::FunctionDef fdef;
ASSERT_TRUE(GetFunctionDef(func_, &fdef));
+#if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+ ASSERT_EQ(string("func_name_base_ZpgUD4x8oqk"), fdef.signature().name());
+#else
ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name());
+#endif
}
TEST_F(CApiFunctionTest, GetOpDef) {
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index ebc0c77828..afd92fbf48 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -473,6 +473,41 @@ Status AddNGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("AddN", AddNGrad);
+Status PowGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ auto x = ConjugateHelper(scope, op.input(0));
+ auto y = ConjugateHelper(scope, op.input(1));
+ auto z = ConjugateHelper(scope, op.output(0));
+ auto grad = grad_inputs[0];
+ // grad * y * pow(x, y - 1)
+ auto one = Cast(scope, Const(scope, 1.0), y.type());
+ auto gx_1 = Mul(scope,
+ Mul(scope, grad, y),
+ Pow(scope, x, Sub(scope, y, one)));
+ // Avoid false singularity at x = 0
+ DataType x_dtype = x.type();
+ auto zero = Cast(scope, Const(scope, 0.0), x_dtype);
+ if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) {
+ // real(x) < 0 is fine for the complex case
+ auto log_x = Where3(scope,
+ NotEqual(scope, x, zero),
+ Log(scope, x),
+ ZerosLike(scope, x));
+ auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
+ return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
+ } else {
+ // There's no sensible real value to return if x < 0, so return 0
+ auto log_x = Where3(scope,
+ Greater(scope, x, zero),
+ Log(scope, x),
+ ZerosLike(scope, x));
+ auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
+ return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
+ }
+}
+REGISTER_GRADIENT_OP("Pow", PowGrad);
+
// MaximumMinimumGradCommon adds shared ops to calculate gradients for
// the binary Maximum and Minimum ops.
Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
@@ -812,6 +847,183 @@ Status MinOrMaxGrad(const Scope& scope, const Operation& op,
REGISTER_GRADIENT_OP("Min", MinOrMaxGrad);
REGISTER_GRADIENT_OP("Max", MinOrMaxGrad);
+Status ProdGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ auto zero = Const(scope, 0);
+ auto one = Const(scope, 1);
+
+ // The gradient can be expressed by dividing the product by each entry of
+ // the input tensor. If our input is
+ // [
+ // [3, 4],
+ // [5, 6],
+ // [7, 8]
+ // ]
+ // and we do a Prod operation on the axis 1, we will obtain [[105, 192]].
+ // The gradient will have the same shape as the input
+ // [
+ // [105/3, 192/4],
+ // dz * [105/5, 192/6],
+ // [105/7, 192/6]
+ // ]
+ // If the input contains a zero, the division is impossible but
+ // if we take the calculation that gave the first gradient
+ // (3 * 5 * 6)/3 is equal to 5 * 6
+ // the trick will be to cumprod the elements on the axis without
+ // the element at the current position (3 in the example above).
+ // We will take as example:
+ // [
+ // [
+ // [3.0, 4.0],
+ // [5.0, 6.0],
+ // [7.0, 8.0]
+ // ],
+ // [
+ // [3.0, 5.0],
+ // [0.0, 6.0],
+ // [5.0, 6.0]
+ // ]
+ // ]
+
+ // [2, 3, 2]
+ auto input_shape = Shape(scope, op.input(0));
+
+ // The Reshape with -1 flattens the reduction indices.
+ // [1]
+ auto reduction_indices = Reshape(scope, op.input(1), {-1});
+
+ // [2, 1, 2]
+ auto output_shape_kept_dims =
+ ReducedShapeHelper(scope, input_shape, reduction_indices);
+
+ // [1, 3, 1]
+ auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
+
+ // [[[105, 192]], [[0, 180]]]
+ auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
+
+ // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]]
+ auto grad_tiled = Tile(scope, grad, tile_scaling);
+
+ Scope cpu_scope = scope.WithDevice("/cpu:0");
+
+ // [3]
+ auto rank = Rank(cpu_scope, op.input(0));
+
+
+ // Normalize any negative indices in the reduction_axes to positive values.
+ auto reduction_indices_pos = Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank);
+
+ // [1]
+ auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32);
+
+ // [0, 1, 2]
+ auto idx = Range(cpu_scope, zero, rank, one);
+
+ // [0, 2]
+ auto other = SetDiff1D(cpu_scope, idx, reduced).out;
+
+ // [1, 0, 2]
+ auto perm =
+ Concat(cpu_scope, std::initializer_list<Input>{reduced, other}, 0);
+
+ // 3 => [3]
+ auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0);
+
+ // 2 * 2 => [2]
+ auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0);
+
+ // [
+ // [
+ // [ 3., 4.],
+ // [ 3., 5.]
+ // ],
+ // [
+ // [ 5., 6.],
+ // [ 0., 6.]
+ // ],
+ // [
+ // [ 7., 8.],
+ // [ 5., 6.]
+ // ]
+ // ]
+ auto permuted = Transpose(scope, op.input(0), perm);
+
+ // [3, 2, 2]
+ auto permuted_shape = Shape(scope, permuted);
+
+ // [
+ // [ 3., 4., 3., 5.],
+ // [ 5., 6., 0., 6.],
+ // [ 7., 8., 5., 6.]
+ // ]
+ auto reshaped = Reshape(
+ scope, permuted,
+ Stack(scope, std::initializer_list<Input>{reduced_num, other_num}));
+
+ // [
+ // [ 1., 1., 1., 1.],
+ // [ 3., 4., 3., 5.],
+ // [ 15., 24., 0., 30.]
+ // ]
+ auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true));
+
+ // [
+ // [ 35., 48., 0., 36.],
+ // [ 7., 8., 5., 6.],
+ // [ 1., 1., 1., 1.]
+ // ]
+ auto right =
+ Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true));
+
+ // left * right =
+ // [
+ // [ 35., 48., 0., 36.],
+ // [ 21., 32., 15., 30.],
+ // [ 15., 24., 0., 30.]
+ // ]
+ // y =
+ // [
+ // [
+ // [ 35., 48.],
+ // [ 0., 36.]
+ // ],
+ // [
+ // [ 21., 32.],
+ // [ 15., 30.]
+ // ],
+ // [
+ // [ 15., 24.],
+ // [ 0., 30.]
+ // ]
+ // ]
+ auto y = Reshape(scope, Mul(scope, left, right), permuted_shape);
+
+ // out =
+ // [
+ // [
+ // [ 35., 48.],
+ // [ 21., 32.],
+ // [ 15., 24.]
+ // ],
+ // [
+ // [ 0., 36.],
+ // [ 15., 30.],
+ // [ 0., 30.]
+ // ]
+ // ]
+ auto out =
+ Mul(scope, grad_tiled, Transpose(scope, y, InvertPermutation(scope, perm)));
+
+ grad_outputs->push_back(Reshape(scope, out, input_shape));
+
+ // stop propagation along reduction_indices
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Prod", ProdGrad);
+
// MatMulGrad helper function used to compute two MatMul operations
// based on input matrix transposition combinations.
Status MatMulGradHelper(const Scope& scope, const bool is_batch,
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 29def3c3ea..b94d797711 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -843,6 +843,14 @@ TEST_F(NaryGradTest, SquaredDifference) {
RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape});
}
+TEST_F(NaryGradTest, Pow) {
+ TensorShape shape({3});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
+ // fix exponent to avoid overflow
+ auto y = Pow(scope_, x, Const(scope_, {1.f, 2.f, 3.f}));
+ RunTest({x}, {shape}, {y}, {shape});
+}
+
TEST_F(NaryGradTest, Maximum) {
TensorShape shape({3, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
@@ -865,6 +873,15 @@ TEST_F(NaryGradTest, Minimum) {
RunTest(x, x_init_value, y, shape);
}
+TEST_F(NaryGradTest, Prod) {
+ TensorShape x_shape({2, 3, 2});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
+ auto y = Prod(scope_, x, {1});
+ // y's shape is the result of reducing x along axes 1
+ TensorShape y_shape({2, 1, 2});
+ RunTest({x}, {x_shape}, {y}, {y_shape});
+}
+
TEST_F(NaryGradTest, Select) {
TensorShape shape({3, 4});
auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index c9fbeae77c..e35d947525 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -148,7 +148,11 @@ cc_library(
cc_library(
name = "simple_orc_jit",
- srcs = ["simple_orc_jit.cc"],
+ srcs = [
+ "simple_orc_jit.cc",
+ "windows_compatibility.cc",
+ "windows_compatibility.h",
+ ],
hdrs = ["simple_orc_jit.h"],
deps = [
":compiler_functor",
diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc
index c9f8e55849..7a97021dda 100644
--- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc
+++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc
@@ -33,13 +33,10 @@ void ExternalConstantPool::Insert(string name, const Literal& literal,
CHECK(entries_.find(name) == entries_.end());
int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape());
- void* raw_pointer;
- CHECK_EQ(
- posix_memalign(&raw_pointer, std::max<size_t>(alignment, sizeof(void*)),
- literal_size),
- 0)
- << "failed to allocate " << literal_size << " bytes with alignment of "
- << alignment;
+ void* raw_pointer = tensorflow::port::AlignedMalloc(
+ literal_size, std::max<size_t>(alignment, sizeof(void*)));
+ CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size
+ << " bytes with alignment of " << alignment;
std::memcpy(raw_pointer, literal.InternalData(), literal_size);
entries_.emplace(std::move(name), static_cast<uint8*>(raw_pointer));
diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h
index ade28cbcbc..9c00d476b1 100644
--- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h
+++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/platform/mem.h"
namespace xla {
namespace cpu {
@@ -49,10 +50,10 @@ class ExternalConstantPool {
const uint8* Find(const string& name);
private:
- // We need to `free()` pointers allocated into `entries_` since we allocate
- // them with `posix_memalign`.
+ // We need to `AlignedFree` pointers allocated into `entries_` since we
+ // allocate them with `AlignedMalloc`.
struct FreeDeleter {
- void operator()(void* ptr) { free(ptr); }
+ void operator()(void* ptr) { tensorflow::port::AlignedFree(ptr); }
};
tensorflow::gtl::FlatMap<string, std::unique_ptr<uint8, FreeDeleter>>
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index c942cd6bf1..65da61805a 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
-#include <dlfcn.h>
#include <stdint.h>
#include <algorithm>
#include <list>
@@ -38,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
+#include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xla/service/cpu/windows_compatibility.cc b/tensorflow/compiler/xla/service/cpu/windows_compatibility.cc
new file mode 100644
index 0000000000..ab308ee6cb
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/windows_compatibility.cc
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
+
+#ifdef _MSC_VER
+
+#include <math.h>
+
+void sincos(double x, double *sinv, double *cosv) {
+ *sinv = sin(x);
+ *cosv = cos(x);
+}
+
+void sincosf(float x, float *sinv, float *cosv) {
+ *sinv = sinf(x);
+ *cosv = cosf(x);
+}
+
+#endif // _MSC_VER
diff --git a/tensorflow/compiler/xla/service/cpu/windows_compatibility.h b/tensorflow/compiler/xla/service/cpu/windows_compatibility.h
new file mode 100644
index 0000000000..262f379d8b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/windows_compatibility.h
@@ -0,0 +1,31 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_WINDOWS_COMPATIBILITY_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_WINDOWS_COMPATIBILITY_H_
+
+#ifdef _MSC_VER
+
+extern "C" {
+
+// MSVC does not have sincos[f].
+void sincos(double x, double *sinv, double *cosv);
+void sincosf(float x, float *sinv, float *cosv);
+
+}
+
+#endif // _MSC_VER
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_WINDOWS_COMPATIBILITY_H_
diff --git a/tensorflow/compiler/xla/status_macros.h b/tensorflow/compiler/xla/status_macros.h
index 5e5550563d..e51dd64e2a 100644
--- a/tensorflow/compiler/xla/status_macros.h
+++ b/tensorflow/compiler/xla/status_macros.h
@@ -196,18 +196,8 @@ class StatusAdaptorForMacros {
#define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y)
#define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y
-#define TF_ASSIGN_OR_RETURN(...) \
- TF_STATUS_MACRO_GET_VARIADIC_IMPL(__VA_ARGS__, TF_ASSIGN_OR_RETURN_IMPL_3, \
- TF_ASSIGN_OR_RETURN_IMPL_2) \
- (__VA_ARGS__)
-
-#define TF_STATUS_MACRO_GET_VARIADIC_IMPL(_1, _2, _3, NAME, ...) NAME
-
-#define TF_ASSIGN_OR_RETURN_IMPL_2(lhs, rexpr) \
- TF_ASSIGN_OR_RETURN_IMPL_3(lhs, rexpr)
-
-#define TF_ASSIGN_OR_RETURN_IMPL_3(lhs, rexpr) \
- TF_ASSIGN_OR_RETURN_IMPL( \
+#define TF_ASSIGN_OR_RETURN(lhs, rexpr) \
+ TF_ASSIGN_OR_RETURN_IMPL( \
TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr)
#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index 8853ed9e57..92c2956f87 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -36,7 +36,7 @@ namespace {
class ClientTest : public ClientLibraryTestBase {};
-TEST_F(ClientTest, ExecuteWithLayout) {
+XLA_TEST_F(ClientTest, ExecuteWithLayout) {
ComputationBuilder b(client_, TestName());
std::vector<std::vector<int64>> layouts = {{0, 1}, {1, 0}};
@@ -68,7 +68,7 @@ TEST_F(ClientTest, ExecuteWithLayout) {
}
}
-TEST_F(ClientTest, ExecuteWithTupleLayout) {
+XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
ComputationBuilder b(client_, TestName());
b.Tuple({b.ConstantR2<int32>({{1, 2}, {3, 4}}),
@@ -107,7 +107,8 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) {
/*minor_to_major=*/{1, 0})));
}
-TEST_F(ClientTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) {
+XLA_TEST_F(ClientTest,
+ DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) {
Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg;
Shape shape = ShapeUtil::MakeShape(S32, {2, 2});
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 6e2320bd0d..cabd5ed1e5 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -101,7 +101,7 @@ py_library(
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:util",
- ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"]),
+ ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]),
)
cc_library(
@@ -122,7 +122,7 @@ cc_library(
"//tensorflow/contrib/tensor_forest:stats_ops_kernels",
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
"//tensorflow/contrib/text:all_kernels",
- ],
+ ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]),
)
cc_library(
diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py
index ee3719232d..fdc12e3b21 100644
--- a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py
@@ -43,7 +43,7 @@ def custom_gradient(fx, gx, x, axis=(),
h(x) = x * stop_gradient(g(x)) + stop_gradient(f(x) - x * g(x))
```
- is such that `h(x) = stop(f(x))` and `grad[h(x), x] = stop_gradient(g(x)).`
+ is such that `h(x) = stop_gradient(f(x))` and `grad[h(x), x] = stop_gradient(g(x)).`
In addition to scalar-domain/scalar-range functions, this function also
supports tensor-domain/scalar-range functions. However, in the latter case it
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 94ca4b0017..8fb19c055e 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -372,7 +372,6 @@ if (tensorflow_BUILD_CC_TESTS)
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/tensor_coding_test.cc"
"${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc"
"${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/graph_transferer_test.cc"
- "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/quantized_matmul_op_for_hexagon_test.cc"
)
if (NOT tensorflow_ENABLE_GPU)
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
index 30bb3c8ad3..b8388f93a4 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
@@ -22,6 +22,11 @@ Usage: python ./rnn_ptb.py --data-path=<path_to_dataset>
Penn Treebank (PTB) dataset from:
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import argparse
import os
import sys
diff --git a/tensorflow/contrib/factorization/examples/BUILD b/tensorflow/contrib/factorization/examples/BUILD
index 363baa121a..bbe842bd5c 100644
--- a/tensorflow/contrib/factorization/examples/BUILD
+++ b/tensorflow/contrib/factorization/examples/BUILD
@@ -21,3 +21,14 @@ tf_py_test(
],
tags = ["notsan"],
)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+)
diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD
index 54502cfc6e..ce2b279e51 100755
--- a/tensorflow/contrib/image/BUILD
+++ b/tensorflow/contrib/image/BUILD
@@ -233,6 +233,23 @@ py_library(
],
)
+cuda_py_test(
+ name = "single_image_random_dot_stereograms_ops_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/single_image_random_dot_stereograms_ops_test.py"],
+ additional_deps = [
+ ":distort_image_py",
+ ":image_py",
+ ":single_image_random_dot_stereograms_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc
index f8b56ab1c5..1f41f243f2 100755
--- a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc
+++ b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc
@@ -19,6 +19,10 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
REGISTER_OP("SingleImageRandomDotStereograms")
.Attr("T: {double,float,int64,int32}")
.Input("depth_values: T")
@@ -37,6 +41,26 @@ REGISTER_OP("SingleImageRandomDotStereograms")
"output_image_shape: shape = { dim {size:1024} dim {size: 768} dim "
"{size: 1}}")
.Attr("output_data_window: shape = { dim {size:1022} dim {size: 757}}")
+ .SetShapeFn([](InferenceContext* c) {
+ // Validate that the output_image_shape attr is correct.
+ // NOTE: The output_image_shape is [X, Y, C]
+ // while the output data is [Y, X, C] (or [H, W, C]).
+ // As a result, by default the output_image_shape has the value
+ // of [1024, 768, 1] but the output data will be [768, 1024, 1].
+ PartialTensorShape shape;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_image_shape", &shape));
+ ShapeHandle output_image_shape;
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromPartialTensorShape(shape, &output_image_shape));
+ DimensionHandle x_dim = c->Dim(output_image_shape, 0);
+ DimensionHandle y_dim = c->Dim(output_image_shape, 1);
+
+ int colors;
+ TF_RETURN_IF_ERROR(c->GetAttr("number_colors", &colors));
+
+ c->set_output(0, c->MakeShape({y_dim, x_dim, colors > 256? c->MakeDim(3) : c->MakeDim(1)}));
+ return Status::OK();
+ })
.Doc(R"doc(
Outputs a single image random dot stereogram for export via encode_PNG/JPG OP.
diff --git a/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py
new file mode 100644
index 0000000000..bf0c97245f
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py
@@ -0,0 +1,87 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for python single_image_random_dot_stereograms_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms \
+ import single_image_random_dot_stereograms
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+class SingleImageRandomDotStereogramsTest(test_util.TensorFlowTestCase):
+
+ def test_shape_function_default(self):
+ """
+ NOTE: The output_image_shape is [X, Y, C]
+ while the output data is [Y, X, C] (or [H, W, C]).
+ As a result, by default the output_image_shape has the value
+ of [1024, 768, 1], but the output data will be [768, 1024, 1].
+ """
+ x_np = [[1, 2, 3, 3, 2, 1],
+ [1, 2, 3, 4, 5, 2],
+ [1, 2, 3, 4, 5, 3],
+ [1, 2, 3, 4, 5, 4],
+ [6, 5, 4, 4, 5, 5]]
+ x_tf = constant_op.constant(x_np)
+ # By default [1024, 768, 1] => [768, 1024, 1].
+ sirds_1 = single_image_random_dot_stereograms(
+ x_tf,
+ convergence_dots_size=8,
+ number_colors=256,
+ normalize=True)
+ shape_1 = sirds_1.get_shape().as_list()
+ self.assertEqual(shape_1, [768, 1024, 1])
+ with self.test_session():
+ r_tf_1 = sirds_1.eval()
+ self.assertAllEqual(shape_1, r_tf_1.shape)
+
+ # If color > 256 then [1024, 768, 3] => [768, 1024, 3].
+ sirds_2 = single_image_random_dot_stereograms(
+ x_tf,
+ convergence_dots_size=8,
+ number_colors=512,
+ normalize=True)
+ shape_2 = sirds_2.get_shape().as_list()
+ self.assertEqual(shape_2, [768, 1024, 3])
+ with self.test_session():
+ r_tf_2 = sirds_2.eval()
+ self.assertAllEqual(shape_2, r_tf_2.shape)
+
+ # If explicitly set output_image_shape to [1200, 800, 1],
+ # then the output data should be [800, 1200, 1].
+ sirds_3 = single_image_random_dot_stereograms(
+ x_tf,
+ convergence_dots_size=8,
+ number_colors=256,
+ normalize=True,
+ output_image_shape=[1200, 800, 1])
+ shape_3 = sirds_3.get_shape().as_list()
+ self.assertEqual(shape_3, [800, 1200, 1])
+ with self.test_session():
+ r_tf_3 = sirds_3.eval()
+ self.assertAllEqual(shape_3, r_tf_3.shape)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/keras/api/__init__.py b/tensorflow/contrib/keras/api/__init__.py
index e69de29bb2..52e83069cb 100644
--- a/tensorflow/contrib/keras/api/__init__.py
+++ b/tensorflow/contrib/keras/api/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD
index a2f320ab11..eff7dfeb4c 100644
--- a/tensorflow/contrib/kernel_methods/BUILD
+++ b/tensorflow/contrib/kernel_methods/BUILD
@@ -83,9 +83,11 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":kernel_methods",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py
index 208b0e1c9d..f182fef067 100644
--- a/tensorflow/contrib/kernel_methods/python/losses.py
+++ b/tensorflow/contrib/kernel_methods/python/losses.py
@@ -73,13 +73,13 @@ def sparse_multiclass_hinge_loss(
labels)) as scope:
# Check logits Tensor has valid rank.
- logits_shape = logits.get_shape()
- logits_rank = logits_shape.ndims
+ logits_rank = logits.get_shape().ndims
if logits_rank != 2:
raise ValueError(
'logits should have rank 2 ([batch_size, num_classes]). Given rank is'
' {}'.format(logits_rank))
- batch_size, num_classes = logits_shape[0].value, logits_shape[1].value
+ logits_shape = array_ops.shape(logits)
+ batch_size, num_classes = logits_shape[0], logits_shape[1]
logits = math_ops.to_float(logits)
# Check labels have valid type.
diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py
index 8a1a5ffe56..d38d8041ce 100644
--- a/tensorflow/contrib/kernel_methods/python/losses_test.py
+++ b/tensorflow/contrib/kernel_methods/python/losses_test.py
@@ -18,10 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.kernel_methods.python import losses
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -114,6 +117,26 @@ class SparseMulticlassHingeLossTest(test.TestCase):
loss = losses.sparse_multiclass_hinge_loss(labels, logits)
self.assertAlmostEqual(loss.eval(), 0.0, 3)
+ def testUnknownShape(self):
+ """Result keeps same with `testZeroLossInt32Labels`"""
+ logits_np = np.array([[1.2, -1.4, -1.0],
+ [1.4, 1.8, 4.0],
+ [0.5, 1.8, -1.0]])
+ labels_np = np.array([0, 2, 1], dtype=np.int32)
+
+ logits_shapes = [[3, 3], # batch_size, num_classes
+ [None, 3],
+ [3, None],
+ [None, None]]
+
+ for batch_size, num_classes in logits_shapes:
+ with self.test_session():
+ logits = array_ops.placeholder(dtypes.float32, shape=(batch_size, num_classes))
+ labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,))
+ loss = losses.sparse_multiclass_hinge_loss(labels, logits)
+ result = loss.eval(feed_dict={logits: logits_np, labels: labels_np})
+ self.assertAlmostEqual(result, 0.0, 3)
+
def testCorrectPredictionsSomeClassesInsideMargin(self):
"""Loss is > 0 even if true class logits are higher than other classes."""
with self.test_session():
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index ae64b75d93..1150328b7a 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1747,6 +1747,12 @@ class BatchNormTest(test.TestCase):
expected_var *= correction_factor
return expected_var, correction_factor
+ def testBatchNormCenterFalse(self):
+ a = array_ops.placeholder(dtype=dtypes.float32, shape=(10, 10, 10, 10))
+ # Test that center=False builds a valid graph.
+ _layers.batch_norm(a, center=False, data_format='NCHW',
+ zero_debias_moving_mean=True)
+
def testUnknownShape(self):
with ops.Graph().as_default() as g, self.test_session(g):
inputs = array_ops.placeholder(dtype=dtypes.float32)
diff --git a/tensorflow/contrib/legacy_seq2seq/python/__init__.py b/tensorflow/contrib/legacy_seq2seq/python/__init__.py
index e69de29bb2..52e83069cb 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/__init__.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/__init__.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/__init__.py
index e69de29bb2..52e83069cb 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/__init__.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/memory_stats/__init__.py b/tensorflow/contrib/memory_stats/__init__.py
index a32302c854..2ce849ca66 100644
--- a/tensorflow/contrib/memory_stats/__init__.py
+++ b/tensorflow/contrib/memory_stats/__init__.py
@@ -19,6 +19,10 @@
@@MaxBytesInUse
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.contrib.memory_stats.python.ops.memory_stats_ops import BytesInUse
from tensorflow.contrib.memory_stats.python.ops.memory_stats_ops import BytesLimit
from tensorflow.contrib.memory_stats.python.ops.memory_stats_ops import MaxBytesInUse
diff --git a/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py b/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py
index 5e52ef3647..02c2ac06fb 100644
--- a/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py
+++ b/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py
@@ -76,9 +76,10 @@ class MemoryStatsOpsTest(test_util.TensorFlowTestCase):
with ops.control_dependencies([a]):
bytes_in_use_op = memory_stats_ops.BytesInUse()
with ops.control_dependencies([bytes_in_use_op]):
- b = math_ops.add(a, a)
+ b = random_ops.random_uniform(matrix_shape, dtype=dtype)
+ c = math_ops.matmul(a, b)
- _, bytes_in_use, max_bytes_in_use = sess.run([b, bytes_in_use_op,
+ _, bytes_in_use, max_bytes_in_use = sess.run([c, bytes_in_use_op,
max_bytes_in_use_op])
# intermediate result allocates 1 matrix, max usage is at least 2
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index 27dad5379a..d3dce46bfb 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -66,6 +66,7 @@ See the @{$python/contrib.metrics} guide.
@@set_intersection
@@set_size
@@set_union
+@@cohen_kappa
@@count
@@precision_recall_at_equal_thresholds
@@recall_at_precision
@@ -82,6 +83,7 @@ from tensorflow.contrib.metrics.python.ops.confusion_matrix_ops import confusion
from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics
+from tensorflow.contrib.metrics.python.ops.metric_ops import cohen_kappa
from tensorflow.contrib.metrics.python.ops.metric_ops import count
from tensorflow.contrib.metrics.python.ops.metric_ops import precision_recall_at_equal_thresholds
from tensorflow.contrib.metrics.python.ops.metric_ops import recall_at_precision
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 2f27985634..c3de1c4c62 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -24,10 +24,12 @@ from __future__ import print_function
import collections as collections_lib
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import confusion_matrix
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
@@ -3297,9 +3299,131 @@ def count(values,
return count_, update_op
+def cohen_kappa(labels, predictions_idx, num_classes, weights=None,
+ metrics_collections=None, updates_collections=None, name=None):
+ """Calculates Cohen's kappa.
+
+ [Cohen's kappa](https://en.wikipedia.org/wiki/Cohen's_kappa) is a statistic
+ that measures inter-annotator agreement.
+
+ The `cohen_kappa` function calculates the confusion matrix, and creates three
+ local variables to compute the Cohen's kappa: `po`, `pe_row`, and `pe_col`,
+ which refer to the diagonal part, rows and columns totals of the confusion
+ matrix, respectively. This value is ultimately returned as `kappa`, an
+ idempotent operation that is calculated by
+
+ pe = (pe_row * pe_col) / N
+ k = (sum(po) - sum(pe)) / (N - sum(pe))
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that updates these variables and returns the
+ `kappa`. `update_op` weights each prediction by the corresponding value in
+ `weights`.
+
+ Class labels are expected to start at 0. E.g., if `num_classes`
+ was three, then the possible labels would be [0, 1, 2].
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+
+ NOTE: Equivalent to `sklearn.metrics.cohen_kappa_score`, but the method
+ doesn't support weighted matrix yet.
+
+ Args:
+ labels: 1-D `Tensor` of real labels for the classification task. Must be
+ one of the following types: int16, int32, int64.
+ predictions_idx: 1-D `Tensor` of predicted class indices for a given
+ classification. Must have the same type as `labels`.
+ num_classes: The possible number of labels.
+ weights: Optional `Tensor` whose shape matches `predictions`.
+ metrics_collections: An optional list of collections that `kappa` should
+ be added to.
+ updates_collections: An optional list of collections that `update_op` should
+ be added to.
+ name: An optional variable_scope name.
+
+ Returns:
+ kappa: Scalar float `Tensor` representing the current Cohen's kappa.
+ update_op: `Operation` that increments `po`, `pe_row` and `pe_col`
+ variables appropriately and whose value matches `kappa`.
+
+ Raises:
+ ValueError: If `num_classes` is less than 2, or `predictions` and `labels`
+ have mismatched shapes, or if `weights` is not `None` and its shape
+ doesn't match `predictions`, or if either `metrics_collections` or
+ `updates_collections` are not a list or tuple.
+ RuntimeError: If eager execution is enabled.
+ """
+ if context.in_eager_mode():
+ raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported'
+ 'when eager execution is enabled.')
+ if num_classes < 2:
+ raise ValueError('`num_classes` must be >= 2.'
+ 'Found: {}'.format(num_classes))
+ with variable_scope.variable_scope(name, 'cohen_kappa',
+ (labels, predictions_idx, weights)):
+ # Convert 2-dim (num, 1) to 1-dim (num,)
+ labels.get_shape().with_rank_at_most(2)
+ if labels.get_shape().ndims == 2:
+ labels = array_ops.squeeze(labels, axis=[-1])
+ predictions_idx, labels, weights = (
+ metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
+ predictions=predictions_idx, labels=labels, weights=weights))
+ predictions_idx.get_shape().assert_is_compatible_with(labels.get_shape())
+
+ stat_dtype = (dtypes.int64
+ if weights is None or weights.dtype.is_integer
+ else dtypes.float32)
+ po = metrics_impl.metric_variable(
+ (num_classes,), stat_dtype, name='po')
+ pe_row = metrics_impl.metric_variable(
+ (num_classes,), stat_dtype, name='pe_row')
+ pe_col = metrics_impl.metric_variable(
+ (num_classes,), stat_dtype, name='pe_col')
+
+ # Table of the counts of agreement:
+ counts_in_table = confusion_matrix.confusion_matrix(
+ labels, predictions_idx,
+ num_classes=num_classes, weights=weights,
+ dtype=stat_dtype, name="counts_in_table")
+
+ po_t = array_ops.diag_part(counts_in_table)
+ pe_row_t = math_ops.reduce_sum(counts_in_table, axis=0)
+ pe_col_t = math_ops.reduce_sum(counts_in_table, axis=1)
+ update_po = state_ops.assign_add(po, po_t)
+ update_pe_row = state_ops.assign_add(pe_row, pe_row_t)
+ update_pe_col = state_ops.assign_add(pe_col, pe_col_t)
+
+ def _calculate_k(po, pe_row, pe_col, name):
+ po_sum = math_ops.reduce_sum(po)
+ total = math_ops.reduce_sum(pe_row)
+ pe_sum = math_ops.reduce_sum(
+ metrics_impl._safe_div( # pylint: disable=protected-access
+ pe_row * pe_col, total, None))
+ po_sum, pe_sum, total = (math_ops.to_double(po_sum),
+ math_ops.to_double(pe_sum),
+ math_ops.to_double(total))
+ # kappa = (po - pe) / (N - pe)
+ k = metrics_impl._safe_scalar_div( # pylint: disable=protected-access
+ po_sum - pe_sum, total - pe_sum, name=name)
+ return k
+
+ kappa = _calculate_k(po, pe_row, pe_col, name='value')
+ update_op = _calculate_k(update_po, update_pe_row, update_pe_col,
+ name='update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, kappa)
+
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+
+ return kappa, update_op
+
+
__all__ = [
'aggregate_metric_map',
'aggregate_metrics',
+ 'cohen_kappa',
'count',
'precision_recall_at_equal_thresholds',
'recall_at_precision',
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index f05ae394e6..89aa29f711 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -6660,5 +6660,213 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(4.1, result.eval(), 5)
+class CohenKappaTest(test.TestCase):
+
+ def _confusion_matrix_to_samples(self, confusion_matrix):
+ x, y = confusion_matrix.shape
+ pairs = []
+ for label in range(x):
+ for feature in range(y):
+ pairs += [label, feature] * confusion_matrix[label, feature]
+ pairs = np.array(pairs).reshape((-1, 2))
+ return pairs[:, 0], pairs[:, 1]
+
+ def setUp(self):
+ np.random.seed(1)
+ ops.reset_default_graph()
+
+ def testVars(self):
+ metrics.cohen_kappa(
+ predictions_idx=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ num_classes=2)
+ _assert_metric_variables(self, (
+ 'cohen_kappa/po:0',
+ 'cohen_kappa/pe_row:0',
+ 'cohen_kappa/pe_col:0',))
+
+ def testMetricsCollection(self):
+ my_collection_name = '__metrics__'
+ kappa, _ = metrics.cohen_kappa(
+ predictions_idx=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ num_classes=2,
+ metrics_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [kappa])
+
+ def testUpdatesCollection(self):
+ my_collection_name = '__updates__'
+ _, update_op = metrics.cohen_kappa(
+ predictions_idx=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ num_classes=2,
+ updates_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
+
+ def testValueTensorIsIdempotent(self):
+ predictions = random_ops.random_uniform(
+ (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=1)
+ labels = random_ops.random_uniform(
+ (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2)
+ kappa, update_op = metrics.cohen_kappa(labels, predictions, 3)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+
+ # Run several updates.
+ for _ in range(10):
+ sess.run(update_op)
+
+ # Then verify idempotency.
+ initial_kappa = kappa.eval()
+ for _ in range(10):
+ self.assertAlmostEqual(initial_kappa, kappa.eval(), 5)
+
+ def testBasic(self):
+ confusion_matrix = np.array([
+ [9, 3, 1],
+ [4, 8, 2],
+ [2, 1, 6]])
+ # overall total = 36
+ # po = [9, 8, 6], sum(po) = 23
+ # pe_row = [15, 12, 9], pe_col = [13, 14, 9], so pe = [5.42, 4.67, 2.25]
+ # finally, kappa = (sum(po) - sum(pe)) / (N - sum(pe))
+ # = (23 - 12.34) / (36 - 12.34)
+ # = 0.45
+ # see: http://psych.unl.edu/psycrs/handcomp/hckappa.PDF
+ expect = 0.45
+ labels, predictions = self._confusion_matrix_to_samples(confusion_matrix)
+
+ dtypes = [dtypes_lib.int16, dtypes_lib.int32, dtypes_lib.int64]
+ shapes = [(len(labels,)), # 1-dim
+ (len(labels), 1)] # 2-dim
+ weights = [None, np.ones_like(labels)]
+
+ for dtype in dtypes:
+ for shape in shapes:
+ for weight in weights:
+ with self.test_session() as sess:
+ predictions_tensor = constant_op.constant(
+ np.reshape(predictions, shape), dtype=dtype)
+ labels_tensor = constant_op.constant(
+ np.reshape(labels, shape), dtype=dtype)
+ kappa, update_op = metrics.cohen_kappa(
+ labels_tensor, predictions_tensor, 3, weights=weight)
+
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(expect, sess.run(update_op), 2)
+ self.assertAlmostEqual(expect, kappa.eval(), 2)
+
+ def testAllCorrect(self):
+ inputs = np.arange(0, 100) % 4
+ # confusion matrix
+ # [[25, 0, 0],
+ # [0, 25, 0],
+ # [0, 0, 25]]
+ # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(inputs, inputs)
+ expect = 1.0
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(inputs)
+ kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
+
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(expect, sess.run(update_op), 5)
+ self.assertAlmostEqual(expect, kappa.eval(), 5)
+
+ def testAllIncorrect(self):
+ labels = np.arange(0, 100) % 4
+ predictions = (labels + 1) % 4
+ # confusion matrix
+ # [[0, 25, 0],
+ # [0, 0, 25],
+ # [25, 0, 0]]
+ # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions)
+ expect = -0.333333333333
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(labels)
+ kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
+
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(expect, sess.run(update_op), 5)
+ self.assertAlmostEqual(expect, kappa.eval(), 5)
+
+ def testWeighted(self):
+ confusion_matrix = np.array([
+ [9, 3, 1],
+ [4, 8, 2],
+ [2, 1, 6]])
+ labels, predictions = self._confusion_matrix_to_samples(confusion_matrix)
+ num_samples = np.sum(confusion_matrix, dtype=np.int32)
+ weights = (np.arange(0, num_samples) % 5) / 5.0
+ # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(
+ # labels, predictions, sample_weight=weights)
+ expect = 0.453466583385
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(labels)
+ kappa, update_op = metrics.cohen_kappa(labels, predictions, 4,
+ weights=weights)
+
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(expect, sess.run(update_op), 5)
+ self.assertAlmostEqual(expect, kappa.eval(), 5)
+
+ def testWithMultipleUpdates(self):
+ confusion_matrix = np.array([
+ [90, 30, 10, 20],
+ [40, 80, 20, 30],
+ [20, 10, 60, 35],
+ [15, 25, 30, 25]])
+ labels, predictions = self._confusion_matrix_to_samples(confusion_matrix)
+ num_samples = np.sum(confusion_matrix, dtype=np.int32)
+ weights = (np.arange(0, num_samples) % 5) / 5.0
+ num_classes = confusion_matrix.shape[0]
+
+ batch_size = num_samples // 10
+ predictions_t = array_ops.placeholder(dtypes_lib.float32,
+ shape=(batch_size,))
+ labels_t = array_ops.placeholder(dtypes_lib.int32,
+ shape=(batch_size,))
+ weights_t = array_ops.placeholder(dtypes_lib.float32,
+ shape=(batch_size,))
+ kappa, update_op = metrics.cohen_kappa(
+ labels_t, predictions_t, num_classes, weights=weights_t)
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+
+ for idx in range(0, num_samples, batch_size):
+ batch_start, batch_end = idx, idx + batch_size
+ sess.run(update_op,
+ feed_dict={labels_t: labels[batch_start:batch_end],
+ predictions_t: predictions[batch_start:batch_end],
+ weights_t: weights[batch_start:batch_end]})
+ # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(
+ # labels_np, predictions_np, sample_weight=weights_np)
+ expect = 0.289965397924
+ self.assertAlmostEqual(expect, kappa.eval(), 5)
+
+ def testInvalidNumClasses(self):
+ predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 1))
+ labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 1))
+ with self.assertRaisesRegexp(ValueError, 'num_classes'):
+ metrics.cohen_kappa(labels, predictions, 1)
+
+ def testInvalidDimension(self):
+ predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 1))
+ invalid_labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 2))
+ with self.assertRaises(ValueError):
+ metrics.cohen_kappa(invalid_labels, predictions, 3)
+
+ invalid_predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 2))
+ labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 1))
+ with self.assertRaises(ValueError):
+ metrics.cohen_kappa(labels, invalid_predictions, 3)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD
index 11c5d6e776..9f9802b8fe 100644
--- a/tensorflow/contrib/mpi_collectives/BUILD
+++ b/tensorflow/contrib/mpi_collectives/BUILD
@@ -6,20 +6,9 @@ package(default_visibility = [
licenses(["notice"]) # Apache 2.0
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
load(
"//tensorflow/core:platform/default/build_config.bzl",
+ "tf_additional_mpi_lib_defines",
"tf_proto_library_cc",
)
@@ -33,26 +22,98 @@ tf_proto_library_cc(
],
)
-load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
-load("//tensorflow:tensorflow.bzl", "tf_py_test")
+cc_library(
+ name = "mpi_defines",
+ defines = tf_additional_mpi_lib_defines(),
+)
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_custom_op_py_library",
+ "tf_custom_op_library",
+ "tf_gen_op_wrapper_py",
+ "tf_gen_op_libs",
+ "tf_kernel_library",
+ "tf_py_test",
+)
tf_custom_op_library(
- name = "mpi_collectives.so",
+ name = "python/ops/_mpi_ops.so",
srcs = [
- "mpi_ops.cc",
- "ring.cc",
- "ring.h",
+ "kernels/mpi_ops.cc",
+ "kernels/ring.cc",
+ "kernels/ring.h",
+ "ops/mpi_ops.cc",
],
gpu_srcs = [
- "ring.cu.cc",
- "ring.h",
+ "kernels/ring.cu.cc",
+ "kernels/ring.h",
],
deps = [
+ ":mpi_defines",
":mpi_message_proto_cc",
"//third_party/mpi",
],
)
+tf_kernel_library(
+ name = "mpi_ops_kernels",
+ srcs = [
+ "kernels/mpi_ops.cc",
+ "kernels/ring.cc",
+ ],
+ hdrs = [
+ "kernels/ring.h",
+ ],
+ gpu_srcs = [
+ "kernels/ring.cu.cc",
+ ],
+ deps = [
+ ":mpi_defines",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:gpu_headers_lib",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:proto_text",
+ "//tensorflow/core:stream_executor",
+ ],
+ # TODO: Include? alwayslink = 1,
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["mpi_ops"],
+)
+
+tf_gen_op_wrapper_py(
+ name = "mpi_ops",
+ deps = [":mpi_ops_op_lib"],
+)
+
+tf_custom_op_py_library(
+ name = "mpi_collectives_py",
+ srcs = [
+ "__init__.py",
+ "python/ops/mpi_ops.py",
+ ],
+ dso = [
+ ":python/ops/_mpi_ops.so",
+ ],
+ kernels = [
+ ":mpi_ops_kernels",
+ ":mpi_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":mpi_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:device",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ ],
+)
+
tf_py_test(
name = "mpi_ops_test",
srcs = ["mpi_ops_test.py"],
@@ -61,20 +122,19 @@ tf_py_test(
"//tensorflow/python:platform",
],
data = [
- ":mpi_collectives.so",
+ ":python/ops/_mpi_ops.so",
],
tags = ["manual"],
)
-py_library(
- name = "mpi_ops_py",
- srcs = [
- "__init__.py",
- "mpi_ops.py",
- ],
- data = [
- ":mpi_collectives.so",
- ],
- srcs_version = "PY2AND3",
- visibility = ["//visibility:public"],
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
)
diff --git a/tensorflow/contrib/mpi_collectives/__init__.py b/tensorflow/contrib/mpi_collectives/__init__.py
index 9ed16a6f07..52029cbc36 100644
--- a/tensorflow/contrib/mpi_collectives/__init__.py
+++ b/tensorflow/contrib/mpi_collectives/__init__.py
@@ -37,7 +37,7 @@ for detecting the running MPI configuration.
Example:
```python
-from tensorflow.contrib import mpi
+import tensorflow.contrib.mpi_collectives as mpi
# Use `mpi.Session` instead of `tf.Session`
with mpi.Session() as session:
@@ -48,8 +48,10 @@ with mpi.Session() as session:
print("MPI Size:", session.run(mpi.size()))
```
-@@rank
+@@init
@@size
+@@rank
+@@local_rank
### Ring Allreduce and Allgather
@@ -123,12 +125,12 @@ from __future__ import print_function
import tensorflow as tf
-from tensorflow.contrib.mpi_collectives.mpi_ops import size
-from tensorflow.contrib.mpi_collectives.mpi_ops import rank
-from tensorflow.contrib.mpi_collectives.mpi_ops import local_rank
-from tensorflow.contrib.mpi_collectives.mpi_ops import allgather
-from tensorflow.contrib.mpi_collectives.mpi_ops import _allreduce
-from tensorflow.contrib.mpi_collectives.mpi_ops import init
+from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import init
+from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import size
+from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import rank
+from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import local_rank
+from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import allgather
+from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import _allreduce
def allreduce(tensor, average=True):
diff --git a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc
new file mode 100644
index 0000000000..2d5b98022c
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc
@@ -0,0 +1,1132 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#include <queue>
+#include <thread>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/mutex.h"
+
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#include <cuda_runtime.h>
+#include "tensorflow/stream_executor/stream.h"
+#endif
+
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+#define OMPI_SKIP_MPICXX
+#include "third_party/mpi/mpi.h"
+#include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h"
+#include "tensorflow/contrib/mpi_collectives/kernels/ring.h"
+
+/*
+ * MPI Allreduce and Allgather Ops for TensorFlow.
+ *
+ * TensorFlow natively provides inter-device communication through send and
+ * receive ops and inter-node communication through Distributed TensorFlow,
+ * based on the same send and receive abstractions. These end up being
+ * insufficient for synchronous data-parallel training on HPC clusters where
+ * Infiniband or other high-speed interconnects are available. This module
+ * implements MPI ops for allgather and allreduce, which do bandwidth-optimal
+ * gathers and reductions and can take advantage of hardware-optimized
+ * communication libraries through the MPI implementation.
+ *
+ * The primary logic of the allreduce and allgather are in RingAllgather() and
+ * RingAllreduce(). The background thread which facilitates MPI operations is
+ * run in BackgroundThreadLoop(). The provided MPI ops are:
+ * – MPIInit:
+ * Initialize MPI on a given device (CPU or GPU).
+ * Should only be run on a single device in every process.
+ * – MPISize:
+ * Get the number of MPI processes in the global communicator.
+ * – MPIRank:
+ * Get the rank of the current MPI process in the global communicator.
+ * – MPILocalRank:
+ * Get the local rank of the current MPI process within its node.
+ * – MPIAllreduce:
+ * Perform an allreduce on a Tensor, returning the sum
+ * across all MPI processes in the global communicator.
+ * – MPIAllgather:
+ * Perform an allgather on a Tensor, returning the concatenation of
+ * the tensor on the first dimension across all MPI processes in the
+ * global communicator.
+ *
+ */
+
+template <class T>
+using StatusOr = perftools::gputools::port::StatusOr<T>;
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi_collectives {
+
+// Make sure template specializations are generated in the ring.cu.cc and the
+// ring.cc file, not in this file.
+extern template Status RingAllreduce<GPUDevice, int>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ Tensor*, Tensor*);
+extern template Status RingAllreduce<GPUDevice, float>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllgather<GPUDevice, int>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+extern template Status RingAllgather<GPUDevice, long long>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+extern template Status RingAllgather<GPUDevice, float>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+extern template Status RingAllreduce<CPUDevice, int>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ Tensor*, Tensor*);
+extern template Status RingAllreduce<CPUDevice, float>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllgather<CPUDevice, int>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+extern template Status RingAllgather<CPUDevice, long long>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+extern template Status RingAllgather<CPUDevice, float>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+
+namespace {
+
+// Return true if the templated type is GPUDevice, otherwise false.
+template <typename T>
+bool IsGPUDevice();
+template <>
+bool IsGPUDevice<GPUDevice>() {
+ return true;
+};
+template <>
+bool IsGPUDevice<CPUDevice>() {
+ return false;
+};
+
+// A callback to call after the MPI communication completes. Since the
+// allreduce and allgather ops are asynchronous, this callback is what resumes
+// computation after the reduction is completed.
+typedef std::function<void(StatusOr<Tensor>)> CommunicationDoneCallback;
+
+struct CollectiveOpRecord {
+ // The rank performing this piece of the op
+ int rank;
+
+ // The name of the op/tensor to be reduced
+ std::string name;
+
+ // The op's kernel context
+ OpKernelContext* context;
+
+ // Data type of the op
+ DataType dtype;
+
+ // The input tensor
+ const Tensor* in_t;
+
+ // Allgather: Vector of per-rank first-dimension sizes
+ std::vector<size_t> sizes_vec;
+
+ // The temp tensor for intermediate results
+ Tensor temp_t;
+
+ // The output tensor
+ Tensor* out_t;
+
+ // Whether to run this op on the gpu
+ bool on_gpu;
+
+ // The callback to call after the op has completed
+ CommunicationDoneCallback callback;
+};
+
+// Table storing Tensors to be reduced, keyed by unique name.
+// This table contains everything necessary to do the reduction
+typedef std::unordered_map<std::string, CollectiveOpRecord> TensorTable;
+
+// Table for storing Tensor metadata on rank zero. This is used for error
+// checking and size calculations, as well as determining when a reduction is
+// ready to be done (when all nodes are ready to do it).
+typedef std::unordered_map<std::string, std::vector<MPIRequest> > MessageTable;
+
+// The global state required for the MPI ops.
+//
+// MPI is a library that stores a lot of global per-program state and often
+// requires running on a single thread. As a result, we have to have a single
+// background thread responsible for all MPI operations, and communicate with
+// that background thread through global state.
+struct MPIGlobalState {
+ // An atomic boolean which is set to true when MPI is initialized.
+ // This ensures that MPI_Init is never called twice.
+ std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT;
+
+ // Condition variable to wait for initialization
+ condition_variable cv;
+
+ // Whether MPI_Init has been completed on the background thread.
+ bool initialization_done = false;
+
+ // Whether MPI_Init succeeded on the background thread.
+ Status init_status;
+
+ // A mutex that needs to be used whenever MPI operations touch
+ // shared structures.
+ mutex mu;
+
+ // Tensors waiting to be allreduced or allgathered.
+ TensorTable tensor_table;
+
+ // Queue of MPI requests waiting to be sent to the coordinator node.
+ std::queue<MPIRequest> message_queue;
+
+ // Background thread running MPI communication.
+ std::thread background_thread;
+
+ // Whether the background thread should shutdown.
+ bool shut_down = false;
+
+ // Only exists on the coordinator node (rank zero). Maintains a count of
+ // how many nodes are ready to allreduce every tensor (keyed by tensor
+ // name).
+ std::unique_ptr<MessageTable> message_table;
+
+ // The MPI rank, local rank, and size.
+ int rank = 0;
+ int local_rank = 0;
+ int size = 1;
+
+ // The device that MPI was initialized on. (-1 for no GPU)
+ int device = -1;
+
+ // The CUDA stream used for data transfers and within-allreduce operations.
+ // A naive implementation would use the TensorFlow StreamExecutor CUDA
+ // stream. However, the allreduce and allgather require doing memory copies
+ // and kernel executions (for accumulation of values on the GPU). However,
+ // the subsequent operations must wait for those operations to complete,
+ // otherwise MPI (which uses its own stream internally) will begin the data
+ // transfers before the CUDA calls are complete. In order to wait for those
+ // CUDA operations, if we were using the TensorFlow stream, we would have
+ // to synchronize that stream; however, other TensorFlow threads may be
+ // submitting more work to that stream, so synchronizing on it can cause
+ // the allreduce to be delayed, waiting for compute totally unrelated to it
+ // in other parts of the graph. Overlaying memory transfers and compute
+ // during backpropagation is crucial for good performance, so we cannot use
+ // the TensorFlow stream, and must use our own stream.
+#if GOOGLE_CUDA
+ cudaStream_t stream;
+ std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT;
+#endif
+
+ ~MPIGlobalState() {
+ // Make sure that the destructor of the background thread is safe to
+ // call. If a thread is still joinable (not detached or complete) its
+ // destructor cannot be called.
+ if (background_thread.joinable()) {
+ shut_down = true;
+ background_thread.join();
+ }
+ }
+};
+
+// All the MPI state that must be stored globally per-process.
+static MPIGlobalState mpi_global;
+
+// For clarify in argument lists.
+#define RANK_ZERO 0
+
+// A tag used for all coordinator messaging.
+#define TAG_NOTIFY 1
+
+// Store the MPIRequest for a name, and return whether the total count of
+// MPIRequests for that tensor is now equal to the MPI size (and thus we are
+// ready to reduce the tensor).
+bool IncrementTensorCount(std::unique_ptr<MessageTable>& message_table,
+ MPIRequest msg, int mpi_size) {
+ auto name = msg.tensor_name();
+ auto table_iter = message_table->find(name);
+ if (table_iter == message_table->end()) {
+ message_table->emplace(name, std::vector<MPIRequest>({msg}));
+ table_iter = message_table->find(name);
+ } else {
+ table_iter->second.push_back(msg);
+ }
+
+ int count = table_iter->second.size();
+ return count == mpi_size;
+}
+
+// Once a tensor is ready to be reduced, the coordinator sends an MPIResponse
+// instructing all ranks to start the reduction to all ranks. The MPIResponse
+// also contains error messages in case the submitted MPIRequests were not
+// valid (for example, contained mismatched shapes or types).
+//
+// Constructing the MPIResponse, thus, requires a whole lot of error checking.
+MPIResponse ConstructMPIResponse(std::unique_ptr<MessageTable>& message_table,
+ std::string name) {
+ bool error = false;
+ auto it = message_table->find(name);
+ assert(it != message_table->end());
+
+ std::vector<MPIRequest> requests = it->second;
+ assert(requests.size() > 0);
+
+ std::ostringstream error_message_stream;
+
+ // Check that all data types being reduced or gathered are identical
+ auto data_type = requests[0].tensor_type();
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ auto request_type = requests[i].tensor_type();
+ if (data_type != request_type) {
+ error = true;
+ error_message_stream << "Mismatched data types: One rank had type "
+ << DataType_Name(data_type)
+ << ", but another rank had type "
+ << DataType_Name(request_type) << ".";
+ break;
+ }
+ }
+
+ // Check that all requested operations are the same
+ auto message_type = requests[0].request_type();
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ if (error) {
+ break;
+ }
+
+ auto request_type = requests[i].request_type();
+ if (message_type != request_type) {
+ error = true;
+ error_message_stream << "Mismatched MPI operations: One rank did an "
+ << message_type << ", but another rank did an "
+ << request_type << ".";
+ break;
+ }
+ }
+
+ // If we are doing an allreduce, check that all tensor shapes
+ // are identical
+ if (message_type == MPIRequest::ALLREDUCE) {
+ TensorShape tensor_shape = requests[0].tensor_shape();
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ if (error) {
+ break;
+ }
+
+ TensorShape request_shape = requests[i].tensor_shape();
+ if (tensor_shape != request_shape) {
+ error = true;
+ error_message_stream << "Mismatched allreduce tensor shapes: "
+ << "One rank reduced a tensor of shape "
+ << tensor_shape.DebugString()
+ << ", but another rank sent a tensor of shape "
+ << request_shape.DebugString() << ".";
+ break;
+ }
+ }
+ }
+
+ // If we are doing an allgather, make sure all but the first dimension are
+ // the same. The first dimension may be different and the output tensor is
+ // the sum of the first dimension. Collect the sizes by rank.
+ if (message_type == MPIRequest::ALLGATHER) {
+ TensorShape tensor_shape = requests[0].tensor_shape();
+
+ if (tensor_shape.dims() == 0) {
+ error = true;
+ error_message_stream << "Rank zero tried to gather a rank-zero tensor.";
+ }
+
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ if (error) {
+ break;
+ }
+
+ TensorShape request_shape = requests[i].tensor_shape();
+ if (tensor_shape.dims() != request_shape.dims()) {
+ error = true;
+ error_message_stream << "Mismatched allgather tensor shapes: "
+ << "One rank gathered a tensor of rank "
+ << tensor_shape.dims()
+ << ", but another rank sent a tensor of rank "
+ << request_shape.dims() << ".";
+ break;
+ }
+
+ for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) {
+ if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) {
+ error = true;
+ error_message_stream
+ << "Mismatched allgather tensor shapes: "
+ << "One rank gathered a tensor with dimension " << dim
+ << " equal to " << tensor_shape.dim_size(dim)
+ << ", but another rank sent a tensor with dimension " << dim
+ << " equal to " << request_shape.dim_size(dim) << ".";
+ break;
+ }
+ }
+ }
+ }
+
+ MPIResponse response;
+ response.set_tensor_name(name);
+ if (error) {
+ std::string error_message = error_message_stream.str();
+ response.set_response_type(MPIResponse::ERROR);
+ response.set_error_message(error_message);
+ } else {
+ auto response_type = MPIResponse::ERROR;
+ if (message_type == MPIRequest::ALLREDUCE) {
+ response_type = MPIResponse::ALLREDUCE;
+ } else {
+ response_type = MPIResponse::ALLGATHER;
+ }
+ response.set_response_type(response_type);
+ }
+
+ // Clear all queued up requests for this name. They are now taken care of
+ // by the constructed MPI response.
+ message_table->erase(it);
+
+ return response;
+}
+
+// Process an MPIResponse by doing a reduction, a gather, or raising an error.
+void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) {
+ OpKernelContext* context;
+ const Tensor* input_tensor;
+ std::vector<size_t> sizes_vec;
+ Tensor temp_tensor;
+ Tensor* output_tensor;
+ CommunicationDoneCallback callback;
+ bool on_gpu;
+ {
+ // Lock on the tensor table.
+ mutex_lock guard(mpi_global.mu);
+
+ // We should never fail at finding this key in the tensor table.
+ auto name = response.tensor_name();
+ auto iter = tensor_table.find(name);
+ assert(iter != tensor_table.end());
+
+ assert(response.response_type() == MPIResponse::ALLREDUCE ||
+ response.response_type() == MPIResponse::ALLGATHER ||
+ response.response_type() == MPIResponse::ERROR);
+
+ CollectiveOpRecord record = iter->second;
+ context = record.context;
+ input_tensor = record.in_t;
+ sizes_vec = record.sizes_vec;
+ temp_tensor = record.temp_t;
+ output_tensor = record.out_t;
+ on_gpu = record.on_gpu;
+ callback = record.callback;
+
+ // Clear the tensor table of this tensor and its callbacks; the rest of
+ // this function takes care of it.
+ tensor_table.erase(iter);
+ }
+
+ // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't
+ // link to non-existent symbols.
+#if GOOGLE_CUDA
+#define GPU_DEVICE_IF_CUDA GPUDevice
+#else
+#define GPU_DEVICE_IF_CUDA CPUDevice
+#endif
+
+ Status status;
+ auto dtype = input_tensor->dtype();
+ if (response.response_type() == MPIResponse::ALLGATHER) {
+ if (dtype == DT_FLOAT) {
+ status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, float>(
+ context, input_tensor, sizes_vec, output_tensor)
+ : RingAllgather<CPUDevice, float>(
+ context, input_tensor, sizes_vec, output_tensor);
+ } else if (dtype == DT_INT32) {
+ status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, int>(
+ context, input_tensor, sizes_vec, output_tensor)
+ : RingAllgather<CPUDevice, int>(context, input_tensor,
+ sizes_vec, output_tensor);
+ } else if (dtype == DT_INT64) {
+ status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, long long>(
+ context, input_tensor, sizes_vec, output_tensor)
+ : RingAllgather<CPUDevice, long long>(
+ context, input_tensor, sizes_vec, output_tensor);
+ } else {
+ status = errors::Unknown("Invalid tensor type for MPI allgather.");
+ }
+ } else if (response.response_type() == MPIResponse::ALLREDUCE) {
+ if (dtype == DT_FLOAT) {
+ status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, float>(
+ context, input_tensor, &temp_tensor, output_tensor)
+ : RingAllreduce<CPUDevice, float>(
+ context, input_tensor, &temp_tensor, output_tensor);
+ } else if (dtype == DT_INT32) {
+ status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, int>(
+ context, input_tensor, &temp_tensor, output_tensor)
+ : RingAllreduce<CPUDevice, int>(
+ context, input_tensor, &temp_tensor, output_tensor);
+ } else if (dtype == DT_INT64) {
+ status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, long long>(
+ context, input_tensor, &temp_tensor, output_tensor)
+ : RingAllreduce<CPUDevice, long long>(
+ context, input_tensor, &temp_tensor, output_tensor);
+ } else {
+ status = errors::Unknown("Invalid tensor type for MPI allreduce.");
+ }
+ } else if (response.response_type() == MPIResponse::ERROR) {
+ status = errors::FailedPrecondition(response.error_message());
+ }
+
+ if (status.ok()) {
+ callback(StatusOr<Tensor>(*output_tensor));
+ } else {
+ callback(StatusOr<Tensor>(status));
+ }
+}
+
+// The MPI background thread loop coordinates all the MPI processes and the
+// tensor reductions. The design of the communicator mechanism is limited by a
+// few considerations:
+//
+// 1. Some MPI implementations require all MPI calls to happen from a
+// single thread. Since TensorFlow may use several threads for graph
+// processing, this means we must have our own dedicated thread for
+// dealing with MPI.
+// 2. We want to gracefully handle errors, when MPI processes do not
+// properly agree upon what should happen (such as mismatched types or
+// shapes). To do so requires the MPI processes to know about the shapes
+// and types of the relevant tensors on the other processes.
+// 3. The MPI reductions and gathers should be able to happen in parallel
+// with other ongoing operations. Since MPI uses an internal
+// (inaccessible) GPU stream separate from the TF GPUDevice streams, we
+// cannot explicitly synchronize memcpys or kernels with it. As a result,
+// MPIAllreduce and MPIAllgather must be AsyncOpKernels to ensure proper
+// ordering of memcpys and kernels with respect to TF streams.
+// 4. NOTE: We cannot guarantee that all the MPI processes reduce their
+// tensors in the same order. Thus, there must be a way to ensure the
+// reduction memcpys and kernels occur for correct tensors across all
+// ranks at the same time. We choose to use a coordinator (rank ID 0) to
+// gather and trigger the reduction operations that are ready to execute.
+//
+// The coordinator currently follows a master-worker paradigm. Rank zero acts
+// as the master (the "coordinator"), whereas all other ranks are simply
+// workers. Each rank runs its own background thread which progresses in ticks.
+// In each tick, the following actions happen:
+//
+// a) The workers send any available MPIRequests to the coordinator. These
+// MPIRequests indicate what the worker would like to do (i.e. which
+// tensor they would like to gather or reduce, as well as their shape and
+// type). They repeat this for every tensor that they would like to
+// operate on after that tensor's collective op has executed ComputeAsync.
+//
+// b) The workers send an empty "DONE" message to the coordinator to
+// indicate that there are no more tensors they wish to operate on.
+//
+// c) The coordinator receives the MPIRequests from the workers, as well
+// as from its own TensorFlow ops, and stores them in a request table. The
+// coordinator continues to receive MPIRequest messages until it has
+// received MPI_SIZE number of empty "DONE" messages.
+//
+// d) The coordinator finds all tensors that are ready to be reduced,
+// gathered, or all operations that result in an error. For each of those,
+// it sends an MPIResponse to all the workers. When no more MPIResponses
+// are available, it sends a "DONE" response to the workers. If the
+// process is being shutdown, it instead sends a "SHUTDOWN" response.
+//
+// e) The workers listen for MPIResponse messages, processing each one by
+// doing the required reduce or gather, until they receive a "DONE"
+// response from the coordinator. At that point, the tick ends.
+// If instead of "DONE" they receive "SHUTDOWN", they exit their
+// background loop.
+// TODO: Use the global mpi_global state variable instead of a local one
+void BackgroundThreadLoop() {
+#if GOOGLE_CUDA
+ // Set the device, so that this thread uses the same GPU context as the
+ // calling thread.
+ // TODO: Ensure that this is operating correctly. The background thread
+ // needs to be able to control all GPUs that the rank has access to, and
+ // might be more than 1 GPU. Tensors could be resident in any of the
+ // GPUs, so the background thread's accumulate and copy kernels might need
+ // to correctly set the device and it might be necessary for the background
+ // thread to manage multiple streams.
+ cudaSetDevice(mpi_global.device);
+ cudaStreamCreate(&mpi_global.stream);
+#endif
+
+ // Initialize MPI. This must happen on the background thread, since not all
+ // MPI implementations support being called from multiple threads.
+ auto init_result = MPI_Init(NULL, NULL);
+ if (init_result != MPI_SUCCESS) {
+ mpi_global.init_status =
+ errors::Unknown("Could not initialize MPI; MPI_Init() failed.");
+ mpi_global.initialization_done = true;
+ mpi_global.cv.notify_all();
+ return;
+ } else {
+ mpi_global.init_status = Status::OK();
+ }
+
+ // Get MPI rank to determine if we are rank zero.
+ int rank;
+ MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+ bool is_coordinator = rank == 0;
+
+ // Get MPI size to determine how many tensors to wait for before reducing.
+ int size;
+ MPI_Comm_size(MPI_COMM_WORLD, &size);
+
+ // Determine local rank by querying the local communicator.
+ MPI_Comm local_comm;
+ MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL,
+ &local_comm);
+ int local_rank;
+ MPI_Comm_rank(local_comm, &local_rank);
+
+ mpi_global.rank = rank;
+ mpi_global.local_rank = local_rank;
+ mpi_global.size = size;
+ mpi_global.initialization_done = true;
+
+ // Notify calling thread that initialization is complete
+ mpi_global.cv.notify_all();
+
+ // TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD!
+ // Initialize the tensor count table. No tensors are available yet.
+ if (is_coordinator) {
+ mpi_global.message_table =
+ std::unique_ptr<MessageTable>(new MessageTable());
+ }
+
+ // The coordinator sends a SHUTDOWN message to trigger shutdown.
+ bool should_shut_down = false;
+ do {
+ // TODO: Eliminate the need for thread sleep by making all activity
+ // depend on other activity (e.g. condition or MPI waits).
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+
+ // Copy the data structures from global state under this lock.
+ // However, don't keep the lock for the rest of the loop, so that
+ // enqueued stream callbacks can continue.
+ std::queue<MPIRequest> message_queue;
+ {
+ mutex_lock guard(mpi_global.mu);
+ while (!mpi_global.message_queue.empty()) {
+ MPIRequest message = mpi_global.message_queue.front();
+ mpi_global.message_queue.pop();
+ message_queue.push(message);
+ }
+ }
+
+ // Collect all tensors that are ready to be reduced. Record them in the
+ // tensor count table (rank zero) or send them to rank zero to be
+ // recorded (everyone else).
+ std::vector<std::string> ready_to_reduce;
+ while (!message_queue.empty()) {
+ // Pop the first available message message
+ MPIRequest message = message_queue.front();
+ message_queue.pop();
+
+ if (is_coordinator) {
+ bool reduce =
+ IncrementTensorCount(mpi_global.message_table, message, size);
+ if (reduce) {
+ ready_to_reduce.push_back(message.tensor_name());
+ }
+ } else {
+ std::string encoded_message;
+ message.SerializeToString(&encoded_message);
+ MPI_Send(encoded_message.c_str(), encoded_message.length() + 1,
+ MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
+ }
+ }
+
+ // Rank zero has put all its own tensors in the tensor count table.
+ // Now, it should count all the tensors that are coming from other
+ // ranks at this tick. It should keep getting tensors until it gets a
+ // DONE message from all the other ranks.
+ if (is_coordinator) {
+ // Count of DONE messages. Keep receiving messages until the number
+ // of messages is equal to the number of processes. Initialize to
+ // one since the coordinator is effectively done.
+ int completed_ranks = 1;
+ while (completed_ranks != size) {
+ MPI_Status status;
+ MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status);
+
+ // Find number of characters in message (including zero byte).
+ int source_rank = status.MPI_SOURCE;
+ int msg_length;
+ MPI_Get_count(&status, MPI_BYTE, &msg_length);
+
+ // If the length is zero, this is a DONE message.
+ if (msg_length == 0) {
+ completed_ranks++;
+ MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD,
+ &status);
+ continue;
+ }
+
+ // Get tensor name from MPI into an std::string.
+ char* buffer = new char[msg_length];
+ MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY,
+ MPI_COMM_WORLD, &status);
+ std::string received_data(buffer);
+ delete[] buffer;
+
+ MPIRequest received_message;
+ received_message.ParseFromString(received_data);
+ auto received_name = received_message.tensor_name();
+
+ bool reduce = IncrementTensorCount(mpi_global.message_table,
+ received_message, size);
+ if (reduce) {
+ ready_to_reduce.push_back(received_name);
+ }
+ }
+
+ // At this point, rank zero should have a fully updated tensor
+ // count table and should know all the tensors that need to be
+ // reduced or gathered, and everyone else should have sent all
+ // their information to rank zero. We can now do reductions and
+ // gathers; rank zero will choose which ones and in what order,
+ // and will notify the other ranks before doing each reduction.
+ for (int i = 0; i < ready_to_reduce.size(); i++) {
+ // Notify all nodes which tensor we'd like to reduce now
+ auto name = ready_to_reduce[i];
+ MPIResponse response =
+ ConstructMPIResponse(mpi_global.message_table, name);
+
+ std::string encoded_response;
+ response.SerializeToString(&encoded_response);
+ for (int r = 1; r < size; r++) {
+ MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
+ MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
+ }
+
+ // Perform the reduction. All nodes should end up performing
+ // the same reduction.
+ PerformCollectiveOp(mpi_global.tensor_table, response);
+ }
+
+ // Notify all nodes that we are done with the reductions for this
+ // tick.
+ MPIResponse done_response;
+ should_shut_down = mpi_global.shut_down;
+ done_response.set_response_type(
+ mpi_global.shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE);
+ std::string encoded_response;
+ done_response.SerializeToString(&encoded_response);
+ for (int r = 1; r < size; r++) {
+ MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
+ MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
+ }
+ } else {
+ // Notify the coordinator that this node is done sending messages.
+ // A DONE message is encoded as a zero-length message.
+ MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
+
+ // Receive names for tensors to reduce from rank zero. Once we
+ // receive a empty DONE message, stop waiting for more names.
+ while (true) {
+ MPI_Status status;
+ MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status);
+
+ // Find number of characters in message (including zero byte).
+ int msg_length;
+ MPI_Get_count(&status, MPI_BYTE, &msg_length);
+
+ // Get tensor name from MPI into an std::string.
+ char* buffer = new char[msg_length];
+ MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD,
+ &status);
+ std::string received_message(buffer);
+ delete[] buffer;
+
+ MPIResponse response;
+ response.ParseFromString(received_message);
+ if (response.response_type() == MPIResponse::DONE) {
+ // No more messages this tick
+ break;
+ } else if (response.response_type() == MPIResponse::SHUTDOWN) {
+ // No more messages this tick, and the background thread
+ // should shut down
+ should_shut_down = true;
+ break;
+ } else {
+ // Process the current message
+ PerformCollectiveOp(mpi_global.tensor_table, response);
+ }
+ }
+ }
+ } while (!should_shut_down);
+
+ MPI_Finalize();
+}
+
+// Initialize MPI and start the MPI background thread. Ensure that this is
+// only done once no matter how many times this function is called.
+Status InitializeMPIOnce(bool gpu) {
+ // Ensure MPI is only initialized once.
+ if (mpi_global.initialized_flag.test_and_set()) return mpi_global.init_status;
+
+ mpi_global.device = -1;
+#if GOOGLE_CUDA
+ if (gpu) {
+ cudaGetDevice(&mpi_global.device);
+ }
+#endif
+
+ // Start the MPI background thread, which assumes MPI is initialized
+ // TODO: Change this to a Tensorflow thread
+ mpi_global.background_thread = std::thread(BackgroundThreadLoop);
+
+ // Wait to ensure that the background thread has finished initializing MPI
+ mutex_lock guard(mpi_global.mu);
+ mpi_global.cv.wait(guard);
+ if (!mpi_global.initialization_done) {
+ mpi_global.init_status =
+ errors::Unknown("Failed to wait for MPI initialization.");
+ }
+
+ return mpi_global.init_status;
+}
+
+// Check that MPI is initialized.
+Status IsMPIInitialized() {
+ if (!mpi_global.initialization_done) {
+ return errors::FailedPrecondition(
+ "MPI has not been initialized; use tf.contrib.mpi.Session.");
+ }
+ return Status::OK();
+}
+
+// This function (called from the callback set up in MPIAll*Op::ComputeAsync)
+// only adds the op's record into the local op queue (to track the op's
+// progress), and sends a message to the coordinator indicating that this rank
+// is ready to begin. The MPI background thread will handle the MPI message.
+void EnqueueTensorCollective(CollectiveOpRecord record,
+ MPIRequest::RequestType rtype) {
+ const Tensor* input_tensor = record.in_t;
+ MPIRequest message;
+ message.set_request_rank(record.rank);
+ message.set_tensor_name(record.name);
+ message.set_tensor_type(record.dtype);
+ message.set_request_type(rtype);
+ input_tensor->shape().AsProto(message.mutable_tensor_shape());
+
+ mutex_lock guard(mpi_global.mu);
+ mpi_global.tensor_table.emplace(record.name, record);
+ mpi_global.message_queue.push(message);
+}
+
+} // namespace
+
+#if GOOGLE_CUDA
+cudaStream_t CudaStreamForMPI() { return mpi_global.stream; }
+#endif
+
+// Op to initialize MPI in the current process. The settings used in the
+// configuration are the same that must be used for all future MPI ops.
+template <typename Device>
+class MPIInitOp : public OpKernel {
+ public:
+ explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ bool on_gpu = IsGPUDevice<Device>();
+ OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU),
+ MPIInitOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU),
+ MPIInitOp<GPUDevice>);
+#endif
+
+// Op to get the current MPI Size.
+template <typename Device>
+class MPISizeOp : public OpKernel {
+ public:
+ explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES_OK(context, IsMPIInitialized());
+
+ // Write integer to output tensor
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+
+ auto flat = output->flat<int>();
+ flat(0) = mpi_global.size;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU),
+ MPISizeOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"),
+ MPISizeOp<GPUDevice>);
+#endif
+
+// Op to get the current MPI Rank.
+template <typename Device>
+class MPIRankOp : public OpKernel {
+ public:
+ explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES_OK(context, IsMPIInitialized());
+
+ // Write integer to output tensor
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+
+ auto flat = output->flat<int>();
+ flat(0) = mpi_global.rank;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU),
+ MPIRankOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"),
+ MPIRankOp<GPUDevice>);
+#endif
+
+// Op to get the current local MPI Rank.
+template <typename Device>
+class MPILocalRankOp : public OpKernel {
+ public:
+ explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES_OK(context, IsMPIInitialized());
+
+ // Write integer to output tensor
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+
+ auto flat = output->flat<int>();
+ flat(0) = mpi_global.local_rank;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU),
+ MPILocalRankOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(
+ Name("MPILocalRank").Device(DEVICE_GPU).HostMemory("rank"),
+ MPILocalRankOp<GPUDevice>);
+#endif
+
+template <typename Device>
+class MPIAllreduceOp : public AsyncOpKernel {
+ public:
+ explicit MPIAllreduceOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+ // Although this op is handled asynchronously, the ComputeAsync call is
+ // very inexpensive. It only sets up a CollectiveOpRecord and places it
+ // in the table for the background thread to handle. Thus, we do not need
+ // a TF pool thread to perform the op.
+ bool IsExpensive() override { return false; }
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
+ const Tensor* input_tensor = &context->input(0);
+ Tensor* output_tensor;
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_output(0, input_tensor->shape(), &output_tensor),
+ done);
+
+ // Record allocated on stack so op can fail without memory leak
+ CollectiveOpRecord record;
+ record.name = name();
+ record.context = context;
+ record.in_t = input_tensor;
+ record.out_t = output_tensor;
+ record.on_gpu = IsGPUDevice<Device>();
+ record.dtype = input_tensor->dtype();
+
+ const size_t temp_size =
+ (input_tensor->NumElements() + mpi_global.size - 1) / mpi_global.size;
+ TensorShape temp_shape;
+ temp_shape.AddDim(temp_size);
+ OP_REQUIRES_OK_ASYNC(context,
+ context->allocate_temp(input_tensor->dtype(),
+ temp_shape, &record.temp_t),
+ done);
+
+ auto allreduce_done_callback = [done, context](StatusOr<Tensor> status) {
+ context->SetStatus(status.status());
+ done();
+ };
+ record.callback = allreduce_done_callback;
+
+ auto allreduce_launch_callback = [record] {
+ EnqueueTensorCollective(record, MPIRequest::ALLREDUCE);
+ };
+
+ // If we are on a CPU, our device context will be null and we can't
+ // get a stream to enqueue this on. On a CPU this op is called when the
+ // data is already available, so we can just immediately do the
+ // allreduce; we don't have to wait for the data to get populated.
+#if GOOGLE_CUDA
+ auto device_context = context->op_device_context();
+ if (device_context == nullptr) {
+ allreduce_launch_callback();
+ } else {
+ auto stream = device_context->stream();
+ stream->ThenDoHostCallback(allreduce_launch_callback);
+ }
+#else
+ allreduce_launch_callback();
+#endif
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU),
+ MPIAllreduceOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU),
+ MPIAllreduceOp<GPUDevice>);
+#endif
+
+template <typename Device>
+class MPIAllgatherOp : public AsyncOpKernel {
+ public:
+ explicit MPIAllgatherOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+ // Although this op is handled asynchronously, the ComputeAsync call is
+ // very inexpensive. It only sets up a CollectiveOpRecord and places it
+ // in the table for the background thread to handle. Thus, we do not need
+ // a TF pool thread to perform the op.
+ bool IsExpensive() override { return false; }
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
+ const Tensor* input_tensor = &context->input(0);
+ const Tensor* sizing_tensor = &context->input(1);
+
+ // Record allocated on stack so op can fail without memory leak
+ CollectiveOpRecord record;
+ record.name = name();
+ record.context = context;
+ record.in_t = input_tensor;
+ record.on_gpu = IsGPUDevice<Device>();
+
+ // Construct the output size from the sizing tensor
+ size_t output_first_dim = 0;
+ if (sizing_tensor->shape().dims() == 0) {
+ // 0-dim sizing_tensor implies that the op is just gathering
+ // a single element from each rank
+ output_first_dim = mpi_global.size;
+ for (int i = 0; i < mpi_global.size; i++) {
+ record.sizes_vec.push_back(1);
+ }
+ } else {
+ // Collect the total output tensor sizing from the sizing tensor
+ // NOTE: The sizing tensor is forced to be placed on the CPU by
+ // declaring the input as HostMemory, so it is valid to read it here.
+ const int64* sizing_array =
+ (const int64*)sizing_tensor->tensor_data().data();
+ for (int i = 0; i < mpi_global.size; i++) {
+ record.sizes_vec.push_back(sizing_array[i]);
+ output_first_dim += sizing_array[i];
+ }
+ }
+
+ TensorShape output_shape;
+ output_shape.AddDim(output_first_dim);
+ for (int i = 1; i < input_tensor->shape().dims(); i++) {
+ output_shape.AddDim(input_tensor->shape().dim_size(i));
+ }
+
+ Tensor* output_tensor;
+ OP_REQUIRES_OK_ASYNC(
+ context, context->allocate_output(0, output_shape, &output_tensor),
+ done);
+
+ record.out_t = output_tensor;
+ record.dtype = input_tensor->dtype();
+
+ auto allgather_done_callback = [done, context](StatusOr<Tensor> status) {
+ context->SetStatus(status.status());
+ done();
+ };
+ record.callback = allgather_done_callback;
+
+ auto allgather_launch_callback = [record] {
+ EnqueueTensorCollective(record, MPIRequest::ALLGATHER);
+ };
+
+ // If we are on a CPU, our device context will be null and we can't
+ // get a stream to enqueue this on. On a CPU this op is called when the
+ // data is already available, so we can just immediately do the
+ // allgather; we don't have to wait for the data to get populated.
+#if GOOGLE_CUDA
+ auto device_context = context->op_device_context();
+ if (device_context == nullptr) {
+ allgather_launch_callback();
+ } else {
+ auto stream = device_context->stream();
+ stream->ThenDoHostCallback(allgather_launch_callback);
+ }
+#else
+ allgather_launch_callback();
+#endif
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MPIAllgather").Device(DEVICE_CPU).HostMemory("sizes"),
+ MPIAllgatherOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(
+ Name("MPIAllgather").Device(DEVICE_GPU).HostMemory("sizes"),
+ MPIAllgatherOp<GPUDevice>);
+#endif
+
+} // namespace mpi_collectives
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI
diff --git a/tensorflow/contrib/mpi_collectives/kernels/ring.cc b/tensorflow/contrib/mpi_collectives/kernels/ring.cc
new file mode 100644
index 0000000000..8970ceb1a2
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/kernels/ring.cc
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/contrib/mpi_collectives/kernels/ring.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi_collectives {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+extern template MPI_Datatype MPIType<float>();
+extern template MPI_Datatype MPIType<int>();
+extern template MPI_Datatype MPIType<long long>();
+extern template DataType TensorFlowDataType<float>();
+extern template DataType TensorFlowDataType<int>();
+extern template DataType TensorFlowDataType<long long>();
+
+// Generate all necessary specializations for RingAllreduce.
+template Status RingAllreduce<CPUDevice, int>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+template Status RingAllreduce<CPUDevice, float>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+
+// Generate all necessary specializations for RingAllgather.
+template Status RingAllgather<CPUDevice, int>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<CPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<CPUDevice, float>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+
+// Copy data on a CPU using a straight-forward memcpy.
+template <>
+void CopyTensorData<CPUDevice>(void* dst, void* src, size_t size) {
+ std::memcpy(dst, src, size);
+};
+
+// Accumulate values on a CPU.
+#define GENERATE_ACCUMULATE(type) \
+ template <> \
+ void AccumulateTensorData<CPUDevice, type>(type * dst, type * src, \
+ size_t size) { \
+ for (unsigned int i = 0; i < size; i++) { \
+ dst[i] += src[i]; \
+ } \
+ };
+GENERATE_ACCUMULATE(int);
+GENERATE_ACCUMULATE(long long);
+GENERATE_ACCUMULATE(float);
+#undef GENERATE_ACCUMULATE
+
+} // namespace mpi_collectives
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI
diff --git a/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc b/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc
new file mode 100644
index 0000000000..b04abde469
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc
@@ -0,0 +1,117 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/contrib/mpi_collectives/kernels/ring.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi_collectives {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+template <>
+MPI_Datatype MPIType<float>() {
+ return MPI_FLOAT;
+};
+template <>
+MPI_Datatype MPIType<int>() {
+ return MPI_INT;
+};
+template <>
+MPI_Datatype MPIType<long long>() {
+ return MPI_LONG_LONG;
+};
+
+template <>
+DataType TensorFlowDataType<float>() {
+ return DT_FLOAT;
+};
+template <>
+DataType TensorFlowDataType<int>() {
+ return DT_INT32;
+};
+template <>
+DataType TensorFlowDataType<long long>() {
+ return DT_INT64;
+};
+
+// Generate all necessary specializations for RingAllreduce.
+template Status RingAllreduce<GPUDevice, int>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+template Status RingAllreduce<GPUDevice, float>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+
+// Generate all necessary specializations for RingAllgather.
+template Status RingAllgather<GPUDevice, int>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<GPUDevice, float>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+
+// Synchronously copy data on the GPU, using a different stream than the default
+// and than TensorFlow to avoid synchronizing on operations unrelated to the
+// allreduce.
+template <>
+void CopyTensorData<GPUDevice>(void* dst, void* src, size_t size) {
+ auto stream = CudaStreamForMPI();
+ cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream);
+ cudaStreamSynchronize(stream);
+};
+
+// Elementwise accumulation kernel for GPU.
+template <typename T>
+__global__ void elemwise_accum(T* out, const T* in, const size_t N) {
+ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
+ i += blockDim.x * gridDim.x) {
+ out[i] += in[i];
+ }
+}
+
+// Synchronously accumulate tensors on the GPU, using a different stream than
+// the default and than TensorFlow to avoid synchronizing on operations
+// unrelated to the allreduce.
+#define GENERATE_ACCUMULATE(type) \
+ template <> \
+ void AccumulateTensorData<GPUDevice, type>(type * dst, type * src, \
+ size_t size) { \
+ auto stream = CudaStreamForMPI(); \
+ elemwise_accum<type><<<32, 256, 0, stream>>>(dst, src, size); \
+ cudaStreamSynchronize(stream); \
+ };
+GENERATE_ACCUMULATE(int);
+GENERATE_ACCUMULATE(long long);
+GENERATE_ACCUMULATE(float);
+#undef GENERATE_ACCUMULATE
+
+} // namespace mpi_collectives
+} // namespace contrib
+} // namespace tensorflow
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_USE_MPI
diff --git a/tensorflow/contrib/mpi_collectives/kernels/ring.h b/tensorflow/contrib/mpi_collectives/kernels/ring.h
new file mode 100644
index 0000000000..1d56d588bc
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/kernels/ring.h
@@ -0,0 +1,327 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_MPI_H_
+#define TENSORFLOW_CONTRIB_MPI_H_
+
+#ifdef TENSORFLOW_USE_MPI
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+#if GOOGLE_CUDA
+#include "cuda_runtime.h"
+#endif
+
+// Needed to avoid header issues with C++-supporting MPI implementations
+#define OMPI_SKIP_MPICXX
+#include "third_party/mpi/mpi.h"
+
+#define TAG_TENSOR 12
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi_collectives {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
+
+// Convert from templated types to values we can pass to MPI.
+template <typename T>
+MPI_Datatype MPIType();
+
+// Convert from templated types to TensorFlow data types.
+template <typename T>
+DataType TensorFlowDataType();
+
+#define MPI_REQUIRES_OK(MPI_STATUS) \
+ if ((MPI_STATUS) != MPI_SUCCESS) { \
+ return errors::Unknown("MPI operation failed unexpectedly."); \
+ }
+
+// Copy data from one tensor to another tensor.
+// This uses a custom CUDA stream on GPU, which is necessary to overlay the
+// backpropagation computations with the allreduce.
+template <typename Device>
+void CopyTensorData(void* destination, void* source, size_t size);
+
+// Add a tensor into another tensor, accumulating in place.
+// This uses a custom CUDA stream on GPU, which is necessary to overlay the
+// backpropagation computations with the allreduce.
+template <typename Device, typename T>
+void AccumulateTensorData(T* destination, T* source, size_t size);
+
+// We need to get the right stream for doing CUDA memory transfers and
+// operations, which is possibly different from the standard TensorFlow stream.
+#if GOOGLE_CUDA
+cudaStream_t CudaStreamForMPI();
+#endif
+
+/* Perform a ring allreduce on the data. Allocate the necessary output tensor
+ * and store it in the output parameter.
+ *
+ * Assumes that all MPI processes are doing an allreduce of the same tensor,
+ * with the same dimensions.
+ *
+ * A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the
+ * allreduce, the nodes involved are arranged in a ring:
+ *
+ * .--0--.
+ * / \
+ * 3 1
+ * \ /
+ * *--2--*
+ *
+ * Each node always sends to the next clockwise node in the ring, and receives
+ * from the previous one.
+ *
+ * The allreduce is done in two parts: a scatter-reduce and an allgather. In
+ * the scatter reduce, a reduction is done, so that each node ends up with a
+ * chunk of the final output tensor which has contributions from all other
+ * nodes. In the allgather, those chunks are distributed among all the nodes,
+ * so that all nodes have the entire output tensor.
+ *
+ * Both of these operations are done by dividing the input tensor into N
+ * evenly sized chunks (where N is the number of nodes in the ring).
+ *
+ * The scatter-reduce is done in N-1 steps. In the ith step, node j will send
+ * the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to
+ * its existing data for that chunk. For example, in the first iteration with
+ * the ring depicted above, you will have the following transfers:
+ *
+ * Segment 0: Node 0 --> Node 1
+ * Segment 1: Node 1 --> Node 2
+ * Segment 2: Node 2 --> Node 3
+ * Segment 3: Node 3 --> Node 0
+ *
+ * In the second iteration, you'll have the following transfers:
+ *
+ * Segment 0: Node 1 --> Node 2
+ * Segment 1: Node 2 --> Node 3
+ * Segment 2: Node 3 --> Node 0
+ * Segment 3: Node 0 --> Node 1
+ *
+ * After this iteration, Node 2 has 3 of the four contributions to Segment 0.
+ * The last iteration has the following transfers:
+ *
+ * Segment 0: Node 2 --> Node 3
+ * Segment 1: Node 3 --> Node 0
+ * Segment 2: Node 0 --> Node 1
+ * Segment 3: Node 1 --> Node 2
+ *
+ * After this iteration, Node 3 has the fully accumulated Segment 0; Node 0
+ * has the fully accumulated Segment 1; and so on. The scatter-reduce is
+ * complete.
+ *
+ * Next, the allgather distributes these fully accumululated chunks across all
+ * nodes. Communication proceeds in the same ring, once again in N-1 steps. At
+ * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i).
+ * For example, at the first iteration, the following transfers will occur:
+ *
+ * Segment 0: Node 3 --> Node 0
+ * Segment 1: Node 0 --> Node 1
+ * Segment 2: Node 1 --> Node 2
+ * Segment 3: Node 2 --> Node 3
+ *
+ * After the first iteration, Node 0 will have a fully accumulated Segment 0
+ * (from Node 3) and Segment 1. In the next iteration, Node 0 will send its
+ * just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3.
+ * After this has continued for N - 1 iterations, all nodes will have a the
+ * fully accumulated tensor.
+ *
+ * Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the
+ * allgather. Each send will contain K / N bytes, if there are K bytes in the
+ * original tensor on every node. Thus, each node sends and receives 2K(N - 1)/N
+ * bytes of data, and the performance of the allreduce (assuming no latency in
+ * connections) is constrained by the slowest interconnect between the nodes.
+ *
+ */
+template <typename Device, typename T>
+Status RingAllreduce(OpKernelContext* context, const Tensor* input,
+ Tensor* temp, Tensor* output) {
+ // Acquire MPI size and rank
+ int n, r;
+ MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n));
+ MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r));
+
+ T* buffer = (T*)output->tensor_data().data();
+
+ CopyTensorData<Device>((void*)buffer, (void*)input->tensor_data().data(),
+ output->tensor_data().size());
+
+ // Calculate segment sizes and segment ends
+ const size_t elements_to_reduce = input->NumElements();
+ const size_t segment_size = elements_to_reduce / n;
+ std::vector<size_t> segment_sizes(n, segment_size);
+
+ const size_t residual = elements_to_reduce % n;
+ for (size_t i = 0; i < residual; ++i) {
+ segment_sizes[i]++;
+ }
+
+ std::vector<size_t> segment_starts(n);
+ segment_starts[0] = 0;
+ for (size_t i = 1; i < segment_starts.size(); ++i) {
+ segment_starts[i] = segment_starts[i - 1] + segment_sizes[i - 1];
+ }
+
+ assert(segment_starts[n - 1] + segment_sizes[n - 1] == elements_to_reduce);
+
+ T* segment_recv = (T*)temp->tensor_data().data();
+
+ // Receive from your left neighbor with wrap-around
+ const size_t recv_from = ((r - 1) + n) % n;
+
+ // Send to your right neighbor with wrap-around
+ const size_t send_to = (r + 1) % n;
+
+ MPI_Status recv_status;
+ MPI_Request recv_req;
+
+ // Now start ring. At every step, for every rank, we iterate through
+ // segments with wraparound and send and recv from our neighbors and reduce
+ // locally. At the i'th iteration, rank r, sends segment (r-i) and receives
+ // segment (r-i-1).
+ for (int i = 0; i < n - 1; i++) {
+ const size_t send_seg_id = ((r - i) + n) % n;
+ const size_t recv_seg_id = ((r - i - 1) + n) % n;
+
+ T* segment_send = &(buffer[segment_starts[send_seg_id]]);
+
+ MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[recv_seg_id],
+ MPIType<T>(), recv_from, TAG_TENSOR,
+ MPI_COMM_WORLD, &recv_req));
+
+ MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[send_seg_id],
+ MPIType<T>(), send_to, TAG_TENSOR,
+ MPI_COMM_WORLD));
+
+ T* segment_update = &(buffer[segment_starts[recv_seg_id]]);
+
+ // Wait for recv to complete before reduction
+ MPI_REQUIRES_OK(MPI_Wait(&recv_req, &recv_status));
+
+ const size_t recv_seg_size = segment_sizes[recv_seg_id];
+ AccumulateTensorData<Device, T>(segment_update, segment_recv,
+ recv_seg_size);
+ }
+
+ // Now start pipelined ring allgather. At every step, for every rank, we
+ // iterate through segments with wraparound and send and recv from our
+ // neighbors. At the i'th iteration, rank r, sends segment (r-i+1) and
+ // receives segment (r-i).
+ for (size_t i = 0; i < n - 1; ++i) {
+ const size_t send_seg_id = ((r - i + 1) + n) % n;
+ const size_t recv_seg_id = ((r - i) + n) % n;
+
+ // Segment to send - at every iteration we send segment (r-i+1)
+ T* segment_send = &(buffer[segment_starts[send_seg_id]]);
+
+ // Segment to recv - at every iteration we receive segment (r-i)
+ T* segment_recv = &(buffer[segment_starts[recv_seg_id]]);
+
+ MPI_REQUIRES_OK(MPI_Sendrecv(
+ segment_send, segment_sizes[send_seg_id], MPIType<T>(), send_to,
+ TAG_TENSOR, segment_recv, segment_sizes[recv_seg_id], MPIType<T>(),
+ recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status));
+ }
+
+ return Status::OK();
+}
+
+// Perform a ring allgather on a Tensor. Other ranks may allgather with a
+// tensor which differs in the first dimension only; all other dimensions must
+// be the same.
+//
+// For more information on the ring allgather, read the documentation for the
+// ring allreduce, which includes a ring allgather.
+template <typename Device, typename T>
+Status RingAllgather(OpKernelContext* context, const Tensor* input,
+ const std::vector<size_t>& sizes, Tensor* output) {
+ // Acquire MPI size and rank
+ int n, r;
+ MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n));
+ MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r));
+
+ assert(sizes.size() == n);
+ assert(input->dim_size(0) == sizes[r]);
+
+ // Compute number of elements in every "row". We can't compute number of
+ // elements in every chunks, because those chunks are variable length.
+ size_t elements_per_row = 1;
+ for (int i = 1; i < input->shape().dims(); i++) {
+ elements_per_row *= input->dim_size(i);
+ }
+
+ // Copy data from input tensor to correct place in output tensor.
+ std::vector<size_t> segment_starts(n);
+ segment_starts[0] = 0;
+ for (int i = 1; i < n; i++) {
+ segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1];
+ }
+ size_t offset = segment_starts[r];
+
+ // Copy data to the right offset for this rank.
+ T* buffer = (T*)output->tensor_data().data();
+ CopyTensorData<Device>((void*)(buffer + offset),
+ (void*)input->tensor_data().data(),
+ elements_per_row * sizes[r] * sizeof(T));
+
+ // Receive from your left neighbor with wrap-around
+ const size_t recv_from = ((r - 1) + n) % n;
+
+ // Send to your right neighbor with wrap-around
+ const size_t send_to = (r + 1) % n;
+
+ // Perform a ring allgather. At every step, for every rank, we iterate
+ // through segments with wraparound and send and recv from our neighbors.
+ // At the i'th iteration, rank r, sends segment (r-i) and receives segment
+ // (r-1-i).
+ MPI_Status recv_status;
+ for (size_t i = 0; i < n - 1; ++i) {
+ const size_t send_seg_id = ((r - i) + n) % n;
+ const size_t recv_seg_id = ((r - i - 1) + n) % n;
+
+ // Segment to send - at every iteration we send segment (r-i)
+ size_t offset_send = segment_starts[send_seg_id];
+ size_t rows_send = sizes[send_seg_id];
+ T* segment_send = &(buffer[offset_send]);
+
+ // Segment to recv - at every iteration we receive segment (r-1-i)
+ size_t offset_recv = segment_starts[recv_seg_id];
+ size_t rows_recv = sizes[recv_seg_id];
+ T* segment_recv = &(buffer[offset_recv]);
+
+ MPI_REQUIRES_OK(MPI_Sendrecv(
+ segment_send, elements_per_row * rows_send, MPIType<T>(), send_to,
+ TAG_TENSOR, segment_recv, elements_per_row * rows_recv, MPIType<T>(),
+ recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status));
+ }
+
+ return Status::OK();
+}
+
+} // namespace mpi_collectives
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI
+
+#undef TENSORFLOW_CONTRIB_MPI_H_
+#endif // TENSORFLOW_CONTRIB_MPI_H_
diff --git a/tensorflow/contrib/mpi_collectives/mpi_message.proto b/tensorflow/contrib/mpi_collectives/mpi_message.proto
index 7fa5e20301..afbce981ae 100644
--- a/tensorflow/contrib/mpi_collectives/mpi_message.proto
+++ b/tensorflow/contrib/mpi_collectives/mpi_message.proto
@@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto3";
-package tensorflow.contrib.mpi;
+package tensorflow.contrib.mpi_collectives;
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
diff --git a/tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc
new file mode 100644
index 0000000000..18e6bb61cf
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc
@@ -0,0 +1,132 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi_collectives {
+
+REGISTER_OP("MPIInit").Doc(R"doc(
+Initialize MPI for the current process.
+
+If this is run on a GPU, then that GPU must be used for all future MPI
+operations. If it is run on CPU, then all future MPI operations must also
+run on CPU.
+)doc");
+
+REGISTER_OP("MPISize")
+ .Output("size: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the number of running MPI processes.
+
+More precisely, returns the number of MPI processes in the group associated
+with the MPI_COMM_WORLD communicator.
+
+size: Size of the MPI group.
+)doc");
+
+REGISTER_OP("MPIRank")
+ .Output("rank: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the index of the current process in the MPI group.
+
+More precisely, returns the rank of the calling process in the MPI_COMM_WORLD
+communicator.
+
+rank: Rank of the calling process.
+)doc");
+
+REGISTER_OP("MPILocalRank")
+ .Output("rank: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the index of the current process in the node it is on.
+
+More precisely, returns the rank of the calling process in communicator that
+only spans the MPI processes running on that node.
+
+rank: Rank of the calling process on the node it is on.
+)doc");
+
+REGISTER_OP("MPIAllreduce")
+ .Attr("T: {int32, int64, float32}")
+ .Input("tensor: T")
+ .Output("sum: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Perform an MPI Allreduce on a tensor. All other processes that do a reduction
+on a tensor with the same name must have the same dimension for that tensor.
+Tensors are reduced with other tensors that have the same node name for the
+allreduce.
+
+Arguments
+ tensor: A tensor to reduce.
+
+Output
+ sum: A tensor with the same shape as `tensor`, summed across all
+ MPI processes.
+)doc");
+
+REGISTER_OP("MPIAllgather")
+ .Attr("T: {int32, int64, float32}")
+ .Attr("S: {int64}")
+ .Input("tensor: T")
+ .Input("sizes: S")
+ .Output("gathered: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle output;
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output));
+ c->set_output(0, output);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Perform an MPI Allgather on a tensor. All other processes that do a gather on a
+tensor with the same name must have the same rank for that tensor, and have the
+same dimension on all but the first dimension.
+
+Arguments
+ tensor: A tensor to gather.
+ sizes: A tensor containing the first-dimension sizes of tensors to be
+ gathered from other ranks
+
+Output
+ gathered: A tensor with the same shape as `tensor` except for the first
+ dimension, which is the sum of dimensions in `sizes`.
+)doc");
+
+} // namespace mpi_collectives
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI
diff --git a/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py b/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py
new file mode 100644
index 0000000000..f0a116239d
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py
@@ -0,0 +1,134 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Inter-process communication using MPI."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.mpi_collectives.ops import gen_mpi_ops
+from tensorflow.contrib.util import loader
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import resource_loader
+
+_mpi_ops_so = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_mpi_ops.so"))
+
+def size(name=None):
+ """An op which returns the number of MPI processes.
+
+ This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the
+ size of the global communicator.
+
+ Returns:
+ An integer scalar containing the number of MPI processes.
+ """
+ return gen_mpi_ops.mpi_size(name=name)
+
+
+ops.NotDifferentiable('MPISize')
+
+
+def rank(name=None):
+ """An op which returns the MPI rank of the calling process.
+
+ This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the
+ rank of the current process in the global communicator.
+
+ Returns:
+ An integer scalar with the MPI rank of the calling process.
+ """
+ return gen_mpi_ops.mpi_rank(name=name)
+
+
+ops.NotDifferentiable('MPIRank')
+
+
+def init(name=None):
+ """An op which initializes MPI on the device on which it is run.
+
+ All future MPI ops must be run on the same device that the `init` op was run
+ on.
+ """
+ return gen_mpi_ops.mpi_init(name=name)
+
+
+ops.NotDifferentiable('MPIInit')
+
+
+def local_rank(name=None):
+ """An op which returns the local MPI rank of the calling process, within the
+ node that it is running on. For example, if there are seven processes running
+ on a node, their local ranks will be zero through six, inclusive.
+
+ This is equivalent to running `MPI_Comm_rank(...)` on a new communicator
+ which only includes processes on the same node.
+
+ Returns:
+ An integer scalar with the local MPI rank of the calling process.
+ """
+ return gen_mpi_ops.mpi_local_rank(name=name)
+
+
+ops.NotDifferentiable('MPILocalRank')
+
+
+def _allreduce(tensor, name=None):
+ """An op which sums an input tensor over all the MPI processes.
+
+ The reduction operation is keyed by the name of the op. The tensor type and
+ shape must be the same on all MPI processes for a given name. The reduction
+ will not start until all processes are ready to send and receive the tensor.
+
+ Returns:
+ A tensor of the same shape and type as `tensor`, summed across all
+ processes.
+ """
+ return gen_mpi_ops.mpi_allreduce(tensor, name=name)
+
+
+ops.NotDifferentiable('MPIAllreduce')
+
+
+def allgather(tensor, name=None):
+ """An op which concatenates the input tensor with the same input tensor on
+ all other MPI processes.
+
+ The concatenation is done on the first dimension, so the input tensors on the
+ different processes must have the same rank and shape, except for the first
+ dimension, which is allowed to be different.
+
+ Returns:
+ A tensor of the same type as `tensor`, concatenated on dimension zero
+ across all processes. The shape is identical to the input shape, except for
+ the first dimension, which may be greater and is the sum of all first
+ dimensions of the tensors in different MPI processes.
+ """
+ # Specify that first allgather is to collect the tensor gather sizes,
+ # indicated by passing in a scalar (0-D tensor) of value 0
+ sizes_flag = tf.constant(0, dtype=tf.int64, name="size_flag_const")
+ my_size = tf.slice(tf.shape(tensor, out_type=tf.int64), [0], [1], name="size_slice")
+ if name is None:
+ name = "allgather"
+ sizing_name = "{}_sizing".format(name)
+ sizes = gen_mpi_ops.mpi_allgather(my_size, sizes_flag, name=sizing_name)
+ return gen_mpi_ops.mpi_allgather(tensor, sizes, name=name)
+
+
+ops.NotDifferentiable('MPIAllgather')
+
+
diff --git a/tensorflow/contrib/ndlstm/__init__.py b/tensorflow/contrib/ndlstm/__init__.py
index e69de29bb2..52e83069cb 100644
--- a/tensorflow/contrib/ndlstm/__init__.py
+++ b/tensorflow/contrib/ndlstm/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 63155faf1e..b5d81b7caa 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -140,6 +140,20 @@ class RNNCellTest(test.TestCase):
# Smoke test
self.assertAllClose(res[0], [[0.156736, 0.156736]])
+ def testSRUCell(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ g, _ = contrib_rnn_cell.SRUCell(2)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run(
+ [g], {x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])})
+ # Smoke test
+ self.assertAllClose(res[0], [[0.509682, 0.509682]])
+
def testBasicLSTMCell(self):
for dtype in [dtypes.float16, dtypes.float32]:
np_dtype = dtype.as_numpy_dtype
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index c6b1316043..e4667828cd 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import init_ops
@@ -2630,3 +2631,95 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
new_state = (rnn_cell_impl.LSTMStateTuple(c, m))
return m, new_state
+
+
+class SRUCell(rnn_cell_impl._LayerRNNCell):
+ """SRU, Simple Recurrent Unit
+ Implementation based on
+ Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755).
+
+ This variation of RNN cell is characterized by the simplified data dependence
+ between hidden states of two consecutive time steps. Traditionally, hidden
+ states from a cell at time step t-1 needs to be multiplied with a matrix
+ W_hh before being fed into the ensuing cell at time step t.
+ This flavor of RNN replaces the matrix multiplication between h_{t-1}
+ and W_hh with a pointwise multiplication, resulting in performance
+ gain.
+
+ Args:
+ num_units: int, The number of units in the SRU cell.
+ activation: Nonlinearity to use. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ name: (optional) String, the name of the layer. Layers with the same name
+ will share weights, but to avoid mistakes we require reuse=True in such
+ cases.
+ """
+ def __init__(self, num_units,
+ activation=None, reuse=None, name=None):
+ super(SRUCell, self).__init__(_reuse=reuse, name=name)
+ self._num_units = num_units
+ self._activation = activation or math_ops.tanh
+
+ # Restrict inputs to be 2-dimensional matrices
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def build(self, inputs_shape):
+ if inputs_shape[1].value is None:
+ raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
+ % inputs_shape)
+
+ input_depth = inputs_shape[1].value
+
+ # Here the contributor believes that the following constraints
+ # are implied. The reasoning is explained here with reference to
+ # the paper https://arxiv.org/pdf/1709.02755.pdf upon which this
+ # implementation is based.
+ # In section 2.1 Equation 5, specifically:
+ # h_t = r_t \odot g(c_t) + (1 - r_t) \odot x_t
+ # the pointwise operation between r_t and x_t means they have
+ # the same shape (since we are implementing an RNN cell, braodcasting
+ # does not happen to input of a single timestep); by the same
+ # reasons, x_t has the same shape as h_t, essentially mandating that
+ # input_depth = unit_num.
+ if input_depth != self._num_units:
+ raise ValueError("SRU requires input_depth == num_units, got "
+ "input_depth = %s, num_units = %s" % (input_depth,
+ self._num_units))
+
+ self._kernel = self.add_variable(
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, 3 * self._num_units])
+
+ self._bias = self.add_variable(
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[2 * self._num_units],
+ initializer=init_ops.constant_initializer(0.0, dtype=self.dtype))
+
+ self._built = True
+
+ def call(self, inputs, state):
+ """Simple recurrent unit (SRU) with num_units cells."""
+
+ U = math_ops.matmul(inputs, self._kernel)
+ x_bar, f_intermediate, r_intermediate = array_ops.split(value=U,
+ num_or_size_splits=3,
+ axis=1)
+
+ f_r = math_ops.sigmoid(nn_ops.bias_add(array_ops.concat(
+ [f_intermediate, r_intermediate], 1), self._bias))
+ f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1)
+
+ c = f * state + (1.0 - f) * x_bar
+ h = r * self._activation(c) + (1.0 - r) * inputs
+
+ return h, c
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/__init__.py b/tensorflow/contrib/seq2seq/python/kernel_tests/__init__.py
index e69de29bb2..52e83069cb 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/__init__.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/seq2seq/python/ops/__init__.py b/tensorflow/contrib/seq2seq/python/ops/__init__.py
index e69de29bb2..52e83069cb 100644
--- a/tensorflow/contrib/seq2seq/python/ops/__init__.py
+++ b/tensorflow/contrib/seq2seq/python/ops/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/specs/__init__.py b/tensorflow/contrib/specs/__init__.py
index e69de29bb2..52e83069cb 100644
--- a/tensorflow/contrib/specs/__init__.py
+++ b/tensorflow/contrib/specs/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/timeseries/examples/__init__.py b/tensorflow/contrib/timeseries/examples/__init__.py
index e69de29bb2..52e83069cb 100644
--- a/tensorflow/contrib/timeseries/examples/__init__.py
+++ b/tensorflow/contrib/timeseries/examples/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/__init__.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/__init__.py
index e69de29bb2..52e83069cb 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/__init__.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/training/python/__init__.py b/tensorflow/contrib/training/python/__init__.py
index e69de29bb2..52e83069cb 100644
--- a/tensorflow/contrib/training/python/__init__.py
+++ b/tensorflow/contrib/training/python/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/training/python/training/__init__.py b/tensorflow/contrib/training/python/training/__init__.py
index e69de29bb2..52e83069cb 100644
--- a/tensorflow/contrib/training/python/training/__init__.py
+++ b/tensorflow/contrib/training/python/training/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index ae38025942..b8ff44bc4b 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -274,7 +274,7 @@ cc_library(
"platform/platform.h",
"platform/protobuf.h",
"platform/types.h",
- ] + glob(tf_additional_proto_hdrs()) + glob(tf_env_time_hdrs()),
+ ] + tf_additional_proto_hdrs() + glob(tf_env_time_hdrs()),
copts = tf_copts(),
deps = tf_lib_proto_parsing_deps(),
)
diff --git a/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBox.pbtxt b/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBox.pbtxt
index 0716b26114..6f1121dd37 100644
--- a/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBox.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBox.pbtxt
@@ -117,7 +117,7 @@ For example,
# Draw the bounding box in an image summary.
image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
bbox_for_draw)
- tf.image_summary('images_with_box', image_with_box)
+ tf.summary.image('images_with_box', image_with_box)
# Employ the bounding box to distort the image.
distorted_image = tf.slice(image, begin, size)
diff --git a/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBoxV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBoxV2.pbtxt
index e991260972..473aec50aa 100644
--- a/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBoxV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBoxV2.pbtxt
@@ -117,7 +117,7 @@ For example,
# Draw the bounding box in an image summary.
image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
bbox_for_draw)
- tf.image_summary('images_with_box', image_with_box)
+ tf.summary.image('images_with_box', image_with_box)
# Employ the bounding box to distort the image.
distorted_image = tf.slice(image, begin, size)
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 6e243c4b7c..a10c176127 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -259,9 +259,10 @@ DirectSession::DirectSession(const SessionOptions& options,
factory_(factory),
cancellation_manager_(new CancellationManager()),
operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
- if (options_.config.session_inter_op_thread_pool_size() > 0) {
- for (int i = 0; i < options_.config.session_inter_op_thread_pool_size();
- ++i) {
+ const int thread_pool_size =
+ options_.config.session_inter_op_thread_pool_size();
+ if (thread_pool_size > 0) {
+ for (int i = 0; i < thread_pool_size; ++i) {
thread::ThreadPool* pool = nullptr;
bool owned = false;
init_error_.Update(NewThreadPoolFromThreadPoolOptions(
diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
index 14f5fdc5d3..df9cf0c91f 100644
--- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
@@ -142,7 +142,7 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelWarmup) {
DirectSession* ds = static_cast<DirectSession*>(session.get());
CostModelManager::CostModelMap cost_models;
ds->ExportCostModels(&cost_models);
- CHECK_EQ(cost_models.size(), 1);
+ CHECK_GE(cost_models.size(), 1);
const CostModel* cm = (*cost_models.begin()).second;
EXPECT_EQ(measure_steps, cm->GetUpdateTimes());
}
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 3beca1e5d2..0ffdc42852 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -2495,14 +2495,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.identity,
mkl_op_registry::GetMklOpName(csinfo_.identity),
CopyAttrsDataType, AlwaysRewrite});
- /*
rinfo_.push_back({csinfo_.lrn,
mkl_op_registry::GetMklOpName(csinfo_.lrn),
CopyAttrsLRN, AlwaysRewrite});
rinfo_.push_back({csinfo_.lrn_grad,
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
CopyAttrsLRN, AlwaysRewrite});
- */
rinfo_.push_back({csinfo_.max_pool,
mkl_op_registry::GetMklOpName(csinfo_.max_pool),
CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
diff --git a/tensorflow/core/kernels/bcast_ops.cc b/tensorflow/core/kernels/bcast_ops.cc
index 2ad2c41636..7fc4b1762d 100644
--- a/tensorflow/core/kernels/bcast_ops.cc
+++ b/tensorflow/core/kernels/bcast_ops.cc
@@ -22,11 +22,10 @@ limitations under the License.
namespace tensorflow {
// Given shapes of two tensors, computes the broadcast shape.
+template <typename T>
class BCastArgsOp : public OpKernel {
public:
- explicit BCastArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->MatchSignature({DT_INT32, DT_INT32}, {DT_INT32}));
- }
+ explicit BCastArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
OP_REQUIRES(
@@ -40,7 +39,7 @@ class BCastArgsOp : public OpKernel {
in.shape().DebugString()));
BCast::Vec vec;
for (int64 i = 0; i < in.NumElements(); ++i) {
- vec.push_back(in.vec<int32>()(i));
+ vec.push_back(in.vec<T>()(i));
}
shapes.push_back(vec);
}
@@ -60,7 +59,7 @@ class BCastArgsOp : public OpKernel {
Tensor* o = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o));
for (int64 i = 0; i < len; ++i) {
- o->flat<int32>()(i) = static_cast<int32>(v[i]);
+ o->flat<T>()(i) = static_cast<T>(v[i]);
}
}
@@ -72,12 +71,10 @@ class BCastArgsOp : public OpKernel {
//
// TODO(zhifengc):
// 1. Adds support for n-ary (n >= 2).
+template <typename T>
class BCastGradArgsOp : public OpKernel {
public:
- explicit BCastGradArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
- OP_REQUIRES_OK(
- ctx, ctx->MatchSignature({DT_INT32, DT_INT32}, {DT_INT32, DT_INT32}));
- }
+ explicit BCastGradArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
OP_REQUIRES(
@@ -91,7 +88,7 @@ class BCastGradArgsOp : public OpKernel {
in.shape().DebugString()));
BCast::Vec vec;
for (int64 i = 0; i < in.NumElements(); ++i) {
- vec.push_back(in.vec<int32>()(i));
+ vec.push_back(in.vec<T>()(i));
}
shapes.push_back(vec);
}
@@ -112,7 +109,7 @@ class BCastGradArgsOp : public OpKernel {
Tensor* o = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o));
for (int64 i = 0; i < len; ++i) {
- o->flat<int32>()(i) = static_cast<int32>(v[i]);
+ o->flat<T>()(i) = static_cast<T>(v[i]);
}
}
@@ -125,14 +122,28 @@ REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
.HostMemory("s0")
.HostMemory("s1")
.HostMemory("r0"),
- BCastArgsOp);
+ BCastArgsOp<int32>);
+REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<int64>("T")
+ .HostMemory("s0")
+ .HostMemory("s1")
+ .HostMemory("r0"),
+ BCastArgsOp<int64>);
REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
.Device(DEVICE_GPU)
.TypeConstraint<int32>("T")
.HostMemory("s0")
.HostMemory("s1")
.HostMemory("r0"),
- BCastArgsOp);
+ BCastArgsOp<int32>);
+REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int64>("T")
+ .HostMemory("s0")
+ .HostMemory("s1")
+ .HostMemory("r0"),
+ BCastArgsOp<int64>);
#if TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
@@ -141,7 +152,14 @@ REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
.HostMemory("s0")
.HostMemory("s1")
.HostMemory("r0"),
- BCastArgsOp);
+ BCastArgsOp<int32>);
+REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int64>("T")
+ .HostMemory("s0")
+ .HostMemory("s1")
+ .HostMemory("r0"),
+ BCastArgsOp<int32>);
#endif
REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
@@ -151,7 +169,15 @@ REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
.HostMemory("s1")
.HostMemory("r0")
.HostMemory("r1"),
- BCastGradArgsOp);
+ BCastGradArgsOp<int32>);
+REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<int64>("T")
+ .HostMemory("s0")
+ .HostMemory("s1")
+ .HostMemory("r0")
+ .HostMemory("r1"),
+ BCastGradArgsOp<int64>);
REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
.Device(DEVICE_GPU)
.TypeConstraint<int32>("T")
@@ -159,7 +185,15 @@ REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
.HostMemory("s1")
.HostMemory("r0")
.HostMemory("r1"),
- BCastGradArgsOp);
+ BCastGradArgsOp<int32>);
+REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int64>("T")
+ .HostMemory("s0")
+ .HostMemory("s1")
+ .HostMemory("r0")
+ .HostMemory("r1"),
+ BCastGradArgsOp<int64>);
#if TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
@@ -169,6 +203,14 @@ REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
.HostMemory("s1")
.HostMemory("r0")
.HostMemory("r1"),
- BCastGradArgsOp);
+ BCastGradArgsOp<int32>);
+REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int64>("T")
+ .HostMemory("s0")
+ .HostMemory("s1")
+ .HostMemory("r0")
+ .HostMemory("r1"),
+ BCastGradArgsOp<int64>);
#endif
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index 6f82698596..57e196c67c 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -146,7 +146,7 @@ class ConvParameters {
int64 total_size = 16 * std::ceil(batch_ / 16.0) *
std::max(in_depths_, out_depths_) * in_[0] * in_[1] *
sizeof(T);
- int64 threshold = 1L << 31;
+ int64 threshold = 1LL << 31;
if (total_size >= threshold) {
return false;
} else {
diff --git a/tensorflow/core/kernels/example_parsing_ops_test.cc b/tensorflow/core/kernels/example_parsing_ops_test.cc
index 0a64a6c154..5d06eda79e 100644
--- a/tensorflow/core/kernels/example_parsing_ops_test.cc
+++ b/tensorflow/core/kernels/example_parsing_ops_test.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <mutex>
#include <unordered_map>
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
@@ -80,6 +81,26 @@ class FloatFiller {
template <typename T>
struct ExampleStore {
+ private:
+ static ExampleTensorMap serialized_example;
+ static std::once_flag flags_init;
+
+ public:
+ static ExampleTensorMap& GetSerializedExample() {
+ std::call_once(flags_init, [] {
+ AddExample(&serialized_example, 10, 1, 1);
+ AddExample(&serialized_example, 100, 1, 1);
+ AddExample(&serialized_example, 1000, 1, 1);
+ AddExample(&serialized_example, 10, 128, 1);
+ AddExample(&serialized_example, 100, 128, 1);
+ AddExample(&serialized_example, 1000, 128, 1);
+ AddExample(&serialized_example, 10, 512, 1);
+ AddExample(&serialized_example, 100, 512, 1);
+ AddExample(&serialized_example, 1000, 512, 1);
+ AddExample(&serialized_example, 1, 1, 1000000);
+ });
+ return serialized_example;
+ }
typedef T Filler;
static void AddExample(ExampleTensorMap* examples, int num_keys,
int batch_size, int feature_size) {
@@ -101,34 +122,15 @@ struct ExampleStore {
(*examples)[std::make_tuple(batch_size, num_keys, feature_size)] =
record_string;
}
- static ExampleTensorMap GetSerializedExamples() {
- ExampleTensorMap examples;
- AddExample(&examples, 10, 1, 1);
- AddExample(&examples, 100, 1, 1);
- AddExample(&examples, 1000, 1, 1);
- AddExample(&examples, 10, 128, 1);
- AddExample(&examples, 100, 128, 1);
- AddExample(&examples, 1000, 128, 1);
- AddExample(&examples, 10, 512, 1);
- AddExample(&examples, 100, 512, 1);
- AddExample(&examples, 1000, 512, 1);
- AddExample(&examples, 1, 1, 1000000);
- return examples;
- }
- static ExampleTensorMap serialized_example;
};
+template <typename T>
+ExampleTensorMap ExampleStore<T>::serialized_example;
+template <typename T>
+std::once_flag ExampleStore<T>::flags_init;
-template <>
-ExampleTensorMap ExampleStore<BytesFiller>::serialized_example =
- ExampleStore<BytesFiller>::GetSerializedExamples();
-
-template <>
-ExampleTensorMap ExampleStore<Int64Filler>::serialized_example =
- ExampleStore<Int64Filler>::GetSerializedExamples();
-
-template <>
-ExampleTensorMap ExampleStore<FloatFiller>::serialized_example =
- ExampleStore<FloatFiller>::GetSerializedExamples();
+template class ExampleStore<BytesFiller>;
+template class ExampleStore<Int64Filler>;
+template class ExampleStore<FloatFiller>;
enum BenchmarkType { kDense, kSparse, kVarLenDense };
@@ -142,7 +144,7 @@ struct BenchmarkOptions {
template <typename Options>
static Graph* ParseExample(int batch_size, int num_keys, int feature_size) {
Graph* g = new Graph(OpRegistry::Global());
- Tensor& serialized = Options::Store::serialized_example[std::make_tuple(
+ Tensor& serialized = Options::Store::GetSerializedExample()[std::make_tuple(
batch_size, num_keys, feature_size)];
Tensor names(DT_STRING, TensorShape({batch_size}));
@@ -193,8 +195,8 @@ template <typename Options>
static Graph* ParseSingleExample(int num_keys, int feature_size) {
Graph* g = new Graph(OpRegistry::Global());
Tensor& serialized_batch_1 =
- Options::Store::serialized_example[std::make_tuple(1, num_keys,
- feature_size)];
+ Options::Store::GetSerializedExample()[std::make_tuple(1, num_keys,
+ feature_size)];
Tensor serialized(DT_STRING, TensorShape());
serialized.scalar<string>()() = serialized_batch_1.vec<string>()(0);
diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc
index 9aabbbdb6b..44b94be3a0 100644
--- a/tensorflow/core/kernels/mkl_aggregate_ops.cc
+++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc
@@ -294,7 +294,7 @@ class MklAddNOp : public OpKernel {
try {
auto cpu_engine = engine(engine::cpu, 0);
- size_t src1_idx = 0, src2_idx = 1;
+ size_t src1_idx = 0, src2_idx = 1, output_idx = 0;
const Tensor& src1_tensor = MklGetInput(ctx, src1_idx);
const Tensor& src2_tensor = MklGetInput(ctx, src2_idx);
@@ -312,7 +312,7 @@ class MklAddNOp : public OpKernel {
Tensor* dst_tensor = nullptr;
MklShape mkl_shape_dst;
mkl_shape_dst.SetMklTensor(false);
- AllocateOutputSetMklShape(ctx, src1_idx, &dst_tensor,
+ AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
src1_tensor.shape(), mkl_shape_dst);
float user_i1 = (src1_tensor.scalar<T>()());
float user_i2 = (src2_tensor.scalar<T>()());
@@ -327,13 +327,12 @@ class MklAddNOp : public OpKernel {
Tensor* dst_tensor = nullptr;
MklShape mkl_shape_dst;
mkl_shape_dst.SetMklTensor(false);
- AllocateOutputSetMklShape(ctx, src1_idx, &dst_tensor,
+ AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
src1_tensor.shape(), mkl_shape_dst);
return;
}
}
- // element-wise add operator for tensor input1 and tensor input2
std::vector<double> coeff(2, 1.0);
MklDnnData<T> src1(&cpu_engine);
MklDnnData<T> src2(&cpu_engine);
@@ -345,70 +344,124 @@ class MklAddNOp : public OpKernel {
memory::desc md1({}, memory::data_undef, memory::format_undef);
memory::desc md2({}, memory::data_undef, memory::format_undef);
- if ( input1_in_mkl_format || input2_in_mkl_format ) {
- if ( input1_in_mkl_format ) {
- md1 = src1_mkl_shape.GetMklLayout();
- md2 = md1;
- dst.SetUsrMem(md1);
- } else {
- md2 = src2_mkl_shape.GetMklLayout();
- md1 = md2;
- dst.SetUsrMem(md2);
- }
+ // For creating Sum primitive, we need to ensure that all inputs are in
+ // same format. What that means is if we have a mixed input case - where
+ // one input is in Tensorflow format and one input is in MKL format -,
+ // then we need to ensure that all inputs are in same format for
+ // primitive construction. For performance reason, we say that all inputs
+ // are in MKL format in such case, and insert reorder for input that is
+ // in Tensorflow format into MKL format. On the other hand, if both the
+ // inputs are in MKL format or both are in Tensorflow format, then we
+ // dont need reorder.
+ if (!input1_in_mkl_format && !input2_in_mkl_format) {
+ // If both the inputs are in Tensorflow format, we create blocked memory
+ // descriptor.
+ dims = TFShapeToMklDnnDims(src1_tensor.shape());
+ strides = CalculateTFStrides(dims);
+ md1 = MklDnnData<T>::CreateBlockedMemDesc(dims, strides);
+ md2 = md1;
+ } else if (input1_in_mkl_format && !input2_in_mkl_format) {
+ // If one input is in MKL format and other is in Tensorflow, then
+ // create respective descriptors describing the actual case. For input
+ // in Mkl format, we just get Mkl layout from MklDnnShape. For input in
+ // Tensorflow format, we create memory descriptor using data format.
+ md1 = src1_mkl_shape.GetMklLayout();
+
+ memory::format src1_mkl_data_format = src1_mkl_shape.GetTfDataFormat();
+ auto src1_tf_data_format = MklDnnDataFormatToTFDataFormat(
+ src1_mkl_data_format);
+ auto src2_dims = TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(),
+ src1_tf_data_format);
+ md2 = memory::desc(src2_dims, MklDnnType<T>(),
+ src1_mkl_data_format);
+ } else if (input2_in_mkl_format && !input1_in_mkl_format) {
+ // Same comment as above.
+ memory::format src2_mkl_data_format = src2_mkl_shape.GetTfDataFormat();
+ auto src2_tf_data_format = MklDnnDataFormatToTFDataFormat(
+ src2_mkl_data_format);
+ auto src1_dims = TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(),
+ src2_tf_data_format);
+ md1 = memory::desc(src1_dims, MklDnnType<T>(),
+ src2_mkl_data_format);
+
+ md2 = src2_mkl_shape.GetMklLayout();
} else {
- dims = TFShapeToMklDnnDims(src1_tensor.shape());
- strides = CalculateTFStrides(dims);
- md1 = MklDnnData<T>::CreateBlockedMemDesc(dims, strides);
- md2 = md1;
- dst.SetUsrMem(dims, strides);
+ // If both the inputs are in MKL format, we use Mkl layout of the input
+ // tensors.
+ md1 = src1_mkl_shape.GetMklLayout();
+ md2 = src2_mkl_shape.GetMklLayout();
}
-
- std::vector<memory::primitive_desc> srcs_pd;
-
src1.SetUsrMem(md1, &src1_tensor);
- auto mpd1 = src1.GetUsrMemPrimDesc();
- srcs_pd.push_back(mpd1);
-
src2.SetUsrMem(md2, &src2_tensor);
- auto mpd2 = src2.GetUsrMemPrimDesc();
- srcs_pd.push_back(mpd2);
+ // As per comment above, we tell MKLDNN that both the inputs are in same
+ // format. So we set common memory descriptor in MKL format, if any of the
+ // inputs are in MKL format. Let's get memory descriptor that we will use
+ // for both the inputs.
+ // We set output memory descriptor in MKL format, if any of the
+ // inputs are in MKL format.
+ memory::desc common_md({}, memory::data_undef, memory::format_undef);
+ if (input1_in_mkl_format || input2_in_mkl_format) {
+ common_md = input1_in_mkl_format ? md1 : md2;
+ dst.SetUsrMem(common_md);
+ } else {
+ // Since both the inputs are in Tensorflow format, and have
+ // same shape, we can get memory descriptor from any input.
+ common_md = md1;
+ dst.SetUsrMem(common_md);
+ }
+
+ std::vector<memory::primitive_desc> srcs_pd;
+ // Memory descriptor for 1st input
+ srcs_pd.push_back(memory::primitive_desc(common_md, cpu_engine));
+ // Memory descriptor for 2nd input
+ srcs_pd.push_back(memory::primitive_desc(common_md, cpu_engine));
+ auto sum_pd = sum::primitive_desc(dst.GetUsrMemDesc(), coeff, srcs_pd);
+
+ // Now we setup resources for primitive execution.
+ // First, we need to check if any of the inputs need to be reordered as
+ // per the logic described above. Since output will be in MKL format if
+ // atleast one input is in MKL format, we choose output descriptor for
+ // reorder.
std::vector<primitive::at> inputs;
+ std::vector<primitive> net;
+ // Check if actual input format of the tensor is different than common_pd
+ // we told MKLDNN. In that case, we will need reorder.
+ src1.CheckReorderToOpMem(srcs_pd[0], &net);
+ src2.CheckReorderToOpMem(srcs_pd[1], &net);
inputs.push_back(src1.GetOpMem());
inputs.push_back(src2.GetOpMem());
- auto output_pd = dst.GetUsrMemPrimDesc();
+
+ // Allocate output tensor now.
Tensor* dst_tensor = nullptr;
- auto sum_pd = sum::primitive_desc(dst.GetUsrMemDesc(), coeff, srcs_pd);
- auto sum_op = sum(sum_pd, inputs, dst.GetOpMem());
- if ( input2_in_mkl_format || input1_in_mkl_format ) {
- MklDnnShape output_mkl_shape;
- output_mkl_shape.SetMklTensor(true);
- output_mkl_shape.SetMklLayout(&output_pd);
- output_mkl_shape.SetElemType(MklDnnType<T>());
- if ( input1_in_mkl_format ) {
+ MklDnnShape output_mkl_shape;
+ TensorShape output_tf_shape;
+
+ if (input2_in_mkl_format || input1_in_mkl_format) {
+ output_mkl_shape.SetMklTensor(true);
+ auto output_pd = dst.GetUsrMemPrimDesc();
+ output_mkl_shape.SetMklLayout(&output_pd);
+ output_mkl_shape.SetElemType(MklDnnType<T>());
+ if (input1_in_mkl_format) {
output_mkl_shape.SetTfLayout(src1_dims_size,
- src1_mkl_shape.GetSizesAsMklDnnDims(),
- src1_mkl_shape.GetTfDataFormat());
- } else {
+ src1_mkl_shape.GetSizesAsMklDnnDims(),
+ src1_mkl_shape.GetTfDataFormat());
+ } else {
output_mkl_shape.SetTfLayout(src2_dims_size,
- src2_mkl_shape.GetSizesAsMklDnnDims(),
- src2_mkl_shape.GetTfDataFormat());
- }
- TensorShape output_tf_shape;
- output_tf_shape.AddDim((output_pd.get_size() / sizeof(T))
- + (output_pd.get_size()%sizeof(T) == 0 ? 0 : 1));
- AllocateOutputSetMklShape(ctx, src1_idx, &dst_tensor, output_tf_shape,
- output_mkl_shape);
+ src2_mkl_shape.GetSizesAsMklDnnDims(),
+ src2_mkl_shape.GetTfDataFormat());
+ }
+ output_tf_shape.AddDim((output_pd.get_size() / sizeof(T)));
} else {
- MklShape mkl_shape_dst;
- mkl_shape_dst.SetMklTensor(false);
- AllocateOutputSetMklShape(ctx, src1_idx,
- &dst_tensor, src1_tensor.shape(), mkl_shape_dst);
+ output_mkl_shape.SetMklTensor(false);
+ output_tf_shape = src1_tensor.shape();
}
-
+ AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
+ output_tf_shape, output_mkl_shape);
dst.SetUsrMemDataHandle(dst_tensor);
- std::vector<primitive> net;
- net.push_back(sum_op);
+
+ // Create Sum op, and submit net for execution.
+ net.push_back(sum(sum_pd, inputs, dst.GetOpMem()));
stream(stream::kind::eager).submit(net).wait();
} catch (mkldnn::error &e) {
string error_msg = "Status: " + std::to_string(e.status) +
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index df51df9638..db9e97e7ca 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -367,6 +367,9 @@ class MklConv2DCustomBackpropInputOp :
~MklConv2DCustomBackpropInputOp() {}
private:
+ const int kInputIndex_Filter = 1,
+ kInputIndex_InputSizes = 0,
+ kInputIndex_OutBackProp = 2;
void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
@@ -377,7 +380,7 @@ class MklConv2DCustomBackpropInputOp :
<< "Conv2DBackpropInput: input should not be in MKL Layout";
}
- size_t GetInputTensorIndexWithSizes() { return 0; /* input index */ }
+ size_t GetInputTensorIndexWithSizes() { return kInputIndex_InputSizes; }
TensorShape MakeInputTfShape(OpKernelContext* context,
const Tensor& input_tensor) {
@@ -390,8 +393,7 @@ class MklConv2DCustomBackpropInputOp :
TensorShape MakeFilterTfShape(OpKernelContext* context,
const Tensor& filter_tensor) {
- size_t filter_idx = 1;
- return GetTfShape(context, filter_idx);
+ return GetTfShape(context, kInputIndex_Filter);
}
const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 04268f23bb..a4e139bb54 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -510,15 +510,15 @@ class MklConv2DOp : public OpKernel {
auto cpu_engine = engine(engine::cpu, 0);
// Input tensors
- size_t src_idx = 0, filter_idx = 1;
- const Tensor& src_tensor = MklGetInput(context, src_idx);
- const Tensor& filter_tensor = MklGetInput(context, filter_idx);
+ const Tensor& src_tensor = MklGetInput(context, kInputIndex_Src);
+ const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter);
MklDnnShape src_mkl_shape, filter_mkl_shape;
- GetMklShape(context, src_idx, &src_mkl_shape);
- GetMklShape(context, filter_idx, &filter_mkl_shape);
- CHECK(!filter_mkl_shape.IsMklTensor())
- << "Conv2D filter should not be in MKL Layout";
+ GetMklShape(context, kInputIndex_Src, &src_mkl_shape);
+ GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape);
+ OP_REQUIRES(context, filter_mkl_shape.IsMklTensor() == false,
+ errors::InvalidArgument("Filter should not be in "
+ "Mkl Layout"));
MklDnnData<T> src(&cpu_engine);
MklDnnData<T> filter(&cpu_engine);
@@ -529,8 +529,8 @@ class MklConv2DOp : public OpKernel {
// Get shapes of input tensors in MKL-DNN order
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
- auto src_tf_shape = GetTfShape(context, src_idx);
- auto filter_tf_shape = GetTfShape(context, filter_idx);
+ auto src_tf_shape = GetTfShape(context, kInputIndex_Src);
+ auto filter_tf_shape = GetTfShape(context, kInputIndex_Filter);
conv_utl.GetConvFwdSizesInMklOrder(src_tf_shape, filter_tf_shape,
&src_dims, &filter_dims, &strides,
&output_dims_tf_order,
@@ -541,9 +541,6 @@ class MklConv2DOp : public OpKernel {
// Check for corner case - if there is nothing to compute, return.
TensorShape output_tf_shape = MklDnnDimsToTFShape(output_dims_tf_order);
- // Forward filter in TF format from input at index 1 to output at index 1.
- ForwardTfTensorInToOut(context, 1, 1);
-
// Corner cases: output with 0 elements and 0 batch size.
Tensor* output_tensor = nullptr;
if (output_tf_shape.num_elements() == 0 ||
@@ -552,8 +549,8 @@ class MklConv2DOp : public OpKernel {
// Need semantics for Null MKL tensor
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(false);
- AllocateOutputSetMklShape(context, 0, &output_tensor, src_tf_shape,
- output_mkl_shape);
+ AllocateOutputSetMklShape(context, kOutputIndex_Dst, &output_tensor,
+ src_tf_shape, output_mkl_shape);
return;
}
@@ -571,10 +568,11 @@ class MklConv2DOp : public OpKernel {
src.SetUsrMem(src_md, &src_tensor);
// Although filter shape (filter_dims) required is in MKL-DNN order,
// the layout is Tensorflow's layout (HWIO).
- auto filter_md = filter_mkl_shape.IsMklTensor()
+ auto filter_md = filter_mkl_shape.IsMklTensor() // Should NEVER be true
? filter_mkl_shape.GetMklLayout()
: memory::desc(filter_dims, MklDnnType<T>(), memory::format::hwio);
filter.SetUsrMem(filter_md, &filter_tensor);
+
// Set output shape (output_dims) required in MKL-DNN order.
// Currently, we set output layout as Tensorflow's layout (NHWC or NCHW
// depending on data format). But later we propagate Mkl layout of the
@@ -590,8 +588,8 @@ class MklConv2DOp : public OpKernel {
if (biasEnabled) {
MklDnnData<T> bias(&cpu_engine);
memory::dims bias_size;
- conv_utl.GetBiasSizeInMklOrder(2 /* bias idx */, &bias_size);
- const Tensor& bias_tensor = MklGetInput(context, 2);
+ conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_size);
+ const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
bias.SetUsrMem(bias_size, memory::format::x, &bias_tensor);
bias.SetOpMemDesc(bias_size, memory::format::any);
@@ -607,7 +605,14 @@ class MklConv2DOp : public OpKernel {
output_dims_mkl_order, tf_fmt, &output_tensor);
// Set data handle for output.
output.SetUsrMemDataHandle(output_tensor);
- PrepareAndExecuteNet(conv_prim_desc, &src, &filter, &bias, &output);
+
+ Tensor* filter_out_tensor = nullptr;
+ AllocateFilterOutputTensor(context, conv_prim_desc,
+ TFShapeToMklDnnDims(filter_tf_shape),
+ &filter_out_tensor);
+
+ PrepareAndExecuteNet(conv_prim_desc, &src, &filter,
+ &bias, &output, filter_out_tensor);
} else {
// Create convolution primitive without Bias.
auto conv_desc = convolution_forward::desc(prop_kind::forward,
@@ -621,7 +626,13 @@ class MklConv2DOp : public OpKernel {
tf_fmt, &output_tensor);
// Set data handle for output.
output.SetUsrMemDataHandle(output_tensor);
- PrepareAndExecuteNet(conv_prim_desc, &src, &filter, nullptr, &output);
+
+ Tensor* filter_out_tensor = nullptr;
+ AllocateFilterOutputTensor(context, conv_prim_desc,
+ TFShapeToMklDnnDims(filter_tf_shape),
+ &filter_out_tensor);
+ PrepareAndExecuteNet(conv_prim_desc, &src, &filter,
+ nullptr, &output, filter_out_tensor);
}
} catch (mkldnn::error &e) {
string error_msg = "Status: " + std::to_string(e.status) +
@@ -637,6 +648,10 @@ class MklConv2DOp : public OpKernel {
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
+ const int kInputIndex_Src = 0,
+ kInputIndex_Filter = 1,
+ kInputIndex_Bias = 2;
+ const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
// Allocate output tensor.
void AllocateOutputTensor(
@@ -653,28 +668,63 @@ class MklConv2DOp : public OpKernel {
output_mkl_shape.SetMklLayout(&dst_pd);
output_mkl_shape.SetElemType(MklDnnType<T>());
output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
- output_dims_mkl_order, output_tf_format);
+ output_dims_mkl_order, output_tf_format);
// Allocate shape of TF tensor.
TensorShape output_tf_shape;
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T)));
- const int kOutputSlotIdx = 0;
- AllocateOutputSetMklShape(context, kOutputSlotIdx, output_tensor,
+ AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
output_tf_shape, output_mkl_shape);
}
+ // Allocate output tensor.
+ void AllocateFilterOutputTensor(
+ OpKernelContext* context,
+ const convolution_forward::primitive_desc& conv_prim_desc,
+ const memory::dims& filter_dims_tf_order,
+ Tensor** filter_tensor) {
+ CHECK_NOTNULL(filter_tensor);
+ auto filter_pd = conv_prim_desc.weights_primitive_desc();
+
+ // Allocate shape of Mkl tensor.
+ MklDnnShape filter_mkl_shape;
+ filter_mkl_shape.SetMklTensor(true);
+ filter_mkl_shape.SetMklLayout(&filter_pd);
+ filter_mkl_shape.SetElemType(MklDnnType<T>());
+
+ // The format of the filter is actually OIhw8i8o, but TF doesn't support
+ // this format. Just use format::blocked for now because the layout
+ // is stored in the MKL data.
+ filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
+ filter_dims_tf_order, memory::format::blocked);
+
+ // Allocate the data space for the filter to propagate as TF tensor.
+ TensorShape filter_tf_shape;
+ filter_tf_shape.AddDim((filter_pd.get_size() / sizeof(T)));
+
+ AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor,
+ filter_tf_shape, filter_mkl_shape);
+ }
+
// Prepare and execute net - checks for input and output reorders.
void PrepareAndExecuteNet(
const convolution_forward::primitive_desc& conv_prim_desc,
MklDnnData<T>* src, MklDnnData<T>* filter,
- MklDnnData<T>* bias, MklDnnData<T>* output) {
+ MklDnnData<T>* bias, MklDnnData<T>* output,
+ Tensor* filter_out_tensor) {
+ CHECK_NOTNULL(filter_out_tensor);
+
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution. No need to check for output
// reorder as we propagate output layout to the next layer.
std::vector<primitive> net;
src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc(), &net);
- filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(), &net);
+
+ // rather than re-order to a temp buffer, reorder directly to the
+ // filter output tensor
+ filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(),
+ filter->GetTensorBuffer(filter_out_tensor), &net);
// Create convolution primitive and add it to net.
if (bias) {
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index 47a9b4bfc7..b6883dbaa2 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include <limits>
+#include <string>
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -288,7 +289,7 @@ class MklDnnConvUtil {
OP_REQUIRES(context_, input_tf_shape.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
- input_tf_shape.DebugString()));
+ input_tf_shape.DebugString()));
GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape,
strides, output_dims_tf_order,
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc
index 227765e46d..a8f28202f4 100644
--- a/tensorflow/core/kernels/mkl_lrn_op.cc
+++ b/tensorflow/core/kernels/mkl_lrn_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
// See docs in ../ops/nn_ops.cc. This opkernel uses MKL library, create MKL
// layout and primitives, use MKL dnn primitives to compute local
// response normalization
-#undef INTEL_MKL
+
#ifdef INTEL_MKL
#define EIGEN_USE_THREADS
@@ -38,6 +38,15 @@ limitations under the License.
#include "tensorflow/core/util/work_sharder.h"
#endif
+#ifdef INTEL_MKL_DNN
+#include "mkldnn.hpp"
+using mkldnn::lrn_forward;
+using mkldnn::lrn_backward;
+using mkldnn::prop_kind;
+using mkldnn::algorithm::lrn_across_channels;
+using mkldnn::stream;
+#endif
+
namespace tensorflow {
namespace {
@@ -58,6 +67,8 @@ void GetBandMatrix(int depth, int depth_radius,
} // namespace
+#ifndef INTEL_MKL_DNN
+
template <typename T>
class MklLRNOp : public OpKernel {
public:
@@ -328,6 +339,7 @@ class MklLRNOp : public OpKernel {
float beta_;
};
+
template <typename T>
class MklLRNGradOp : public OpKernel {
public:
@@ -648,6 +660,7 @@ class MklLRNGradOp : public OpKernel {
const auto nodes = cols * rows;
auto grads_shaped = in_grads.shaped<T, 2>({nodes * batch, depth});
+
auto in_shaped = in_image.shaped<T, 2>({nodes * batch, depth});
auto activations = out_image.shaped<T, 2>({nodes * batch, depth});
@@ -717,6 +730,649 @@ class MklLRNGradOp : public OpKernel {
float beta_;
};
+#else
+
+template <typename T>
+class MklLRNOp : public OpKernel {
+ public:
+ ~MklLRNOp() {}
+
+ explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) {
+ int64 depth_radius64;
+ OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
+ OP_REQUIRES(context, FastBoundsCheck(depth_radius64,
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("depth_radius = ", depth_radius64,
+ " larger than int max"));
+ depth_radius_ = static_cast<size_t>(depth_radius64);
+
+ OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
+ OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
+ workspace_enabled_ = false;
+ context->GetAttr("workspace_enabled", &workspace_enabled_);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ try {
+ SanityCheckInputs(context);
+ if (!context->status().ok()) return;
+
+ auto cpu_engine = engine(engine::cpu, 0);
+ const Tensor& src_tensor = MklGetInput(context, kIdxInput);
+ MklDnnShape src_dnn_shape;
+ GetMklShape(context, kIdxInput, &src_dnn_shape);
+
+ // MKL-DNN has a notion of kernel_size and not depth_radius.
+ int kernel_size = 2 * depth_radius_ + 1;
+ float new_alpha = alpha_ * kernel_size;
+
+ // if the input tensor is not an MKL Tensor, or if the last
+ // dimension is not channel, then just use Eigen.
+ // MKL only support normalization over the channel dimension.
+ if (!src_dnn_shape.IsMklTensor()) {
+ MklDefaultToEigen(context, src_tensor);
+ return;
+ } else if (!src_dnn_shape.IsMklChannelDim(
+ src_dnn_shape.GetDimension() - 1) ) {
+ Tensor converted_tensor =
+ ConvertMklToTF<T>(context, src_tensor, src_dnn_shape);
+ MklDefaultToEigen(context, converted_tensor);
+ return;
+ }
+ // At this point, we can assume that the src is an MklTensor
+ // and we can enable the workspace
+ workspace_enabled_ = true;
+
+ MklDnnData<T> src_dnn_data(&cpu_engine);
+ MklDnnData<T> dst_dnn_data(&cpu_engine);
+ MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
+
+ TensorShape tf_output_shape = src_tensor.shape();
+
+ memory::desc src_md = src_dnn_shape.GetCurLayout();
+ memory::dims input_dims = src_dnn_shape.GetSizesAsMklDnnDims();
+
+ // Create memory for user input.
+ // Since Tensorflow always performs normalization over last dimension,
+ // and MKL-DNN performs normalization over Channel, we tell MKL-DNN
+ // that input is in NHWC layout with Channel being the last dimension.
+ src_dnn_data.SetUsrMem(src_md, &src_tensor);
+ src_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc);
+
+ // output_dnn_data and workspace both have the same shape as input
+ dst_dnn_data.SetUsrMem(src_md);
+ dst_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc);
+
+ // Create LRN primitive descriptor.
+ // Tensorflow's normalization semantics is across channels.
+ // MKL-DNN also supports normalization within channel.
+ auto lrn_desc = lrn_forward::desc(prop_kind::forward,
+ lrn_across_channels,
+ src_dnn_data.GetUsrMemDesc(),
+ kernel_size,
+ new_alpha, beta_, bias_);
+ auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine);
+
+ // Allocate output_dnn_data tensor.
+ Tensor* output_tensor = nullptr;
+ memory::format input_format = src_dnn_shape.GetTfDataFormat();
+ AllocateOutputTensor(context, lrn_prim_desc, input_dims,
+ input_format, &output_tensor);
+ OP_REQUIRES_OK(context, context->status());
+ CHECK_NOTNULL(output_tensor);
+ dst_dnn_data.SetUsrMemDataHandle(output_tensor);
+
+ // Handle workspace required for MKL-DNN.
+ AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
+ OP_REQUIRES_OK(context, context->status());
+
+ PrepareAndExecuteNet(lrn_prim_desc, &src_dnn_data,
+ &dst_dnn_data, &workspace_dnn_data);
+ } catch (mkldnn::error &e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) +
+ ", in file " + string(__FILE__) + ":" +
+ std::to_string(__LINE__);
+ OP_REQUIRES_OK(context,
+ errors::Aborted("Operation received an exception:",
+ error_msg));
+ }
+ }
+
+ private:
+ void PrepareAndExecuteNet(
+ const lrn_forward::primitive_desc& lrn_fwd_desc,
+ MklDnnData<T>* src_dnn_data,
+ MklDnnData<T>* dst_dnn_data,
+ MklDnnData<uint8>* wksp_dnn_data = nullptr) {
+ std::vector<primitive> net;
+
+ // Check for input reorder
+ src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc(), &net);
+
+ // Create pooling primitive and add it to net
+ if (wksp_dnn_data != nullptr) {
+ net.push_back(lrn_forward(lrn_fwd_desc,
+ src_dnn_data->GetOpMem(),
+ wksp_dnn_data->GetOpMem(),
+ dst_dnn_data->GetOpMem()));
+ } else {
+ net.push_back(lrn_forward(lrn_fwd_desc,
+ src_dnn_data->GetOpMem(),
+ dst_dnn_data->GetOpMem()));
+ }
+ stream(stream::kind::eager).submit(net).wait();
+ }
+
+ void AllocateOutputTensor(OpKernelContext* context,
+ const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
+ const memory::dims output_dims_mkl_order,
+ const memory::format& output_tf_format,
+ Tensor** output_tensor) {
+ CHECK_NOTNULL(output_tensor);
+ memory::primitive_desc dst_pd = lrn_fwd_prim_desc.dst_primitive_desc();
+
+ MklDnnShape output_mkl_shape;
+ // We only handle the case when the inputs and output are in Mkl format
+ // Any other case is handled by Eigen
+ output_mkl_shape.SetMklTensor(true);
+ output_mkl_shape.SetMklLayout(&dst_pd);
+ output_mkl_shape.SetElemType(MklDnnType<T>());
+ output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
+ output_dims_mkl_order,
+ output_tf_format);
+ TensorShape output_tf_shape;
+ // only allocate enough space for the elements we need.
+ size_t num_bytes = dst_pd.get_size();
+ CHECK_EQ(num_bytes % sizeof(T), 0);
+ output_tf_shape.AddDim(num_bytes / sizeof(T));
+ AllocateOutputSetMklShape(context, kIdxOutput,
+ output_tensor,
+ output_tf_shape, output_mkl_shape);
+ }
+
+ // Fallback implementation - Taken from lrn_op.cc
+ // TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a
+ // copy.
+ void MklDefaultToEigen(OpKernelContext* context,
+ const Tensor& input) {
+ const int batch = static_cast<int>(input.dim_size(0));
+ const int rows = static_cast<int>(input.dim_size(1));
+ const int cols = static_cast<int>(input.dim_size(2));
+ const int depth = static_cast<int>(input.dim_size(3));
+ const int nodes = cols * rows;
+
+ auto in_shaped = input.shaped<T, 2>({nodes * batch, depth});
+ // Multiplying the input with the band matrix has the effect of reducing
+ // the
+ // correct patch along the depth.
+ Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
+ GetBandMatrix<T>(depth, depth_radius_, &multiplier);
+
+ Tensor *output_dnn_data, *workspace;
+ MklDnnShape mkl_output_mkl_shape, mkl_workspace_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(false);
+ mkl_output_mkl_shape.SetDimensions(4);
+ AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
+ input.shape(), mkl_output_mkl_shape);
+
+ mkl_workspace_mkl_shape.SetMklTensor(false);
+ mkl_workspace_mkl_shape.SetDimensions(4);
+ AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace,
+ input.shape(), mkl_workspace_mkl_shape);
+
+ auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
+ Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
+ auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
+ if (beta_ == T(1)) {
+ out_shaped.device(context->eigen_cpu_device()) =
+ in_shaped * tmp.inverse();
+ } else if (beta_ == T(0.5)) {
+ out_shaped.device(context->eigen_cpu_device()) =
+ in_shaped * tmp.rsqrt();
+ } else {
+ out_shaped.device(context->eigen_cpu_device()) =
+ in_shaped * (tmp.log() * -beta_).exp();
+ }
+ }
+
+ void AllocateWorkspaceTensor(OpKernelContext* context,
+ const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
+ MklDnnData<uint8>* dnn_data_wksp) {
+ CHECK_NOTNULL(dnn_data_wksp);
+ Tensor* workspace_tensor = nullptr;
+ memory::primitive_desc workspace_pd
+ = lrn_fwd_prim_desc.workspace_primitive_desc();
+ size_t workspace_bytes = workspace_pd.get_size();
+ MklDnnShape workspace_mkl_shape;
+ // the workspace tensor is a uint8 tensor that has
+ // exactly the number of bytes necessary
+ workspace_mkl_shape.SetMklTensor(false);
+ TensorShape workspace_tf_shape;
+ workspace_tf_shape.AddDim(workspace_bytes);
+ AllocateOutputSetMklShape(context, kIdxWorkspace,
+ &workspace_tensor,
+ workspace_tf_shape, workspace_mkl_shape);
+ CHECK_NOTNULL(workspace_tensor);
+ dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor);
+ }
+
+ void SanityCheckInputs(OpKernelContext* context) {
+ const Tensor& src_tensor = MklGetInput(context, kIdxInput);
+ MklDnnShape src_dnn_shape;
+ GetMklShape(context, kIdxInput, &src_dnn_shape);
+ if (src_dnn_shape.IsMklTensor()) {
+ OP_REQUIRES(context, src_dnn_shape.GetDimension() == 4,
+ errors::InvalidArgument("input must be 4-dimensional"));
+ OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("argument to LRN too large"));
+ } else {
+ OP_REQUIRES(context, src_tensor.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional"));
+ OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("argument to LRN too large"));
+ }
+ }
+ const int kIdxInput = 0,
+ kIdxOutput = 0,
+ kIdxWorkspace = 1;
+
+ typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
+ bool workspace_enabled_;
+ int depth_radius_;
+ float bias_;
+ float alpha_;
+ float beta_;
+};
+
+
+template <typename T>
+class MklLRNGradOp : public OpKernel {
+ public:
+ explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
+ int64 depth_radius64;
+ OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
+ OP_REQUIRES(context, FastBoundsCheck(depth_radius64,
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("depth_radius = ", depth_radius64,
+ " larger than int max"));
+ depth_radius_ = static_cast<int>(depth_radius64);
+ OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
+ OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
+ workspace_enabled_ = false;
+ context->GetAttr("workspace_enabled", &workspace_enabled_);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ try {
+ SanityCheckInputs(context);
+ if (!context->status().ok()) return;
+
+ auto cpu_engine = engine(engine::cpu, 0);
+ MklDnnData<T> input_grad_dnn_data(&cpu_engine);
+ MklDnnData<T> orig_input_dnn_data(&cpu_engine);
+ MklDnnData<T> orig_output_dnn_data(&cpu_engine);
+ MklDnnData<T> output_dnn_data(&cpu_engine);
+
+ MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
+ orig_output_dnn_shape;
+ GetMklShape(context, kIdxGradient, &input_grad_dnn_shape);
+ GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape);
+ GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape);
+
+ // We only use MKLDNN if all of the necessary inputs are present
+ // in mkldnn format, and Channel is the last dimension
+ bool can_use_mkldnn = workspace_enabled_ &&
+ input_grad_dnn_shape.IsMklTensor() &&
+ orig_input_dnn_shape.IsMklTensor() &&
+ orig_output_dnn_shape.IsMklTensor() &&
+ input_grad_dnn_shape.IsMklChannelDim(
+ input_grad_dnn_shape.GetDimension() - 1) &&
+ orig_input_dnn_shape.IsMklChannelDim(
+ orig_input_dnn_shape.GetDimension() - 1) &&
+ orig_output_dnn_shape.IsMklChannelDim(
+ orig_output_dnn_shape.GetDimension() - 1);
+
+ if (!can_use_mkldnn) {
+ // Fallback to eigen
+ MklDefaultToEigen(context);
+ return;
+ }
+ // At this point, we have the all clear to use MklDnn constructs
+ // Naming: diff_dst is input_gradient_tensor; src is orig_input_tensor.
+ const Tensor& input_grad_tensor = MklGetInput(context, kIdxGradient);
+ const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput);
+ const Tensor& orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
+
+ // Get input sizes in MKL-DNN required NCHW format.
+ // LRN does not have data_format attribute. But by default it has
+ // NHWC format.
+ memory::desc original_output_md = orig_output_dnn_shape.GetCurLayout();
+ memory::desc target_diff_dst_md = ConfigureInputGradient(
+ input_grad_tensor,
+ input_grad_dnn_shape,
+ &input_grad_dnn_data);
+
+ memory::desc orig_input_md = orig_input_dnn_shape.GetCurLayout();
+ memory::dims orig_input_dims =
+ orig_input_dnn_shape.GetSizesAsMklDnnDims();
+ orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
+ orig_input_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc);
+
+ // output_dnn_data has the same shape as original input
+ output_dnn_data.SetUsrMem(orig_input_md);
+ output_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc);
+
+ // MKL-DNN has a notion of kernel_size and not depth_radius.
+ int kernel_size = 2 * depth_radius_ + 1;
+ float new_alpha = alpha_ * kernel_size;
+
+ // Create LRN backward primitive descriptor. It requires LRN forward
+ // primitive descriptor also.
+ auto lrn_fwd_desc = lrn_forward::desc(prop_kind::forward,
+ lrn_across_channels,
+ orig_input_md,
+ kernel_size,
+ new_alpha, beta_, bias_);
+ auto lrn_fwd_prim_desc = lrn_forward::primitive_desc(lrn_fwd_desc,
+ cpu_engine);
+ auto lrn_bwd_desc = lrn_backward::desc(lrn_across_channels,
+ original_output_md,
+ target_diff_dst_md,
+ kernel_size,
+ new_alpha, beta_, bias_);
+ auto lrn_bwd_prim_desc = lrn_backward::primitive_desc(lrn_bwd_desc,
+ cpu_engine,
+ lrn_fwd_prim_desc);
+
+ Tensor* output_tensor = nullptr;
+ memory::format orig_input_format
+ = orig_input_dnn_shape.GetTfDataFormat();
+ AllocateOutputTensor(context, lrn_bwd_prim_desc,
+ orig_input_dims, orig_input_format, &output_tensor);
+ OP_REQUIRES_OK(context, context->status());
+ CHECK_NOTNULL(output_tensor);
+ output_dnn_data.SetUsrMemDataHandle(output_tensor);
+
+ // Create LRN primitive and add it to the net
+ // At this point, workspace is enabled, so we don't need
+ // to check. Pass input workspace to LRN backward primitive.
+ const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
+ MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
+ ConfigureWorkspace(workspace_tensor,
+ lrn_fwd_prim_desc.workspace_primitive_desc(),
+ &workspace_dnn_data);
+
+ PrepareAndExecuteNet(lrn_bwd_prim_desc,
+ lrn_fwd_prim_desc,
+ &orig_input_dnn_data,
+ &input_grad_dnn_data,
+ &output_dnn_data,
+ memory::primitive_desc(target_diff_dst_md, cpu_engine),
+ &workspace_dnn_data);
+ } catch (mkldnn::error &e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) +
+ ", in file " + string(__FILE__) + ":" +
+ std::to_string(__LINE__);
+ OP_REQUIRES_OK(context,
+ errors::Aborted("Operation received an exception:",
+ error_msg));
+ }
+ }
+
+ void AllocateOutputTensor(OpKernelContext* context,
+ const lrn_backward::primitive_desc& lrn_bkwd_prim_desc,
+ const memory::dims output_dims_mkl_order,
+ const memory::format& output_tf_format,
+ Tensor** output_tensor) {
+ CHECK_NOTNULL(output_tensor);
+ memory::primitive_desc dst_pd
+ = lrn_bkwd_prim_desc.diff_src_primitive_desc();
+ MklDnnShape output_mkl_shape;
+
+ // We assume that all outputs at this point are MKL Tensors
+ output_mkl_shape.SetMklTensor(true);
+ output_mkl_shape.SetMklLayout(&dst_pd);
+ output_mkl_shape.SetElemType(MklDnnType<T>());
+ output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
+ output_dims_mkl_order,
+ output_tf_format);
+
+ TensorShape output_tf_shape;
+ size_t num_bytes = dst_pd.get_size();
+ CHECK_EQ(num_bytes % sizeof(T), 0);
+ output_tf_shape.AddDim(num_bytes / sizeof(T));
+ AllocateOutputSetMklShape(context, kIdxOutput,
+ output_tensor,
+ output_tf_shape, output_mkl_shape);
+ }
+
+ memory::desc ConfigureInputGradient(const Tensor& input_grad_tensor,
+ const MklDnnShape& input_grad_dnn_shape,
+ MklDnnData<T> *input_grad_dnn_data) {
+ CHECK_NOTNULL(input_grad_dnn_data);
+ // This shouldn't be necessary at this point, but just in case
+ CHECK_EQ(input_grad_dnn_shape.IsMklTensor(), true);
+
+ memory::desc input_grad_md = input_grad_dnn_shape.GetCurLayout();
+ memory::dims orig_input_dims =
+ input_grad_dnn_shape.GetSizesAsMklDnnDims();
+ input_grad_dnn_data->SetUsrMem(input_grad_md, &input_grad_tensor);
+ input_grad_dnn_data->SetOpMemDesc(orig_input_dims, memory::format::nhwc);
+ return input_grad_md;
+ }
+
+ void PrepareAndExecuteNet(
+ const lrn_backward::primitive_desc& lrn_bkwd_desc,
+ const lrn_forward::primitive_desc& lrn_fwd_desc,
+ MklDnnData<T>* src_dnn_data,
+ MklDnnData<T>* input_gradient_diff_dst,
+ MklDnnData<T>* output_diff_src,
+ const memory::primitive_desc& target_diff_dst_pd,
+ const MklDnnData<uint8>* workspace_dnn_data = nullptr) {
+ std::vector<primitive> net;
+
+ // Check for input reordering on the diff dst input
+ input_gradient_diff_dst->CheckReorderToOpMem(
+ lrn_bkwd_desc.diff_dst_primitive_desc(), &net);
+
+ // Check for input reordering on the original input
+ src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc(),
+ &net);
+ // Create pooling primitive and add it to net
+ if (nullptr == workspace_dnn_data) {
+ net.push_back(lrn_backward(lrn_bkwd_desc,
+ src_dnn_data->GetOpMem(),
+ input_gradient_diff_dst->GetOpMem(),
+ output_diff_src->GetOpMem()));
+ } else {
+ net.push_back(lrn_backward(lrn_bkwd_desc,
+ src_dnn_data->GetOpMem(),
+ input_gradient_diff_dst->GetOpMem(),
+ workspace_dnn_data->GetOpMem(),
+ output_diff_src->GetOpMem()));
+ }
+ stream(stream::kind::eager).submit(net).wait();
+ }
+
+ void ConfigureWorkspace(const Tensor& workspace_tensor,
+ memory::primitive_desc workspace_pd,
+ MklDnnData<uint8> *workspace_dnn_data) {
+ CHECK_NOTNULL(workspace_dnn_data);
+
+ workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor);
+ }
+
+ // Fallback implementation - Taken from lrn_op.cc
+ // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
+ // copy.
+ void MklDefaultToEigen(OpKernelContext* context) {
+ Tensor input_gradient_tensor;
+ Tensor orig_input_tensor;
+ Tensor orig_output_tensor;
+
+ MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
+ orig_output_dnn_shape;
+ GetMklShape(context, kIdxGradient, &input_grad_dnn_shape);
+ GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape);
+ GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape);
+
+ if (input_grad_dnn_shape.IsMklTensor()) {
+ input_gradient_tensor =
+ ConvertMklToTF<T>(context,
+ MklGetInput(context, kIdxGradient),
+ input_grad_dnn_shape);
+ } else {
+ input_gradient_tensor = MklGetInput(context, kIdxGradient);
+ }
+
+ if (orig_input_dnn_shape.IsMklTensor()) {
+ orig_input_tensor =
+ ConvertMklToTF<T>(context,
+ MklGetInput(context, kIdxOrigInput),
+ orig_input_dnn_shape);
+ } else {
+ orig_input_tensor = MklGetInput(context, kIdxOrigInput);
+ }
+
+ if (orig_output_dnn_shape.IsMklTensor()) {
+ orig_output_tensor =
+ ConvertMklToTF<T>(context,
+ MklGetInput(context, kIdxOrigOutput),
+ orig_output_dnn_shape);
+ } else {
+ orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
+ }
+
+ const int64 batch = static_cast<int64>(input_gradient_tensor.dim_size(0));
+ const int64 rows = static_cast<int64>(input_gradient_tensor.dim_size(1));
+ const int64 cols = static_cast<int64>(input_gradient_tensor.dim_size(2));
+ const int64 depth = static_cast<int64>(input_gradient_tensor.dim_size(3));
+ const auto nodes = cols * rows;
+
+ auto grads_shaped =
+ input_gradient_tensor.shaped<T, 2>({nodes * batch, depth});
+
+ auto in_shaped = orig_input_tensor.shaped<T, 2>({nodes * batch, depth});
+ auto activations =
+ orig_output_tensor.shaped<T, 2>({nodes * batch, depth});
+
+ Tensor* output_dnn_data;
+ MklShape mkl_output_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(false);
+ mkl_output_mkl_shape.SetDimensions(4);
+ AllocateOutputSetMklShape(context, kIdxOutput,
+ &output_dnn_data,
+ input_gradient_tensor.shape(),
+ mkl_output_mkl_shape);
+
+ auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
+ out_shaped.setZero();
+ auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
+ depth](int64 begin, int64 end) {
+ for (int64 i = begin; i < end; ++i) {
+ for (int64 j = 0; j < depth; ++j) {
+ int64 depth_begin = std::max<int64>(0, j - depth_radius_);
+ int64 depth_end = std::min<int64>(depth, j + depth_radius_ + 1);
+
+ T norm(0);
+ for (int64 k = depth_begin; k < depth_end; ++k) {
+ norm += in_shaped(i, k) * in_shaped(i, k);
+ }
+ norm = alpha_ * norm + bias_;
+ DCHECK_GT(norm, T(1e-6));
+ for (int64 k = depth_begin; k < depth_end; ++k) {
+ T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) *
+ activations(i, j) / norm;
+ if (k == j) {
+ dyi += Eigen::numext::pow(norm, -beta_);
+ }
+ dyi *= grads_shaped(i, j);
+ const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) +=
+ dyi;
+ }
+ }
+ }
+ };
+ auto worker_threads =
+ *(context->device()->tensorflow_cpu_worker_threads());
+ Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
+ depth * depth, shard);
+ }
+
+ void SanityCheckInputs(OpKernelContext* context) {
+ const Tensor& input_gradient_tensor = MklGetInput(context, kIdxGradient);
+ const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput);
+ const Tensor& orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
+ const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
+ MklDnnShape in_grads_dnn_shape, in_image_dnn_shape, out_image_dnn_shape,
+ workspace_dnn_shape;
+ GetMklShape(context, kIdxGradient, &in_grads_dnn_shape);
+ GetMklShape(context, kIdxOrigInput, &in_image_dnn_shape);
+ GetMklShape(context, kIdxOrigOutput, &out_image_dnn_shape);
+ GetMklShape(context, kIdxWorkspace, &workspace_dnn_shape);
+ if (in_grads_dnn_shape.IsMklTensor()) {
+ OP_REQUIRES(context, in_grads_dnn_shape.GetDimension() == 4,
+ errors::InvalidArgument("Input gradient must be "
+ "4-dimensional"));
+ } else {
+ OP_REQUIRES(context, input_gradient_tensor.dims() == 4,
+ errors::InvalidArgument("input gradient must be 4-dimensional"));
+ }
+
+ if (in_image_dnn_shape.IsMklTensor()) {
+ OP_REQUIRES(context, in_image_dnn_shape.GetDimension() == 4,
+ errors::InvalidArgument("input images must be "
+ "4-dimensional"));
+ } else {
+ OP_REQUIRES(context, orig_input_tensor.dims() == 4,
+ errors::InvalidArgument("input images must be "
+ "4-dimensional"));
+ }
+
+ if (out_image_dnn_shape.IsMklTensor()) {
+ OP_REQUIRES(context, out_image_dnn_shape.GetDimension() == 4,
+ errors::InvalidArgument("Output image must be "
+ "4-dimensional"));
+ } else {
+ OP_REQUIRES(context, orig_output_tensor.dims() == 4,
+ errors::InvalidArgument("Output image must be 4-dimensional"));
+ }
+
+ if (workspace_dnn_shape.IsMklTensor()) {
+ OP_REQUIRES(context, workspace_dnn_shape.IsMklTensor() == false,
+ errors::InvalidArgument("Workspace should not be MKL Tensor."));
+ } else {
+ OP_REQUIRES(context, workspace_tensor.dims() == 1,
+ errors::InvalidArgument("Workspace must be 1-dimensional"));
+ }
+ }
+
+// Input("input_grads: T")
+// Input("input_image: T")
+// Input("output_image: T")
+// Input("workspace: uint8")
+ const int kIdxGradient = 0,
+ kIdxOrigInput = 1,
+ kIdxOrigOutput = 2,
+ kIdxWorkspace = 3,
+ kIdxOutput = 0;
+
+ typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
+ bool workspace_enabled_;
+ int depth_radius_;
+ float bias_;
+ float alpha_;
+ float beta_;
+};
+
+#endif // INTEL_MKL_DNN
+
#define REGISTER_MKL_LRN_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("_MklLRN") \
.Device(DEVICE_CPU) \
@@ -729,6 +1385,7 @@ class MklLRNGradOp : public OpKernel {
.Label(mkl_op_registry::kMklOpLabel), \
MklLRNGradOp<T>);
+
TF_CALL_float(REGISTER_MKL_LRN_CPU);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 45bdd0ad5c..dc899d8c7e 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -500,30 +500,81 @@ class MklReluGradOpBase : public OpKernel {
// Set DNN primitives for src & diff_dst
memory::desc src_md({}, memory::data_undef, memory::format_undef);
memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef);
- if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
- if (dnn_shape_diff_dst.IsMklTensor()) {
- diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
- src_md = diff_dst_md;
- } else {
- src_md = dnn_shape_src.GetMklLayout();
- diff_dst_md = src_md;
- }
- } else {
+
+ // For creating Sum primitive, we need to ensure that all inputs are in
+ // same format. What that means is if we have a mixed input case - where
+ // one input is in Tensorflow format and one input is in MKL format -,
+ // then we need to ensure that all inputs are in same format for
+ // primitive construction. For performance reason, we say that all inputs
+ // are in MKL format in such case, and insert reorder for input that is
+ // in Tensorflow format into MKL format. On the other hand, if both the
+ // inputs are in MKL format or both are in Tensorflow format, then we
+ // dont need reorder.
+ if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
+ // If both the inputs are in Tensorflow format, we create blocked memory
+ // descriptor.
auto src_dims = TFShapeToMklDnnDims(src_tensor.shape());
auto src_strides = CalculateTFStrides(src_dims);
src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
diff_dst_md = src_md;
+ } else if (dnn_shape_src.IsMklTensor() &&
+ !dnn_shape_diff_dst.IsMklTensor()) {
+ // If one input is in MKL format and other is in Tensorflow, then
+ // create respective descriptors describing the actual case. For input
+ // in Mkl format, we just get Mkl layout from MklDnnShape. For input in
+ // Tensorflow format, we create memory descriptor using data format.
+ src_md = dnn_shape_src.GetMklLayout();
+
+ memory::format src_mkl_data_format = dnn_shape_src.GetTfDataFormat();
+ auto src_tf_data_format = MklDnnDataFormatToTFDataFormat(
+ src_mkl_data_format);
+ auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(),
+ src_tf_data_format);
+ diff_dst_md = memory::desc(diff_dst_dims, MklDnnType<T>(),
+ src_mkl_data_format);
+ } else if (!dnn_shape_src.IsMklTensor() &&
+ dnn_shape_diff_dst.IsMklTensor()) {
+ // Same comment as above.
+ diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
+
+ memory::format diff_dst_mkl_data_format =
+ dnn_shape_diff_dst.GetTfDataFormat();
+ auto diff_dst_tf_data_format = MklDnnDataFormatToTFDataFormat(
+ diff_dst_mkl_data_format);
+ auto src_dims = TFShapeToMklDnnDimsInNCHW(src_tensor.shape(),
+ diff_dst_tf_data_format);
+ src_md = memory::desc(src_dims, MklDnnType<T>(),
+ diff_dst_mkl_data_format);
+ } else {
+ // If both the inputs are in MKL format, we use Mkl layout of the input
+ // tensors.
+ src_md = dnn_shape_src.GetMklLayout();
+ diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
}
+
src.SetUsrMem(src_md, &src_tensor);
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
+ // As per comment above, we tell MKLDNN that both the inputs are in same
+ // format. So we set common memory descriptor in MKL format, if any of the
+ // inputs are in MKL format. Let's get memory descriptor that we will use
+ // for both the inputs.
+ memory::desc common_md({}, memory::data_undef, memory::format_undef);
+ if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
+ common_md = dnn_shape_src.IsMklTensor() ? src_md : diff_dst_md;
+ } else {
+ // Since both the inputs are in Tensorflow format, and have
+ // same shape, we can get memory descriptor from any input.
+ common_md = src_md;
+ }
+
T alpha = 0, beta = 0;
std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd;
auto relu_fwd_desc = relu_forward::desc(prop_kind::forward_training,
alg_kind, src_md, alpha, beta);
relu_fwd_pd.reset(new relu_forward::primitive_desc(relu_fwd_desc,
cpu_engine));
- auto relu_bwd_desc = relu_backward::desc(alg_kind, diff_dst_md, src_md,
+ auto relu_bwd_desc = relu_backward::desc(alg_kind, common_md, common_md,
alpha, beta);
auto relu_bwd_pd = relu_backward::primitive_desc(relu_bwd_desc,
cpu_engine, *relu_fwd_pd);
@@ -547,9 +598,9 @@ class MklReluGradOpBase : public OpKernel {
AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
tf_shape_diff_src, dnn_shape_diff_src);
- // diff_src memory descriptor is same as diff_dst memory descriptor.
- auto diff_src_md = diff_dst_md;
- diff_src.SetUsrMem(diff_src_md, diff_src_tensor);
+ // diff_src memory descriptor is same as memory descriptor for both
+ // inputs.
+ diff_src.SetUsrMem(common_md, diff_src_tensor);
PrepareAndExecuteNet(relu_bwd_pd, &src, &diff_src, &diff_dst);
} catch (mkldnn::error &e) {
@@ -567,6 +618,14 @@ class MklReluGradOpBase : public OpKernel {
MklDnnData<T>* src, MklDnnData<T>* diff_src, MklDnnData<T>*
diff_dst) {
std::vector<primitive> net;
+
+ // Check if we need to reorder original input tensors into common_md layout
+ // that we set for primitive creation. diff_src_primitive_desc is same as
+ // common_md.
+ src->CheckReorderToOpMem(relu_prim_desc.diff_src_primitive_desc(), &net);
+ diff_dst->CheckReorderToOpMem(relu_prim_desc.diff_src_primitive_desc(),
+ &net);
+
net.push_back(relu_backward(relu_prim_desc, src->GetOpMem(),
diff_dst->GetOpMem(), diff_src->GetOpMem()));
stream(stream::kind::eager).submit(net).wait();
@@ -622,7 +681,6 @@ class MklReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> {
MklDnnShape dnn_shape_diff_dst;
GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
- int src_dims_size = src_tensor.dims();
MklDnnShape dnn_shape_diff_src;
dnn_shape_diff_src.SetMklTensor(false);
AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
@@ -690,7 +748,6 @@ class MklEluGradOp : public MklReluGradOpBase<Device, T, eltwise_elu> {
MklDnnShape dnn_shape_diff_dst;
GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
- int src_dims_size = src_tensor.dims();
MklDnnShape dnn_shape_diff_src;
dnn_shape_diff_src.SetMklTensor(false);
AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
@@ -762,7 +819,6 @@ class MklTanhGradOp : public MklReluGradOpBase<Device, T, eltwise_tanh> {
MklDnnShape dnn_shape_diff_dst;
GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
- int src_dims_size = src_tensor.dims();
MklDnnShape dnn_shape_diff_src;
dnn_shape_diff_src.SetMklTensor(false);
AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc
index 7ac34d1c62..8f82784d93 100644
--- a/tensorflow/core/kernels/reverse_op.cc
+++ b/tensorflow/core/kernels/reverse_op.cc
@@ -182,9 +182,9 @@ class ReverseOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
-#define HANDLE_REVERSE(NDIMS) \
- case NDIMS: \
- HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \
+#define HANDLE_REVERSE(NDIMS) \
+ case NDIMS: \
+ HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \
return;
switch (input_dims) {
@@ -228,7 +228,7 @@ void HandleReverseV2Case(OpKernelContext* context,
result->tensor<T, NDIMS>());
}
-template <typename Device, typename T>
+template <typename Device, typename T, typename Tidx>
class ReverseV2Op : public OpKernel {
public:
explicit ReverseV2Op(OpKernelConstruction* context) : OpKernel(context) {}
@@ -242,15 +242,15 @@ class ReverseV2Op : public OpKernel {
} else {
const int input_dims = input.dims();
const TensorShape& sparse_dims_shape = sparse_dims.shape();
- const auto& axes_sparse_flat = sparse_dims.flat<int32>();
+ const auto& axes_sparse_flat = sparse_dims.flat<Tidx>();
OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_dims_shape),
errors::InvalidArgument("'dims' must be 1-dimension, not ",
sparse_dims.dims()));
gtl::InlinedVector<bool, 8> axes_dense(input_dims, false);
for (int dummy = 0; dummy < axes_sparse_flat.size(); dummy++) {
- int32 axis = internal::SubtleMustCopy<int32>(axes_sparse_flat(dummy));
- int32 canonical_axis = axis < 0 ? input_dims + axis : axis;
+ Tidx axis = internal::SubtleMustCopy<Tidx>(axes_sparse_flat(dummy));
+ Tidx canonical_axis = axis < 0 ? input_dims + axis : axis;
OP_REQUIRES(context, canonical_axis >= 0 && canonical_axis < input_dims,
errors::InvalidArgument("'axis'[", dummy, "] = ", axis,
" is out of valid range [", 0, ", ",
@@ -306,7 +306,13 @@ class ReverseV2Op : public OpKernel {
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
- ReverseV2Op<CPUDevice, T>)
+ ReverseV2Op<CPUDevice, T, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("axis"), \
+ ReverseV2Op<CPUDevice, T, int64>)
TF_CALL_POD_TYPES(REGISTER_KERNELS);
TF_CALL_string(REGISTER_KERNELS);
#undef REGISTER_KERNELS
@@ -358,7 +364,13 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
- ReverseV2Op<GPUDevice, T>)
+ ReverseV2Op<GPUDevice, T, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("axis"), \
+ ReverseV2Op<GPUDevice, T, int64>)
TF_CALL_uint8(REGISTER_GPU_KERNELS);
TF_CALL_int8(REGISTER_GPU_KERNELS);
// TODO decide whether we want to enable the bool kernel.
@@ -387,7 +399,15 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2")
.HostMemory("tensor")
.HostMemory("axis")
.HostMemory("output"),
- ReverseV2Op<CPUDevice, int32>);
+ ReverseV2Op<CPUDevice, int32, int32>);
+REGISTER_KERNEL_BUILDER(Name("ReverseV2")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("Tidx")
+ .HostMemory("tensor")
+ .HostMemory("axis")
+ .HostMemory("output"),
+ ReverseV2Op<CPUDevice, int32, int64>);
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
@@ -402,7 +422,13 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2")
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
- ReverseV2Op<SYCLDevice, T>)
+ ReverseV2Op<SYCLDevice, T, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("axis"), \
+ ReverseV2Op<SYCLDevice, T, int64>)
TF_CALL_uint8(REGISTER_SYCL_KERNELS);
TF_CALL_int8(REGISTER_SYCL_KERNELS);
TF_CALL_float(REGISTER_SYCL_KERNELS);
@@ -422,6 +448,14 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2")
.HostMemory("tensor")
.HostMemory("axis")
.HostMemory("output"),
- ReverseV2Op<CPUDevice, int32>);
-#endif // TENSORFLOW_USE_SYCL
+ ReverseV2Op<CPUDevice, int32, int32>);
+REGISTER_KERNEL_BUILDER(Name("ReverseV2")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("Tidx")
+ .HostMemory("tensor")
+ .HostMemory("axis")
+ .HostMemory("output"),
+ ReverseV2Op<CPUDevice, int32, int64>);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/set_kernels.cc b/tensorflow/core/kernels/set_kernels.cc
index 5a2b18b41c..e836c764ac 100644
--- a/tensorflow/core/kernels/set_kernels.cc
+++ b/tensorflow/core/kernels/set_kernels.cc
@@ -216,7 +216,7 @@ void PopulateFromDenseGroup(OpKernelContext* ctx, const Tensor& input_tensor,
result->clear();
auto input_flat = input_tensor.flat<T>();
const auto start = std::inner_product(
- group_indices.begin(), group_indices.end(), input_strides.begin(), 0L);
+ group_indices.begin(), group_indices.end(), input_strides.begin(), 0LL);
const TensorShape& input_shape = input_tensor.shape();
const auto end = start + input_shape.dim_size(input_shape.dims() - 1);
for (int64 i = start; i < end; ++i) {
@@ -279,7 +279,7 @@ void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
const auto group_key = group.group();
const auto output_index = std::inner_product(
- group_key.begin(), group_key.end(), output_strides.begin(), 0L);
+ group_key.begin(), group_key.end(), output_strides.begin(), 0LL);
out(output_index) = group_set.size();
}
}
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index 43f17df898..a9e31cc336 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -273,6 +273,7 @@ class MklSliceOp : public OpKernel {
HANDLE_DIM(1);
HANDLE_DIM(2);
HANDLE_DIM(3);
+ HANDLE_DIM(4);
HANDLE_DIM(5);
HANDLE_DIM(6);
HANDLE_DIM(7);
diff --git a/tensorflow/core/kernels/string_to_number_op.cc b/tensorflow/core/kernels/string_to_number_op.cc
index d583e4e6bb..70dbd15c46 100644
--- a/tensorflow/core/kernels/string_to_number_op.cc
+++ b/tensorflow/core/kernels/string_to_number_op.cc
@@ -49,43 +49,15 @@ class StringToNumberOp : public OpKernel {
auto output_flat = output_tensor->flat<OutputType>();
for (int i = 0; i < input_flat.size(); ++i) {
- Convert(input_flat(i), &output_flat(i), context);
+ OP_REQUIRES(
+ context,
+ strings::SafeStringToNumeric<OutputType>(input_flat(i).c_str(),
+ &output_flat(i)),
+ errors::InvalidArgument(kErrorMessage, input_flat(i).c_str()));
}
}
-
- private:
- void Convert(const string& s, OutputType* output_data,
- OpKernelContext* context);
};
-template <>
-void StringToNumberOp<float>::Convert(const string& s, float* output_data,
- OpKernelContext* context) {
- OP_REQUIRES(context, strings::safe_strtof(s.c_str(), output_data),
- errors::InvalidArgument(kErrorMessage, s));
-}
-
-template <>
-void StringToNumberOp<double>::Convert(const string& s, double* output_data,
- OpKernelContext* context) {
- OP_REQUIRES(context, strings::safe_strtod(s.c_str(), output_data),
- errors::InvalidArgument(kErrorMessage, s));
-}
-
-template <>
-void StringToNumberOp<int32>::Convert(const string& s, int32* output_data,
- OpKernelContext* context) {
- OP_REQUIRES(context, strings::safe_strto32(s, output_data),
- errors::InvalidArgument(kErrorMessage, s));
-}
-
-template <>
-void StringToNumberOp<int64>::Convert(const string& s, int64* output_data,
- OpKernelContext* context) {
- OP_REQUIRES(context, strings::safe_strto64(s, output_data),
- errors::InvalidArgument(kErrorMessage, s));
-}
-
// Registers the currently supported output types.
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER(Name("StringToNumber") \
diff --git a/tensorflow/core/kernels/where_op.cc b/tensorflow/core/kernels/where_op.cc
index 42d1365e64..de0c66ad23 100644
--- a/tensorflow/core/kernels/where_op.cc
+++ b/tensorflow/core/kernels/where_op.cc
@@ -55,14 +55,14 @@ namespace functor {
namespace {
template <typename T>
int64 CountAccumulator(const T* begin, const T* end) {
- return std::accumulate(begin, end, 0L, [](int64 accum, const T& val) {
+ return std::accumulate(begin, end, 0LL, [](int64 accum, const T& val) {
return accum + (val != T(0));
});
}
template <>
int64 CountAccumulator<bool>(const bool* begin, const bool* end) {
- return std::accumulate(begin, end, 0L);
+ return std::accumulate(begin, end, 0LL);
}
} // namespace
diff --git a/tensorflow/core/lib/random/random_distributions_test.cc b/tensorflow/core/lib/random/random_distributions_test.cc
index ca088f9988..90d0dba4a7 100644
--- a/tensorflow/core/lib/random/random_distributions_test.cc
+++ b/tensorflow/core/lib/random/random_distributions_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <math.h>
#include <algorithm>
#include <functional>
+#include <numeric>
#include <unordered_map>
#include <vector>
diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h
index 31b6abbac6..3c45b90274 100644
--- a/tensorflow/core/lib/strings/numbers.h
+++ b/tensorflow/core/lib/strings/numbers.h
@@ -122,6 +122,38 @@ bool safe_strtof(const char* str, float* value);
// Values may be rounded on over- and underflow.
bool safe_strtod(const char* str, double* value);
+inline bool ProtoParseNumeric(StringPiece s, int32* value) {
+ return safe_strto32(s, value);
+}
+
+inline bool ProtoParseNumeric(StringPiece s, uint32* value) {
+ return safe_strtou32(s, value);
+}
+
+inline bool ProtoParseNumeric(StringPiece s, int64* value) {
+ return safe_strto64(s, value);
+}
+
+inline bool ProtoParseNumeric(StringPiece s, uint64* value) {
+ return safe_strtou64(s, value);
+}
+
+inline bool ProtoParseNumeric(StringPiece s, float* value) {
+ return safe_strtof(s.ToString().c_str(), value);
+}
+
+inline bool ProtoParseNumeric(StringPiece s, double* value) {
+ return safe_strtod(s.ToString().c_str(), value);
+}
+
+// Convert strings to number of type T.
+// Leading and trailing spaces are allowed.
+// Values may be rounded on over- and underflow.
+template <typename T>
+bool SafeStringToNumeric(StringPiece s, T* value) {
+ return ProtoParseNumeric(s, value);
+}
+
// Converts from an int64 to a human readable string representing the
// same number, using decimal powers. e.g. 1200000 -> "1.20M".
string HumanReadableNum(int64 value);
diff --git a/tensorflow/core/lib/strings/proto_text_util.h b/tensorflow/core/lib/strings/proto_text_util.h
index 3d0c6e4a37..ed6d0af010 100644
--- a/tensorflow/core/lib/strings/proto_text_util.h
+++ b/tensorflow/core/lib/strings/proto_text_util.h
@@ -118,30 +118,6 @@ class ProtoTextOutput {
TF_DISALLOW_COPY_AND_ASSIGN(ProtoTextOutput);
};
-inline bool ProtoParseNumeric(StringPiece s, int32* value) {
- return ::tensorflow::strings::safe_strto32(s, value);
-}
-
-inline bool ProtoParseNumeric(StringPiece s, uint32* value) {
- return ::tensorflow::strings::safe_strtou32(s, value);
-}
-
-inline bool ProtoParseNumeric(StringPiece s, int64* value) {
- return ::tensorflow::strings::safe_strto64(s, value);
-}
-
-inline bool ProtoParseNumeric(StringPiece s, uint64* value) {
- return ::tensorflow::strings::safe_strtou64(s, value);
-}
-
-inline bool ProtoParseNumeric(StringPiece s, float* value) {
- return ::tensorflow::strings::safe_strtof(s.ToString().c_str(), value);
-}
-
-inline bool ProtoParseNumeric(StringPiece s, double* value) {
- return ::tensorflow::strings::safe_strtod(s.ToString().c_str(), value);
-}
-
inline void ProtoSpaceAndComments(Scanner* scanner) {
for (;;) {
scanner->AnySpace();
@@ -174,7 +150,7 @@ bool ProtoParseNumericFromScanner(Scanner* scanner, T* value) {
}
ProtoSpaceAndComments(scanner);
- return ProtoParseNumeric(numeric_str, value);
+ return SafeStringToNumeric<T>(numeric_str, value);
}
// Parse the next boolean value from <scanner>, returning false if parsing
diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h
index 8cea0f0718..44c52850fa 100644
--- a/tensorflow/core/lib/strings/str_util.h
+++ b/tensorflow/core/lib/strings/str_util.h
@@ -83,7 +83,7 @@ string Uppercase(StringPiece s);
// Converts "^2ILoveYou!" to "i_love_you_". More specifically:
// - converts all non-alphanumeric characters to underscores
-// - replaces each occurence of a capital letter (except the very
+// - replaces each occurrence of a capital letter (except the very
// first character and if there is already an '_' before it) with '_'
// followed by this letter in lower case
// - Skips leading non-alpha characters
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index b4c6096eb4..13762cc221 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -884,7 +884,7 @@ For example,
# Draw the bounding box in an image summary.
image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
bbox_for_draw)
- tf.image_summary('images_with_box', image_with_box)
+ tf.summary.image('images_with_box', image_with_box)
# Employ the bounding box to distort the image.
distorted_image = tf.slice(image, begin, size)
@@ -976,7 +976,7 @@ For example,
# Draw the bounding box in an image summary.
image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
bbox_for_draw)
- tf.image_summary('images_with_box', image_with_box)
+ tf.summary.image('images_with_box', image_with_box)
# Employ the bounding box to distort the image.
distorted_image = tf.slice(image, begin, size)
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index d2dfe23888..8ad2c06741 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -3361,7 +3361,11 @@ REGISTER_OP("_MklLRN")
.Input("input: T")
.Input("mkl_input: uint8")
.Output("output: T")
+#ifndef INTEL_MKL_DNN
.Output("workspace: T")
+#else
+ .Output("workspace: uint8")
+#endif
.Output("mkl_output: uint8")
.Output("mkl_workspace: uint8")
.Attr("depth_radius: int = 5")
@@ -3385,7 +3389,11 @@ REGISTER_OP("_MklLRNGrad")
.Input("input_grads: T")
.Input("input_image: T")
.Input("output_image: T")
+#ifndef INTEL_MKL_DNN
.Input("workspace: T")
+#else
+ .Input("workspace: uint8")
+#endif
.Input("mkl_input_grads: uint8")
.Input("mkl_input_image: uint8")
.Input("mkl_output_image: uint8")
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index aaeccc8324..6b6be757f6 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -11,6 +11,7 @@ load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_copts",
+ "if_windows",
)
filegroup(
@@ -261,6 +262,7 @@ tf_cc_test(
name = "gcs_dns_cache_test",
size = "small",
srcs = ["gcs_dns_cache_test.cc"],
+ linkopts = if_windows(["-DEFAULTLIB:ws2_32.lib"]),
deps = [
":gcs_dns_cache",
"//tensorflow/core:lib",
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 948334d27b..b357be8e63 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -3,6 +3,7 @@
load("@protobuf_archive//:protobuf.bzl", "proto_gen")
load("@protobuf_archive//:protobuf.bzl", "py_proto_library")
load("//tensorflow:tensorflow.bzl", "if_not_mobile")
+load("//tensorflow:tensorflow.bzl", "if_windows")
load("//tensorflow:tensorflow.bzl", "if_not_windows")
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
@@ -358,7 +359,9 @@ def tf_additional_proto_hdrs():
"platform/default/integral_types.h",
"platform/default/logging.h",
"platform/default/protobuf.h"
- ]
+ ] + if_windows([
+ "platform/windows/integral_types.h",
+ ])
def tf_additional_proto_srcs():
return [
diff --git a/tensorflow/core/platform/types.h b/tensorflow/core/platform/types.h
index 93b82ecb7a..6308e58847 100644
--- a/tensorflow/core/platform/types.h
+++ b/tensorflow/core/platform/types.h
@@ -22,8 +22,10 @@ limitations under the License.
// Include appropriate platform-dependent implementations
#if defined(PLATFORM_GOOGLE) || defined(GOOGLE_INTEGRAL_TYPES)
#include "tensorflow/core/platform/google/integral_types.h"
+#elif defined(PLATFORM_WINDOWS)
+#include "tensorflow/core/platform/windows/integral_types.h"
#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \
- defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_WINDOWS)
+ defined(PLATFORM_GOOGLE_ANDROID)
#include "tensorflow/core/platform/default/integral_types.h"
#else
#error Define the appropriate PLATFORM_<foo> macro for this platform
diff --git a/tensorflow/core/platform/windows/integral_types.h b/tensorflow/core/platform/windows/integral_types.h
new file mode 100644
index 0000000000..4970b8ca6a
--- /dev/null
+++ b/tensorflow/core/platform/windows/integral_types.h
@@ -0,0 +1,25 @@
+ /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================*/
+
+#ifndef TENSORFLOW_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
+#define TENSORFLOW_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
+
+#include "tensorflow/core/platform/default/integral_types.h"
+
+#include <cstddef>
+
+typedef std::ptrdiff_t ssize_t;
+
+#endif // TENSORFLOW_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index e816c282c8..f2401a0af4 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -616,7 +616,7 @@ SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
int index = 0;
for (int i = 0; i < input_tensor.indices().dim_size(0) && index < count;
i++) {
- // The logic here is similiar as the above except that the above
+ // The logic here is similar as the above except that the above
// only count the number of indices while here we actually generate
// the output.
bool hit = true;
diff --git a/tensorflow/docs_src/programmers_guide/datasets.md b/tensorflow/docs_src/programmers_guide/datasets.md
index f75f14dfb6..55c8216bb6 100644
--- a/tensorflow/docs_src/programmers_guide/datasets.md
+++ b/tensorflow/docs_src/programmers_guide/datasets.md
@@ -537,7 +537,7 @@ import cv2
# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
- image_decoded = cv2.imread(image_string, cv2.IMREAD_GRAYSCALE)
+ image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE)
return image_decoded, label
# Use standard TensorFlow operations to resize the image to a fixed shape.
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
index 142e45a2e8..87cd95165e 100644
--- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
@@ -120,7 +120,7 @@ def generate_batch(batch_size, num_skips, skip_window):
batch[i * num_skips + j] = buffer[skip_window]
labels[i * num_skips + j, 0] = buffer[context_word]
if data_index == len(data):
- buffer[:] = data[:span]
+ buffer.extend(data[0:span])
data_index = span
else:
buffer.append(data[data_index])
diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go
index fc914f86df..db6ae4f26c 100644
--- a/tensorflow/go/session.go
+++ b/tensorflow/go/session.go
@@ -65,6 +65,51 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
return s, nil
}
+// Device structure contains information about a device associated with a session, as returned by ListDevices()
+type Device struct {
+ Name, Type string
+ MemoryLimitBytes int64
+}
+
+// Return list of devices associated with a Session
+func (s *Session) ListDevices() ([]Device, error) {
+ var devices []Device
+
+ status := newStatus()
+ devices_list := C.TF_SessionListDevices(s.c, status.c)
+ if err := status.Err(); err != nil {
+ return nil, fmt.Errorf("SessionListDevices() failed: %v", err)
+ }
+ defer C.TF_DeleteDeviceList(devices_list)
+
+ for i := 0; i < int(C.TF_DeviceListCount(devices_list)); i++ {
+ device_name := C.TF_DeviceListName(devices_list, C.int(i), status.c)
+ if err := status.Err(); err != nil {
+ return nil, fmt.Errorf("DeviceListName(index=%d) failed: %v", i, err)
+ }
+
+ device_type := C.TF_DeviceListType(devices_list, C.int(i), status.c)
+ if err := status.Err(); err != nil {
+ return nil, fmt.Errorf("DeviceListType(index=%d) failed: %v", i, err)
+ }
+
+ memory_limit_bytes := C.TF_DeviceListMemoryBytes(devices_list, C.int(i), status.c)
+ if err := status.Err(); err != nil {
+ return nil, fmt.Errorf("DeviceListMemoryBytes(index=%d) failed: %v", i, err)
+ }
+
+ device := Device{
+ Name: C.GoString(device_name),
+ Type: C.GoString(device_type),
+ MemoryLimitBytes: int64(memory_limit_bytes),
+ }
+
+ devices = append(devices, device)
+ }
+
+ return devices, nil
+}
+
// Run the graph with the associated session starting with the supplied feeds
// to compute the value of the requested fetches. Runs, but does not return
// Tensors for operations specified in targets.
diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go
index 73d78a8e57..05ace99a23 100644
--- a/tensorflow/go/session_test.go
+++ b/tensorflow/go/session_test.go
@@ -283,3 +283,19 @@ func TestSessionConfig(t *testing.T) {
t.Fatalf("Got %v, want -1", output[0].Value())
}
}
+
+func TestListDevices(t *testing.T) {
+ s, err := NewSession(NewGraph(), nil)
+ if err != nil {
+ t.Fatalf("NewSession(): %v", err)
+ }
+
+ devices, err := s.ListDevices()
+ if err != nil {
+ t.Fatalf("ListDevices(): %v", err)
+ }
+
+ if len(devices) == 0 {
+ t.Fatalf("no devices detected")
+ }
+}
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index af34aca3e3..bc9ddec2a5 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -263,6 +263,7 @@ _allowed_symbols.extend([
'GIT_VERSION',
'COMPILER_VERSION',
'CXX11_ABI_FLAG',
+ 'MONOLITHIC_BUILD',
])
# Remove all extra symbols that don't have a docstring or are not explicitly
@@ -282,6 +283,7 @@ _exported_dunders = set([
'__git_version__',
'__compiler_version__',
'__cxx11_abi_flag__',
+ '__monolithic_build__',
])
# Expose symbols minus dunders, unless they are whitelisted above.
diff --git a/tensorflow/python/client/session_clusterspec_prop_test.py b/tensorflow/python/client/session_clusterspec_prop_test.py
index c85b22eb15..f193424133 100644
--- a/tensorflow/python/client/session_clusterspec_prop_test.py
+++ b/tensorflow/python/client/session_clusterspec_prop_test.py
@@ -77,7 +77,8 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
config = config_pb2.ConfigProto(cluster_def=cluster_def)
with ops.Graph().as_default() as g, ops.device('/job:worker/task:1'):
- const = constant_op.constant(17)
+ with ops.device('/cpu:0'):
+ const = constant_op.constant(17)
sess = session.Session(server1.target, config=config, graph=g)
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 2d24b86df3..3f1d63a543 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -100,6 +100,9 @@ tensorflow::ImportNumpy();
// _GLIBCXX_USE_CXX11_ABI flag value
%constant const int __cxx11_abi_flag__ = tf_cxx11_abi_flag();
+// Flag indicating whether the build is monolithic
+%constant const int __monolithic_build__ = tf_monolithic_build();
+
// Release the Python GIL for the duration of most methods.
%exception {
Py_BEGIN_ALLOW_THREADS;
diff --git a/tensorflow/python/debug/README.md b/tensorflow/python/debug/README.md
index b26411cd15..a2273b050b 100644
--- a/tensorflow/python/debug/README.md
+++ b/tensorflow/python/debug/README.md
@@ -28,7 +28,7 @@ models:
* Easy access through session wrappers
* Easy integration with common high-level APIs, such as
- [tf-learn](https://www.tensorflow.org/get_started/tflearn) and
+ [TensorFlow Estimators](https://www.tensorflow.org/programmers_guide/estimators) and
[Keras](https://keras.io/)
* Inspection of runtime tensor values and node connections
* Conditional breaking after runs that generate tensors satisfying given
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index f5a97eb197..cbe1a33ed0 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import re
import time
+import sys
import numpy as np
@@ -765,8 +766,12 @@ class FunctionTest(test.TestCase):
# We added more randomness to function names in C API.
# TODO(iga): Remove this if statement when we switch to C API.
if ops._USE_C_API: # pylint: disable=protected-access
- self.assertEqual("Foo_aCYSbwBkR5A",
- Foo.instantiate([dtypes.float32] * 3).name)
+ if sys.byteorder == 'big':
+ self.assertEqual("Foo_kEdkAG8SJvg",
+ Foo.instantiate([dtypes.float32] * 3).name)
+ else:
+ self.assertEqual("Foo_aCYSbwBkR5A",
+ Foo.instantiate([dtypes.float32] * 3).name)
else:
self.assertEqual("Foo_d643acf7",
Foo.instantiate([dtypes.float32] * 3).name)
diff --git a/tensorflow/python/framework/versions.py b/tensorflow/python/framework/versions.py
index 81529e2b1e..f03b81eb28 100644
--- a/tensorflow/python/framework/versions.py
+++ b/tensorflow/python/framework/versions.py
@@ -25,11 +25,13 @@ __version__ = pywrap_tensorflow.__version__
__git_version__ = pywrap_tensorflow.__git_version__
__compiler_version__ = pywrap_tensorflow.__compiler_version__
__cxx11_abi_flag__ = pywrap_tensorflow.__cxx11_abi_flag__
+__monolithic_build__ = pywrap_tensorflow.__monolithic_build__
VERSION = __version__
GIT_VERSION = __git_version__
COMPILER_VERSION = __compiler_version__
CXX11_ABI_FLAG = __cxx11_abi_flag__
+MONOLITHIC_BUILD = __monolithic_build__
GRAPH_DEF_VERSION = pywrap_tensorflow.GRAPH_DEF_VERSION
GRAPH_DEF_VERSION_MIN_CONSUMER = (
@@ -42,6 +44,7 @@ __all__ = [
"__git_version__",
"__compiler_version__",
"__cxx11_abi_flag__",
+ "__monolithic_build__",
"COMPILER_VERSION",
"CXX11_ABI_FLAG",
"GIT_VERSION",
@@ -49,4 +52,5 @@ __all__ = [
"GRAPH_DEF_VERSION_MIN_CONSUMER",
"GRAPH_DEF_VERSION_MIN_PRODUCER",
"VERSION",
+ "MONOLITHIC_BUILD",
]
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index f5ada46eeb..d7403fe6ee 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -2198,6 +2198,7 @@ cuda_py_test(
srcs = ["atrous_convolution_test.py"],
additional_deps = [
"//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn_grad",
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 17492e9255..1dbe7deb97 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -277,26 +277,34 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
x_np = np.array([1, 200, 3, 40, 5], dtype=np_dtype)
for use_gpu in [False, True]:
- with self.test_session(use_gpu=use_gpu):
- x_tf = array_ops.reverse_v2(x_np, [0]).eval()
- self.assertAllEqual(x_tf, np.asarray(x_np)[::-1])
+ for axis_dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = array_ops.reverse_v2(x_np,
+ constant_op.constant([0], dtype=axis_dtype)).eval()
+ self.assertAllEqual(x_tf, np.asarray(x_np)[::-1])
def _reverse2DimAuto(self, np_dtype):
x_np = np.array([[1, 200, 3], [4, 5, 60]], dtype=np_dtype)
for reverse_f in [array_ops.reverse_v2, array_ops.reverse]:
for use_gpu in [False, True]:
- with self.test_session(use_gpu=use_gpu):
- x_tf_1 = reverse_f(x_np, [0]).eval()
- x_tf_2 = reverse_f(x_np, [-2]).eval()
- x_tf_3 = reverse_f(x_np, [1]).eval()
- x_tf_4 = reverse_f(x_np, [-1]).eval()
- x_tf_5 = reverse_f(x_np, [1, 0]).eval()
- self.assertAllEqual(x_tf_1, np.asarray(x_np)[::-1, :])
- self.assertAllEqual(x_tf_2, np.asarray(x_np)[::-1, :])
- self.assertAllEqual(x_tf_3, np.asarray(x_np)[:, ::-1])
- self.assertAllEqual(x_tf_4, np.asarray(x_np)[:, ::-1])
- self.assertAllEqual(x_tf_5, np.asarray(x_np)[::-1, ::-1])
+ for axis_dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf_1 = reverse_f(x_np,
+ constant_op.constant([0], dtype=axis_dtype)).eval()
+ x_tf_2 = reverse_f(x_np,
+ constant_op.constant([-2], dtype=axis_dtype)).eval()
+ x_tf_3 = reverse_f(x_np,
+ constant_op.constant([1], dtype=axis_dtype)).eval()
+ x_tf_4 = reverse_f(x_np,
+ constant_op.constant([-1], dtype=axis_dtype)).eval()
+ x_tf_5 = reverse_f(x_np,
+ constant_op.constant([1, 0], dtype=axis_dtype)).eval()
+ self.assertAllEqual(x_tf_1, np.asarray(x_np)[::-1, :])
+ self.assertAllEqual(x_tf_2, np.asarray(x_np)[::-1, :])
+ self.assertAllEqual(x_tf_3, np.asarray(x_np)[:, ::-1])
+ self.assertAllEqual(x_tf_4, np.asarray(x_np)[:, ::-1])
+ self.assertAllEqual(x_tf_5, np.asarray(x_np)[::-1, ::-1])
# This is the version of reverse that uses axis indices rather than
# bool tensors
diff --git a/tensorflow/python/kernel_tests/atrous_convolution_test.py b/tensorflow/python/kernel_tests/atrous_convolution_test.py
index 3ac27d11c5..04248fb2ba 100644
--- a/tensorflow/python/kernel_tests/atrous_convolution_test.py
+++ b/tensorflow/python/kernel_tests/atrous_convolution_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
@@ -108,6 +109,18 @@ class AtrousConvolutionTest(test.TestCase):
add_check(check, y1, y2)
+ def test_unknown_spatial_dims_for_channel_last_format(self):
+ x = array_ops.placeholder(dtypes.float32, [1, None, None, 10])
+ w = array_ops.zeros([3, 3, 10, 20])
+ y = nn_ops.convolution(x, w, "VALID", dilation_rate=[2, 2], data_format="NHWC")
+ self.assertEqual(y.shape.as_list(), [1, None, None, 20])
+
+ def test_unknown_spatial_dims_for_channel_first_format(self):
+ x = array_ops.placeholder(dtypes.float32, [1, 10, None, None])
+ w = array_ops.zeros([3, 3, 10, 20])
+ y = nn_ops.convolution(x, w, "VALID", dilation_rate=[2, 2], data_format="NCHW")
+ self.assertEqual(y.shape.as_list(), [1, 20, None, None])
+
@test_util.run_in_graph_and_eager_modes()
def testAtrousConvolution2D(self):
with self._delay_checks() as add_check:
diff --git a/tensorflow/python/kernel_tests/bcast_ops_test.py b/tensorflow/python/kernel_tests/bcast_ops_test.py
index 7c18044c5c..9e51234605 100644
--- a/tensorflow/python/kernel_tests/bcast_ops_test.py
+++ b/tensorflow/python/kernel_tests/bcast_ops_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops.gen_array_ops import _broadcast_args
from tensorflow.python.ops.gen_array_ops import _broadcast_gradient_args
from tensorflow.python.platform import test
@@ -135,6 +137,19 @@ class BcastOpsTest(test.TestCase):
self.assertAllEqual(r0, [0, 1, 3])
self.assertAllEqual(r1, [])
+ def testDataTypes(self):
+ for dtype in [dtypes.int32, dtypes.int64]:
+ r = self._GetBroadcastShape(
+ constant_op.constant([2, 3, 5], dtype=dtype),
+ constant_op.constant([1], dtype=dtype))
+ self.assertAllEqual(r, [2, 3, 5])
+
+ r0, r1 = self._GetGradientArgs(
+ constant_op.constant([2, 3, 5], dtype=dtype),
+ constant_op.constant([1], dtype=dtype))
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0, 1, 2])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/record_input_test.py b/tensorflow/python/kernel_tests/record_input_test.py
index 44cd89022c..068860d5d4 100644
--- a/tensorflow/python/kernel_tests/record_input_test.py
+++ b/tensorflow/python/kernel_tests/record_input_test.py
@@ -26,7 +26,6 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-
class RecordInputOpTest(test.TestCase):
def generateTestData(self,
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 21561f3689..7506ab653b 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -1478,7 +1478,7 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None,
# Draw the bounding box in an image summary.
image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
bbox_for_draw)
- tf.image_summary('images_with_box', image_with_box)
+ tf.summary.image('images_with_box', image_with_box)
# Employ the bounding box to distort the image.
distorted_image = tf.slice(image, begin, size)
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index b74971f654..4f24c3c5bf 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -157,6 +157,13 @@ def compute_weighted_loss(
ValueError: If `weights` is `None` or the shape is not compatible with
`losses`, or if the number of dimensions (rank) of either `losses` or
`weights` is missing.
+
+ Note:
+ When calculating the gradient of a weighted loss contributions from
+ both `losses` and `weights` are considered. If your `weights` depend
+ on some model parameters but you do not want this to affect the loss
+ gradient, you need to apply @{tf.stop_gradient} to `weights` before
+ passing them to `compute_weighted_loss`.
"""
Reduction.validate(reduction)
with ops.name_scope(scope, "weighted_loss", (losses, weights)):
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index e04121ee31..25e1613a65 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -175,7 +175,7 @@ def _maybe_expand_labels(labels, predictions):
def _safe_div(numerator, denominator, name):
- """Divides two values, returning 0 if the denominator is <= 0.
+ """Divides two tensors element-wise, returning 0 if the denominator is <= 0.
Args:
numerator: A real `Tensor`.
@@ -185,11 +185,11 @@ def _safe_div(numerator, denominator, name):
Returns:
0 if `denominator` <= 0, else `numerator` / `denominator`
"""
- return array_ops.where(
- math_ops.greater(denominator, 0),
- math_ops.truediv(numerator, denominator),
- 0,
- name=name)
+ t = math_ops.truediv(numerator, denominator)
+ zero = array_ops.zeros_like(t, dtype=denominator.dtype)
+ condition = math_ops.greater(denominator, zero)
+ zero = math_ops.cast(zero, t.dtype)
+ return array_ops.where(condition, t, zero, name=name)
def _safe_scalar_div(numerator, denominator, name):
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index a563e7c588..8c1083d9cc 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -452,6 +452,7 @@ class _WithSpaceToBatch(object):
self.input_shape = input_shape
self.spatial_dims = spatial_dims
self.dilation_rate = dilation_rate
+ self.data_format = data_format
self.op = build_op(num_spatial_dims, "VALID")
self.call = self._with_space_to_batch_call
@@ -496,6 +497,14 @@ class _WithSpaceToBatch(object):
result_converted = array_ops.batch_to_space_nd(
input=result, block_shape=dilation_rate, crops=crops)
+
+ # Recover channel information for output shape if channels are not last.
+ if self.data_format is not None and self.data_format.startswith("NC"):
+ if not result_converted.shape[1].value:
+ output_shape = result_converted.shape.as_list()
+ output_shape[1] = filter.shape[-1]
+ result_converted.set_shape(output_shape)
+
return result_converted
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
@@ -823,7 +832,8 @@ class Convolution(object):
padding=padding,
build_op=self._build_op,
filter_shape=filter_shape,
- spatial_dims=spatial_dims)
+ spatial_dims=spatial_dims,
+ data_format=data_format)
def _build_op(self, _, padding):
return _NonAtrousConvolution(
diff --git a/tensorflow/python/platform/sysconfig.py b/tensorflow/python/platform/sysconfig.py
index 57635fb4d9..f6c4f2227f 100644
--- a/tensorflow/python/platform/sysconfig.py
+++ b/tensorflow/python/platform/sysconfig.py
@@ -27,6 +27,7 @@ from __future__ import print_function
import os.path as _os_path
from tensorflow.python.framework.versions import CXX11_ABI_FLAG as _CXX11_ABI_FLAG
+from tensorflow.python.framework.versions import MONOLITHIC_BUILD as _MONOLITHIC_BUILD
from tensorflow.python.util.all_util import remove_undocumented
@@ -75,8 +76,9 @@ def get_link_flags():
The link flags.
"""
flags = []
- flags.append('-L%s' % get_lib())
- flags.append('-ltensorflow_framework')
+ if not _MONOLITHIC_BUILD:
+ flags.append('-L%s' % get_lib())
+ flags.append('-ltensorflow_framework')
return flags
_allowed_symbols = []
diff --git a/tensorflow/python/pywrap_tensorflow.py b/tensorflow/python/pywrap_tensorflow.py
index 91373fa544..5c0c5783dc 100644
--- a/tensorflow/python/pywrap_tensorflow.py
+++ b/tensorflow/python/pywrap_tensorflow.py
@@ -60,6 +60,7 @@ try:
from tensorflow.python.pywrap_tensorflow_internal import __git_version__
from tensorflow.python.pywrap_tensorflow_internal import __compiler_version__
from tensorflow.python.pywrap_tensorflow_internal import __cxx11_abi_flag__
+ from tensorflow.python.pywrap_tensorflow_internal import __monolithic_build__
if _use_dlopen_global_flags:
pywrap_dlopen_global_flags.reset_dlopen_flags()
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index cb2b06d47c..44a3a745ad 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -36,6 +36,7 @@ limitations under the License.
#include <assert.h>
#include <complex>
+#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
@@ -268,6 +269,11 @@ PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmEx)
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGemmEx)
#endif
+#if CUDA_VERSION >= 9000
+PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGetMathMode)
+PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSetMathMode)
+#endif
+
} // namespace wrap
static string ToString(cublasStatus_t status) {
@@ -299,6 +305,18 @@ static string ToString(cublasStatus_t status) {
}
}
+// Decide whether to enable TENSOR_OP_MATH
+static bool TensorOpMathEnabled() {
+ static bool is_enabled = [] {
+ bool is_disabled;
+ TF_CHECK_OK(
+ tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUBLAS_TENSOR_OP_MATH",
+ /*default_val=*/false, &is_disabled));
+ return !is_disabled;
+ }();
+ return is_enabled;
+}
+
// cuBLAS has interfaces that permit pointers to be passed from either the host
// memory space or the device memory space; however, you must instruct it as to
// which address space those pointers are in with cublasSetPointerMode.
@@ -360,6 +378,65 @@ class ScopedCublasPointerMode {
bool ok_; // Whether the change was successful.
};
+#if CUDA_VERSION >= 9000
+// cuBLAS has interfaces that permit computations to use the Volta hardware.
+// This must be enabled via the cublasGet/SetMathMode APIs.
+//
+// This helper sets the cuBLAS math mode to a desired value for a cuBLAS call
+// you are about to perform in a given scope.
+//
+// The prior cuBLAS math mode is retained and restored when this object goes
+// out of scope.
+class ScopedCublasMathMode {
+ public:
+ // Note that, because the setting of the cublas math mode is fallible,
+ // construction of this scoped datatype must be paired with a call to
+ // Init().
+ //
+ // Parameters:
+ // handle: The cublas library handle to act upon in setting the math mode.
+ explicit ScopedCublasMathMode(CUDAExecutor *parent, cublasHandle_t handle)
+ : parent_(parent), handle_(handle), ok_(false) {}
+
+ // Attempts the switch to the requested scoped math mode, new_mode.
+ //
+ // Note that when false is returned, an appropriate error has already been
+ // logged.
+ bool Init(cublasMath_t new_mode) {
+ cublasStatus_t ret = wrap::cublasGetMathMode(parent_, handle_, &old_mode_);
+ if (ret != CUBLAS_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to get old cublas math mode: " << ToString(ret);
+ return ok_ = false;
+ }
+
+ ret = wrap::cublasSetMathMode(parent_, handle_, new_mode);
+ if (ret != CUBLAS_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set new cublas math mode: " << ToString(ret);
+ return ok_ = false;
+ }
+ return ok_ = true;
+ }
+
+ // Switches back to the prior math mode, if the switch operation was
+ // successful in the first place.
+ ~ScopedCublasMathMode() {
+ if (ok_) {
+ cublasStatus_t ret = wrap::cublasSetMathMode(parent_, handle_, old_mode_);
+ if (ret != CUBLAS_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set former cublas math mode: "
+ << ToString(ret);
+ }
+ }
+ }
+
+ private:
+ CUDAExecutor *parent_; // Executor establishing this math mode for.
+ cublasHandle_t handle_; // Handle to the cuBLAS instance of interest.
+ cublasMath_t old_mode_; // Prior cuBLAS math mode, to be restored.
+ bool ok_; // Whether the change was successful.
+};
+#endif // CUDA_VERSION >= 9000
+
bool CUDABlas::Init() {
cublasStatus_t ret = wrap::cublasCreate(parent_, &blas_);
if (ret != CUBLAS_STATUS_SUCCESS) {
@@ -532,7 +609,7 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
template <typename FuncT, typename... Args>
bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
bool pointer_mode_host, bool err_on_failure,
- Args... args) {
+ bool use_tensor_op_math, Args... args) {
mutex_lock lock{mu_};
CHECK(blas_ != nullptr);
@@ -545,7 +622,14 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
: CUBLAS_POINTER_MODE_DEVICE)) {
return false;
}
-
+#if CUDA_VERSION >= 9000
+ ScopedCublasMathMode math_mode{parent_, blas_};
+ if (use_tensor_op_math) {
+ if (!math_mode.Init(CUBLAS_TENSOR_OP_MATH)) {
+ return false;
+ }
+ }
+#endif
cublasStatus_t ret = cublas_func(parent_, blas_, args...);
if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) {
LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": "
@@ -1762,14 +1846,26 @@ bool CUDABlas::DoBlasGemm(
"precondition violation";
}
}
- // TODO(sesse): Consider supporting the Hgemm interface, which uses half
- // calculations internally (faster on newer devices, such as Pascal and TX1,
- // but less precise).
- return DoBlasInternal(
+
+ bool use_tensor_ops = false;
+#if CUDA_VERSION >= 9000
+ int cc_major, cc_minor;
+ stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
+ &cc_minor);
+
+ // GPUs < sm_70 don't support Volta hardware.
+ if (cc_major >= 7 && TensorOpMathEnabled()) {
+ use_tensor_ops = true;
+ }
+#endif
+
+ return DoBlasInternalImpl(
wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */,
- CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
- CUDAMemory(a), SE_CUDA_DATA_HALF, lda, CUDAMemory(b), SE_CUDA_DATA_HALF,
- ldb, &beta, CUDAMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
+ true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa),
+ CUDABlasTranspose(transb), m, n, k, &alpha, CUDAMemory(a),
+ SE_CUDA_DATA_HALF, lda, CUDAMemory(b), SE_CUDA_DATA_HALF, ldb, &beta,
+ CUDAMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
+
#else
LOG(ERROR) << "fp16 sgemm is not implemented in this cuBLAS version "
<< "(need at least CUDA 7.5)";
@@ -2031,6 +2127,26 @@ bool CUDABlas::DoBlasGemmWithProfilingImpl(
return result;
}
+static bool UsesTensorOps(blas::AlgorithmType algo) {
+#if CUDA_VERSION >= 9000
+ cublasGemmAlgo_t cublas_algo = static_cast<cublasGemmAlgo_t>(algo);
+ return cublas_algo >= CUBLAS_GEMM_DEFAULT_TENSOR_OP;
+#else
+ return false;
+#endif
+}
+
+template <typename InType>
+static bool TensorOpsAvailable(int cc_major) {
+#if CUDA_VERSION >= 9000
+ if (cc_major >= 7 && TensorOpMathEnabled() &&
+ std::is_same<InType, Eigen::half>::value) {
+ return true;
+ }
+#endif
+ return false;
+}
+
template <typename InT, typename OutT, typename CompT>
bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
@@ -2049,6 +2165,10 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
return false;
}
+ if (UsesTensorOps(algorithm) && !TensorOpsAvailable<InT>(cc_major)) {
+ return false;
+ }
+
struct TimerDeleter {
void operator()(CUDATimer *t) {
t->Destroy();
@@ -2098,10 +2218,19 @@ bool CUDABlas::GetBlasGemmAlgorithms(
// still return the out_algorithms. Caller needs to make sure that in this case,
// the returned vector is empty.
#if CUDA_VERSION >= 8000
- for (cublasGemmAlgo_t algo :
- {CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
- CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4,
- CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7}) {
+ for (cublasGemmAlgo_t algo : {
+ CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
+ CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4,
+ CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7,
+#if CUDA_VERSION >= 9000
+ CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10,
+ CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13,
+ CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16,
+ CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP,
+ CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP,
+ CUBLAS_GEMM_ALGO2_TENSOR_OP
+#endif
+ }) {
out_algorithms->push_back(algo);
}
#endif
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h
index 80cda97117..deb211c04b 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.h
+++ b/tensorflow/stream_executor/cuda/cuda_blas.h
@@ -84,7 +84,7 @@ class CUDABlas : public blas::BlasSupport {
template <typename FuncT, typename... Args>
bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
bool pointer_mode_host, bool err_on_failure,
- Args... args);
+ bool use_tensor_op_math, Args... args);
// Convenience functions that call DoBlasInternalImpl with different values
// for err_on_failure.
@@ -92,13 +92,17 @@ class CUDABlas : public blas::BlasSupport {
bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
Args... args) {
return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
- /*err_on_failure=*/true, args...);
+ /*err_on_failure=*/true, /*use_tensor_ops=*/false,
+ args...);
}
template <typename FuncT, typename... Args>
bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream,
bool pointer_mode_host, Args... args) {
+ // Tensor ops are hard-coded off in this path, but can still be enabled with
+ // a specific algorithm choice as in DoBlasGemmWithAlgorithmImpl().
return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
- /*err_on_failure=*/false, args...);
+ /*err_on_failure=*/false,
+ /*use_tensor_ops=*/false, args...);
}
// A helper function to implement DoBlasGemmBatched interfaces for generic
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 5519381d51..384445e6c1 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -559,10 +559,11 @@ class ScopedFilterDescriptor {
// A helper function to decide whether to enable the TENSOR_OP_MATH math type
static bool TensorOpMathEnabled() {
static bool is_enabled = [] {
- bool ret;
- TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DISABLE_TENSOR_OP_MATH",
- /*default_val=*/false, &ret));
- return !ret;
+ bool is_disabled;
+ TF_CHECK_OK(
+ tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_TENSOR_OP_MATH",
+ /*default_val=*/false, &is_disabled));
+ return !is_disabled;
}();
return is_enabled;
}
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 82675a91c5..2135f6dd01 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -196,6 +196,10 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False):
+ if_linux_x86_64(["-msse3"])
+ if_ios_x86_64(["-msse4.1"])
+ select({
+ "//tensorflow:framework_shared_object": [],
+ "//conditions:default": ["-DTENSORFLOW_MONOLITHIC_BUILD"],
+ })
+ + select({
clean_dep("//tensorflow:android"): android_copts,
clean_dep("//tensorflow:darwin"): [],
clean_dep("//tensorflow:windows"): get_win_copts(is_external),
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index b6f9414571..35917e94ad 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -125,6 +125,10 @@ tf_module {
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
member {
+ name: "MONOLITHIC_BUILD"
+ mtype: "<type \'int\'>"
+ }
+ member {
name: "MetaGraphDef"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
index 0c9f3bb5b3..dee98a027e 100644
--- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
@@ -42,7 +42,6 @@ broken_cpu_cc_tests="\
//tensorflow/core/platform/cloud:gcs_file_system_test + \
//tensorflow/core/kernels/cloud:bigquery_table_accessor_test + \
//tensorflow/core/kernels/hexagon:graph_transferer_test + \
- //tensorflow/core/kernels/hexagon:quantized_matmul_op_for_hexagon_test + \
//tensorflow/core/kernels:remote_fused_graph_execute_utils_test + \
//tensorflow/core/kernels:requantize_op_test + \
//tensorflow/core/kernels:requantization_range_op_test + \
diff --git a/tensorflow/tools/git/gen_git_source.py b/tensorflow/tools/git/gen_git_source.py
index 3630dbd740..f2845c877f 100755
--- a/tensorflow/tools/git/gen_git_source.py
+++ b/tensorflow/tools/git/gen_git_source.py
@@ -16,10 +16,7 @@
"""Help include git hash in tensorflow bazel build.
This creates symlinks from the internal git repository directory so
-that the build system can see changes in the version state. We also
-remember what branch git was on so when the branch changes we can
-detect that the ref file is no longer correct (so we can suggest users
-run ./configure again).
+that the build system can see changes in the version state.
NOTE: this script is only used in opensource.
@@ -221,13 +218,14 @@ def generate(arglist):
if not data["git"]:
git_version = b"unknown"
else:
- old_branch = data["branch"]
+ old_branch = data["branch"]
new_branch = parse_branch_ref(head_symlink)
if new_branch != old_branch:
- raise RuntimeError(
- "Run ./configure again, branch was '%s' but is now '%s'" %
- (old_branch, new_branch))
- git_version = get_git_version(data["path"])
+ print("Warning, run ./configure again, to get __git_version__ to record "
+ "correct version")
+ git_version = get_git_version(data["path"])+'-inconsistent-git-version'
+ else:
+ git_version = get_git_version(data["path"])
write_version_info(dest_file, git_version)
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index 58489b28c8..b5465b7fb3 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -316,3 +316,14 @@ tf_py_test(
],
main = "python/transform_graph_test.py",
)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+)