aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-04-22 06:08:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-22 07:28:38 -0700
commit326942394e69074d50d5889218a24c9371eff259 (patch)
tree50c78852c36b828440761a16650718f224560f7b
parent3c0900a49c11b7975c7accc026153bbc2001c018 (diff)
Merge changes from github.
Change: 153925676
-rw-r--r--README.md15
-rw-r--r--RELEASE.md3
-rwxr-xr-xconfigure22
-rw-r--r--tensorflow/BUILD24
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl1
-rw-r--r--tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc2
-rw-r--r--tensorflow/contrib/learn/python/learn/datasets/mnist.py31
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py14
-rwxr-xr-xtensorflow/contrib/makefile/compile_linux_protobuf.sh2
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py4
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py2
-rw-r--r--tensorflow/contrib/slim/README.md2
-rw-r--r--tensorflow/contrib/verbs/BUILD168
-rw-r--r--tensorflow/contrib/verbs/README.md77
-rw-r--r--tensorflow/contrib/verbs/design_diagram.pngbin0 -> 13625 bytes
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_client.cc47
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_client.h50
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service.cc165
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service.h72
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service_impl.cc68
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service_impl.h89
-rw-r--r--tensorflow/contrib/verbs/rdma.cc874
-rw-r--r--tensorflow/contrib/verbs/rdma.h277
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.cc133
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.h54
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc149
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.h64
-rw-r--r--tensorflow/contrib/verbs/verbs_server_lib.cc172
-rw-r--r--tensorflow/contrib/verbs/verbs_server_lib.h66
-rw-r--r--tensorflow/contrib/verbs/verbs_service.proto60
-rw-r--r--tensorflow/contrib/verbs/verbs_util.cc61
-rw-r--r--tensorflow/contrib/verbs/verbs_util.h41
-rw-r--r--tensorflow/core/BUILD14
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc26
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h18
-rw-r--r--tensorflow/core/framework/function_testlib.cc44
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc1275
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc561
-rw-r--r--tensorflow/core/graph/mkl_optimizer_merge.cc651
-rw-r--r--tensorflow/core/graph/mkl_optimizer_merge.h36
-rw-r--r--tensorflow/core/graph/mkl_optimizer_merge_test.cc470
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc55
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass_test.cc153
-rw-r--r--tensorflow/core/grappler/optimizers/auto_parallel.cc2
-rw-r--r--tensorflow/core/kernels/BUILD32
-rw-r--r--tensorflow/core/kernels/fixed_length_record_reader_op.cc40
-rw-r--r--tensorflow/core/kernels/mkl_avgpooling_op.cc42
-rw-r--r--tensorflow/core/kernels/mkl_concat_op.cc458
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc12
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc12
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc12
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc24
-rw-r--r--tensorflow/core/kernels/mkl_fused_batch_norm_op.cc689
-rw-r--r--tensorflow/core/kernels/mkl_lrn_op.cc722
-rw-r--r--tensorflow/core/kernels/mkl_maxpooling_op.cc74
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc32
-rw-r--r--tensorflow/core/kernels/mkl_reshape_op.cc149
-rw-r--r--tensorflow/core/kernels/mkl_tfconv_op.cc10
-rw-r--r--tensorflow/core/kernels/quantized_conv_ops.cc4
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc143
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h5
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc37
-rw-r--r--tensorflow/core/ops/array_ops.cc58
-rw-r--r--tensorflow/core/ops/io_ops.cc12
-rw-r--r--tensorflow/core/ops/nn_ops.cc257
-rw-r--r--tensorflow/core/ops/ops.pbtxt53
-rw-r--r--tensorflow/core/ops/sparse_ops.cc3
-rw-r--r--tensorflow/core/platform/default/build_config.bzl12
-rw-r--r--tensorflow/core/platform/default/build_config_root.bzl8
-rw-r--r--tensorflow/core/public/version.h4
-rw-r--r--tensorflow/core/util/mkl_util.h271
-rw-r--r--tensorflow/docs_src/community/style_guide.md50
-rw-r--r--tensorflow/docs_src/extend/adding_an_op.md6
-rw-r--r--tensorflow/docs_src/get_started/get_started.md2
-rw-r--r--tensorflow/docs_src/get_started/monitors.md9
-rw-r--r--tensorflow/docs_src/get_started/tflearn.md2
-rw-r--r--tensorflow/docs_src/install/install_c.md2
-rw-r--r--tensorflow/docs_src/install/install_go.md2
-rw-r--r--tensorflow/docs_src/install/install_java.md16
-rw-r--r--tensorflow/docs_src/install/install_linux.md28
-rw-r--r--tensorflow/docs_src/install/install_mac.md14
-rw-r--r--tensorflow/docs_src/install/install_sources.md5
-rw-r--r--tensorflow/docs_src/install/install_windows.md4
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py3
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_with_summaries.py17
-rw-r--r--tensorflow/python/BUILD5
-rw-r--r--tensorflow/python/framework/dtypes_test.py2
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py54
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py42
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py19
-rw-r--r--tensorflow/python/layers/convolutional_test.py2
-rw-r--r--tensorflow/python/layers/normalization_test.py2
-rw-r--r--tensorflow/python/layers/utils_test.py2
-rw-r--r--tensorflow/python/ops/batch_norm_benchmark.py2
-rw-r--r--tensorflow/python/ops/io_ops.py14
-rw-r--r--tensorflow/python/ops/nn_impl.py28
-rw-r--r--tensorflow/python/ops/sparse_grad.py14
-rw-r--r--tensorflow/python/ops/sparse_ops.py6
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.h2
-rw-r--r--tensorflow/tensorboard/DEVELOPMENT.md2
-rw-r--r--tensorflow/tensorboard/dist/tf-tensorboard.html20
-rw-r--r--tensorflow/tensorboard/gulp_tasks/bower.js4
-rw-r--r--tensorflow/tensorboard/gulp_tasks/compile.js34
-rw-r--r--tensorflow/tensorboard/gulp_tasks/test.js4
-rw-r--r--tensorflow/tensorboard/gulp_tasks/util.js6
-rw-r--r--tensorflow/tensorboard/gulp_tasks/vulcanize.js27
-rw-r--r--tensorflow/tensorboard/gulpfile.js20
-rw-r--r--tensorflow/tensorboard/package.json2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel4
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu4
-rw-r--r--tensorflow/tools/pip_package/setup.py2
-rw-r--r--third_party/jemalloc.BUILD33
-rw-r--r--third_party/llvm/llvm.BUILD5
115 files changed, 7578 insertions, 2236 deletions
diff --git a/README.md b/README.md
index 3ab4773681..951e7c3b9f 100644
--- a/README.md
+++ b/README.md
@@ -34,12 +34,13 @@ and discussion.**
People who are a little more adventurous can also try our nightly binaries:
-* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
-* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
-* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
-* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
-* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.1.0rc1-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/))
-* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.1.0rc1-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/))
+
+* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
+* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
+* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
+* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
+* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.1.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/))
+* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.1.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/))
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
@@ -62,7 +63,7 @@ $ python
## For more information
-* [TensorFlow website](http://tensorflow.org)
+* [TensorFlow website](https://tensorflow.org)
* [TensorFlow whitepaper](http://download.tensorflow.org/paper/whitepaper2015.pdf)
* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
diff --git a/RELEASE.md b/RELEASE.md
index 6087390c9c..fe6d052640 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -36,6 +36,7 @@
* New navigation bar in Curses-based UI
* NodeStepper (command `invoke_stepper`) now uses intermediate tensor dumps. It also uses `TensorHandles` as direct feeds during successive `cont` calls for improved performance and reduced memory consumption.
* Initial release of installation guides for Java, C, and Go.
+* Added Text Dashboard to TensorBoard.
## Deprecations
@@ -91,6 +92,8 @@
* Command history now persists across runs.
* Bug fix in graph validation related to `tf.while_loops`.
* Java Maven fixes for bugs with Windows installation.
+* Backport fixes and improvements from external keras.
+* Keras config file handling fix.
## Thanks to our Contributors
diff --git a/configure b/configure
index 47bdd5d018..fad3fdbebd 100755
--- a/configure
+++ b/configure
@@ -94,10 +94,10 @@ write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH"
if false; then # Disable building with MKL for now
while [ "$TF_NEED_MKL" == "" ]; do
fromuser=""
- read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT
+ read -p "Do you wish to build TensorFlow with MKL support (experimental)? [y/N] " INPUT
fromuser="1"
case $INPUT in
- [Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;;
+ [Yy]* ) echo "MKL support (experimental) (will be enabled for TensorFlow"; TF_NEED_MKL=1;;
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
* ) echo "Invalid selection: " $INPUT;;
@@ -244,6 +244,24 @@ if [[ "$TF_ENABLE_XLA" == "1" ]]; then
write_to_bazelrc 'build --define with_xla_support=true'
fi
+# Verbs configuration
+while [ "$TF_NEED_VERBS" == "" ]; do
+ read -p "Do you wish to build TensorFlow with "\
+"VERBS support? [y/N] " INPUT
+ case $INPUT in
+ [Yy]* ) echo "VERBS support will be enabled for "\
+"TensorFlow"; TF_NEED_VERBS=1;;
+ [Nn]* ) echo "No VERBS support will be enabled for "\
+"TensorFlow"; TF_NEED_VERBS=0;;
+ "" ) echo "No VERBS support will be enabled for "\
+"TensorFlow"; TF_NEED_VERBS=0;;
+ * ) echo "Invalid selection: " $INPUT;;
+ esac
+done
+
+if [[ "$TF_NEED_VERBS" == "1" ]]; then
+ write_to_bazelrc 'build --define with_verbs_support=true'
+fi
# Invoke python_config and set up symlinks to python includes
./util/python/python_config.sh "$PYTHON_BIN_PATH"
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 0f7f848cb1..248b18e020 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -85,6 +85,12 @@ config_setting(
)
config_setting(
+ name = "linux_ppc64le",
+ values = {"cpu": "ppc"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
name = "debug",
values = {
"compilation_mode": "dbg",
@@ -108,7 +114,7 @@ config_setting(
# TODO(jhseu): Enable on other platforms other than Linux.
config_setting(
- name = "with_jemalloc",
+ name = "with_jemalloc_linux_x86_64",
values = {
"cpu": "k8",
"define": "with_jemalloc=true",
@@ -117,6 +123,15 @@ config_setting(
)
config_setting(
+ name = "with_jemalloc_linux_ppc64le",
+ values = {
+ "cpu": "ppc",
+ "define": "with_jemalloc=true",
+ },
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
name = "with_gcp_support",
values = {"define": "with_gcp_support=true"},
visibility = ["//visibility:public"],
@@ -134,6 +149,12 @@ config_setting(
visibility = ["//visibility:public"],
)
+config_setting(
+ name = "with_verbs_support",
+ values = {"define": "with_verbs_support=true"},
+ visibility = ["//visibility:public"],
+)
+
package_group(
name = "internal",
packages = ["//tensorflow/..."],
@@ -249,6 +270,7 @@ filegroup(
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
"//tensorflow/contrib/training:all_files",
"//tensorflow/contrib/util:all_files",
+ "//tensorflow/contrib/verbs:all_files",
"//tensorflow/contrib/xla_tf_graph:all_files",
"//tensorflow/core:all_files",
"//tensorflow/core/debug:all_files",
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 64e5bfd602..7d61bee8ca 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -282,5 +282,6 @@ def target_llvm_triple():
"//tensorflow:android_arm": "armv7-none-android",
"//tensorflow:android_arm64": "aarch64-none-android",
"//tensorflow:android_x86": "i686-none-android",
+ "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//conditions:default": "x86_64-pc-linux",
})
diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
index a758bb92aa..e520139e65 100644
--- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
+++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
@@ -142,7 +142,7 @@ template <typename UInt>
string LittleEndianData(UInt data) {
static_assert(std::is_unsigned<UInt>::value, "UInt must be unsigned");
string str;
- for (int i = 0; i < sizeof(UInt); ++i) {
+ for (size_t i = 0; i < sizeof(UInt); ++i) {
const unsigned char bits = static_cast<unsigned char>(data & 0xFFU);
char ch;
::memcpy(&ch, &bits, sizeof(bits));
diff --git a/tensorflow/contrib/learn/python/learn/datasets/mnist.py b/tensorflow/contrib/learn/python/learn/datasets/mnist.py
index fd50070dac..13f213c197 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/mnist.py
+++ b/tensorflow/contrib/learn/python/learn/datasets/mnist.py
@@ -26,6 +26,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import random_seed
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
@@ -109,12 +110,16 @@ class DataSet(object):
fake_data=False,
one_hot=False,
dtype=dtypes.float32,
- reshape=True):
+ reshape=True,
+ seed=None):
"""Construct a DataSet.
one_hot arg is used only if fake_data is true. `dtype` can be either
`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
- `[0, 1]`.
+ `[0, 1]`. Seed arg provides for convenient deterministic testing.
"""
+ seed1, seed2 = random_seed.get_seed(seed)
+ # If op level seed is not set, use whatever graph level seed is returned
+ numpy.random.seed(seed1 if seed is None else seed2)
dtype = dtypes.as_dtype(dtype).base_dtype
if dtype not in (dtypes.uint8, dtypes.float32):
raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
@@ -208,11 +213,13 @@ def read_data_sets(train_dir,
one_hot=False,
dtype=dtypes.float32,
reshape=True,
- validation_size=5000):
+ validation_size=5000,
+ seed=None):
if fake_data:
def fake():
- return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)
+ return DataSet(
+ [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)
train = fake()
validation = fake()
@@ -254,12 +261,16 @@ def read_data_sets(train_dir,
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:]
- train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
- validation = DataSet(validation_images,
- validation_labels,
- dtype=dtype,
- reshape=reshape)
- test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)
+ train = DataSet(
+ train_images, train_labels, dtype=dtype, reshape=reshape, seed=seed)
+ validation = DataSet(
+ validation_images,
+ validation_labels,
+ dtype=dtype,
+ reshape=reshape,
+ seed=seed)
+ test = DataSet(
+ test_images, test_labels, dtype=dtype, reshape=reshape, seed=seed)
return base.Datasets(train=train, validation=validation, test=test)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index ce2eb4e052..6e10fdb977 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -52,7 +52,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables as variables_lib
@@ -63,7 +62,6 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import monitored_session
-from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import compat
@@ -82,18 +80,6 @@ def boston_input_fn(num_epochs=None):
return features, labels
-def boston_input_fn_with_queue(num_epochs=None):
- features, labels = boston_input_fn(num_epochs=num_epochs)
-
- # Create a minimal queue runner.
- fake_queue = data_flow_ops.FIFOQueue(30, dtypes.int32)
- queue_runner = queue_runner_impl.QueueRunner(fake_queue,
- [constant_op.constant(0)])
- queue_runner_impl.add_queue_runner(queue_runner)
-
- return features, labels
-
-
def iris_input_fn():
iris = base.load_iris()
features = array_ops.reshape(
diff --git a/tensorflow/contrib/makefile/compile_linux_protobuf.sh b/tensorflow/contrib/makefile/compile_linux_protobuf.sh
index 480fbcc215..6eb061a3c9 100755
--- a/tensorflow/contrib/makefile/compile_linux_protobuf.sh
+++ b/tensorflow/contrib/makefile/compile_linux_protobuf.sh
@@ -38,7 +38,7 @@ then
exit 1
fi
-./configure --prefix="${GENDIR}"
+./configure --prefix="${GENDIR}" --with-pic
if [ $? -ne 0 ]
then
echo "./configure command failed."
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
index 9a96d4e856..3a5cbf604d 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
@@ -68,8 +68,8 @@ class LSTMBlockCellTest(test.TestCase):
m3 = array_ops.zeros([1, 2])
g, ((out_m0, out_m1),
(out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
- [lstm_ops.LSTMBlockCell(2)] * 2, state_is_tuple=True)(x, (
- (m0, m1), (m2, m3)))
+ [lstm_ops.LSTMBlockCell(2) for _ in range(2)],
+ state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
sess.run([variables.global_variables_initializer()])
res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
x.name: np.array([[1., 1.]]),
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 023164d826..37622af59f 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -473,7 +473,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
if probability_fn is None:
probability_fn = nn_ops.softmax
else:
- if not callable(cell_input_fn):
+ if not callable(probability_fn):
raise TypeError(
"probability_fn must be callable, saw type: %s"
% type(probability_fn).__name__)
diff --git a/tensorflow/contrib/slim/README.md b/tensorflow/contrib/slim/README.md
index dae50e67c5..c8842dd57b 100644
--- a/tensorflow/contrib/slim/README.md
+++ b/tensorflow/contrib/slim/README.md
@@ -447,7 +447,7 @@ vgg = tf.contrib.slim.nets.vgg
images, labels = ...
# Create the model.
-predictions = vgg.vgg16(images)
+predictions = vgg.vgg_16(images)
# Define the loss functions and get the total loss.
loss = slim.losses.softmax_cross_entropy(predictions, labels)
diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD
new file mode 100644
index 0000000000..e747fa4c9e
--- /dev/null
+++ b/tensorflow/contrib/verbs/BUILD
@@ -0,0 +1,168 @@
+# Description:
+# Verbs RDMA communication interfaces and implementations for TensorFlow.
+
+package(default_visibility = [
+ "//tensorflow:__subpackages__",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+# For platform specific build config
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_proto_library_cc",
+)
+
+tf_proto_library_cc(
+ name = "verbs_service_proto",
+ srcs = ["verbs_service.proto"],
+ has_services = 1,
+ cc_api_version = 2,
+ visibility = [
+ "//tensorflow:__subpackages__",
+ ],
+)
+
+cc_library(
+ name = "verbs_util",
+ srcs = ["verbs_util.cc"],
+ hdrs = ["verbs_util.h"],
+ deps = [
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:gpu_runtime",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "grpc_verbs_service",
+ srcs = ["grpc_verbs_service.cc"],
+ hdrs = ["grpc_verbs_service.h"],
+ deps = [
+ ":grpc_verbs_service_impl",
+ ":rdma_mgr",
+ ":verbs_service_proto_cc",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/distributed_runtime:session_mgr",
+ "//tensorflow/core/distributed_runtime:worker_env",
+ "//tensorflow/core/distributed_runtime/rpc:async_service_interface",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_call",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_util",
+ "@grpc//:grpc++_unsecure",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "grpc_verbs_service_impl",
+ srcs = ["grpc_verbs_service_impl.cc"],
+ hdrs = ["grpc_verbs_service_impl.h"],
+ deps = [
+ ":verbs_service_proto_cc",
+ "@grpc//:grpc++_unsecure",
+ ],
+)
+
+cc_library(
+ name = "grpc_verbs_client",
+ srcs = ["grpc_verbs_client.cc"],
+ hdrs = ["grpc_verbs_client.h"],
+ deps = [
+ ":grpc_verbs_service_impl",
+ ":verbs_service_proto_cc",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/distributed_runtime:call_options",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "rdma_rendezvous_mgr",
+ srcs = ["rdma_rendezvous_mgr.cc"],
+ hdrs = ["rdma_rendezvous_mgr.h"],
+ deps = [
+ ":rdma_mgr",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime:worker_env",
+ ],
+)
+
+cc_library(
+ name = "rdma_mgr",
+ srcs = ["rdma_mgr.cc"],
+ hdrs = ["rdma_mgr.h"],
+ deps = [
+ ":grpc_verbs_client",
+ ":rdma",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/distributed_runtime:worker_env",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_channel",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
+ ],
+)
+
+cc_library(
+ name = "rdma",
+ srcs = ["rdma.cc"],
+ hdrs = ["rdma.h"],
+ linkopts = select({
+ "//tensorflow:with_verbs_support": ["-libverbs"],
+ "//conditions:default": [],
+ }),
+ deps = [
+ ":verbs_util",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:gpu_runtime",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
+ "//tensorflow/core/distributed_runtime:session_mgr",
+ "//tensorflow/core/distributed_runtime:worker_env",
+ ],
+)
+
+cc_library(
+ name = "verbs_server_lib",
+ srcs = ["verbs_server_lib.cc"],
+ hdrs = ["verbs_server_lib.h"],
+ linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
+ deps = [
+ ":grpc_verbs_service",
+ ":rdma_mgr",
+ ":rdma_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/contrib/verbs/README.md b/tensorflow/contrib/verbs/README.md
new file mode 100644
index 0000000000..37a543dda8
--- /dev/null
+++ b/tensorflow/contrib/verbs/README.md
@@ -0,0 +1,77 @@
+## How to compile and use Rdma-enabled tensorflow
+1. Follow the regular TF compilation instructions. During configure step, if you want ibverbs based Rdma support, answer yes to this question:
+
+ ```Do you wish to build TensorFlow with VERBS-RDMA support [y/N]```
+
+2. To turn on Rdma connection, add the protocol "grpc+verbs" in server definition:
+
+ ```server = tf.train.Server(cluster, job_name="local", task_index=0, protocol='grpc+verbs') # default protocol is 'grpc'```
+
+## Overview
+The design is based on Tensorflow r1.0. An Rdma path is added between servers for tensor transfer (weights, gradients, etc). The existing GRPC path remains and is responsible for "administrative" tasks, such as setting up the Rdma path, exchanging computation graphs, etc.
+
+During the server setup, an Rdma manager is created to manage low-level Rdma components such as Rdma channel and Rdma adapter, an Rdma rendezvous manager is created to oversee send/recv operations between servers. Following the distributed Tensorflow design philosophy, the send operation is passive, i.e. merely placing a tensor in the local out-going table. It is the receive operation that actually initiates the tensor transfer.
+
+Tensorflow dynamically allocates memory for tensors that are to be sent or received. This causes difficulty for Rdma operations where pinned memory is required. Two remedies are possible, either the memory is pinned, transfer, then unpinned for each and every tensor to be transferred, or a buffer is pre-allocated and pinned for each tensor. The former incurs significant operation overhead since pinning and unpinning memory for each dynamically generated tensor is slow. The latter incurs large memory overhead and extra copying from the tensor to its pinned buffer, but may still be faster than the former. The second approach is adopted in this design. Each Rdma channel, representing a Rdma connection to a peer, contains a table of pinned buffers for all the seen tensors that requires transfer. It is assumed that the tensor size rarely changes across different steps. So only one buffer is created for the same tensor across all the steps. In the rare case when the tensor size does increases, the old buffer is discarded and new buffer of larger size is created and pinned.
+
+When a tensor is prepared fro transfer, it is first converted to TensorProto, then the proto is serialized to byte array and copied to the pinned buffer. The content of the buffer is transferred to the remote node via Rdma write. On the remote side, the process is reversed. This is illustrated in the diagram below. The conversion of TensorProto is introduced to simplify transfer of string-tensors. Also since the TensorProto lives in host memory, even if the origin tensor lives in the device, the pinned buffers are all allocated in the host memory.
+![Tensorflow Rdma path](./design_diagram.png)
+
+The following improvements can be made in the future. First, conversion to TensorProto and serialization can be avoided for numeric (float/int) tensors since their internal buffer can be access directly as byte array. Second, the pinned buffer may be allocated on device if the tensor is located in the device. This avoids extra device-to-host copy at the expense of extra device memory consumption.
+## Design details
+
+### Rdma components
+
+* **Rdma adapter:** The base for Rdma communications. It may contain multiple channels and buffers. It is responsible for handling various incoming Rdma messages.
+* **Rdma channel:** Responsible for Rdma connection to a particular node. It manages multiple buffers. A channel has a callback table which stores all the callbacks for the requested tensors.
+* **Rdma buffer:** Responsible for sending or receiving data. It has a fixed size memory to store the data. It has a queue to store the pending jobs. There are three types of buffers, message buffer, ACK buffer and tensor buffer. A channel has two message buffers, two ack buffers and many tensor buffers.
+* **Rdma manager:** Manages the adapter and channels, including channel creation, channel setup via GRPC service, channel lookup, etc.
+* **Rdma rendezvous manager:** manages multiple rdma rendezvous.
+* **Rdma rendezvous:** a derived class of BaseRemoteRendezvous. This class is the back end for "send" and "recv" ops. When the sendrecv_op wants to send or receive a tensor, it calls the rendezvous' "send" and "recv" functions respectively. Rendezvous are identified by "step_id", a random number, so that tensors for different iterations don't get mixed up.
+
+### The SEND operation
+
+In tensorflow, when rendezvous sends a tensor, it merely puts a tensor in a local table in the corresponding rendezvous. If the tensor has been requested, a callback exists in the table. "send" will activate the callback, which tries to send the tensor across the node.
+
+
+### The RECV operation
+
+When a tensor is requested, rendezvous' recv function is called. The function first places a callback in the channel's callback table, which will be activated once the tensor is sent from the source. In the next step, a message is sent to notify the source of the requested tensor. Once the source receives the message, it will check locally for the tensor, if not found, a callback is placed in the table, otherwise, the tensor id will be placed at corresponding Rdma buffer's job queue for future transmission. When a tensor is scheduled to be transmitted, the Rdma buffer needs to have the memory allocated and initialized (registered with the remote buffer info). If the memory is not ready, the transmission is deferred, a message is sent to the destination to establish the memory first. The other case a transimssion can be deferred is when the buffer is still being used by an on-going transmission.
+
+### Three types of Rdma buffers
+
+* **Message buffer:** responsible for sending message only.
+* **Ack buffer:** once a message is sent, the recipient needs to send an ack via the ack buffer to free up the message buffer. An ack buffer is exclusively for its coupled message buffer.
+* **Tensor buffer:** responsible for sending tensors. The recipient needs to send back a message to free up the sending buffer.
+
+### Rdma packet format
+
+|type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|data_type|tensor_shape|tensor_bytes|tensor_buffer|
+
+### Six types of Rdma messages
+* RDMA_MESSAGE_ACK
+* RDMA_MESSAGE_BUFFER_IDLE
+* RDMA_MESSAGE_BUFFER_REQUEST
+* RDMA_MESSAGE_BUFFER_RESPONSE
+* RDMA_MESSAGE_TENSOR_REQUEST
+* RDMA_MESSAGE_TENSOR_WRITE
+
+### Actions upon receiving Rdma messages
+* RDMA_MESSAGE_ACK
+ * sender: mark local ack buffer idle.
+ * receiver: mark remote message buffer idle, send next item.
+* RDMA_MESSAGE_BUFFER_IDLE
+ * sender: mark local message buffer idle, send next item.
+ * receiver: send ack, set remote tensor buffer idle, send next item.
+* RDMA_MESSAGE_BUFFER_REQUEST
+ * sender: mark local message buffer idle, send next item.
+ * receiver: send ack, find or create tensor buffer, send BUFFER_RESPONSE.
+* RDMA_MESSAGE_BUFFER_RESPONSE
+ * sender: mark local message buffer idle, send next item.
+ * receiver: send ack, set remote buffer info, set local and remote buffer idle, send next item.
+* RDMA_MESSAGE_TENSOR_REQUEST
+ * sender: mark local message buffer idle, send next item.
+ * receiver: send ack, find or create tensor buffer, enqueue tensor id, send next item.
+* RDMA_MESSAGE_TENSOR_WRITE
+ * sender: mark local message buffer idle, send next item.
+ * receiver: run callback.
diff --git a/tensorflow/contrib/verbs/design_diagram.png b/tensorflow/contrib/verbs/design_diagram.png
new file mode 100644
index 0000000000..f0ad27455f
--- /dev/null
+++ b/tensorflow/contrib/verbs/design_diagram.png
Binary files differ
diff --git a/tensorflow/contrib/verbs/grpc_verbs_client.cc b/tensorflow/contrib/verbs/grpc_verbs_client.cc
new file mode 100644
index 0000000000..608a9140d3
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_client.cc
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+Status GrpcVerbsClient::GetRemoteAddress(CallOptions* call_options,
+ const GetRemoteAddressRequest* request,
+ GetRemoteAddressResponse* response) {
+ ::grpc::ClientContext ctx;
+ ctx.set_fail_fast(false);
+ SetDeadline(&ctx, call_options->GetTimeout());
+ return FromGrpcStatus(stub_->GetRemoteAddress(&ctx, *request, response));
+}
+
+Status GrpcVerbsClient::GetRemoteAddress(const GetRemoteAddressRequest* request,
+ GetRemoteAddressResponse* response) {
+ CallOptions call_options;
+ call_options.SetTimeout(-1); // no time out
+ return GetRemoteAddress(&call_options, request, response);
+}
+
+void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx,
+ int64 time_in_ms) {
+ if (time_in_ms > 0) {
+ ctx->set_deadline(gpr_time_from_millis(time_in_ms, GPR_TIMESPAN));
+ }
+}
+
+} // namespace tensorflow \ No newline at end of file
diff --git a/tensorflow/contrib/verbs/grpc_verbs_client.h b/tensorflow/contrib/verbs/grpc_verbs_client.h
new file mode 100644
index 0000000000..358977f925
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_client.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// GrpcVerbsClient is a client that uses gRPC to talk to the Verbs service.
+class GrpcVerbsClient {
+ public:
+ explicit GrpcVerbsClient(SharedGrpcChannelPtr client_channel)
+ : stub_(grpc::VerbsService::NewStub(client_channel)) {}
+ ~GrpcVerbsClient() {}
+
+ Status GetRemoteAddress(CallOptions* call_options,
+ const GetRemoteAddressRequest* request,
+ GetRemoteAddressResponse* response);
+ Status GetRemoteAddress(const GetRemoteAddressRequest* request,
+ GetRemoteAddressResponse* response);
+
+ private:
+ std::unique_ptr<grpc::VerbsService::Stub> stub_;
+
+ void SetDeadline(::grpc::ClientContext* ctx, int64 time_in_ms);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsClient);
+};
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc
new file mode 100644
index 0000000000..e73b2700bd
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc
@@ -0,0 +1,165 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "grpc++/alarm.h"
+#include "grpc++/grpc++.h"
+#include "grpc++/server_builder.h"
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
+
+namespace tensorflow {
+
+GrpcVerbsService::GrpcVerbsService(const WorkerEnv* worker_env,
+ ::grpc::ServerBuilder* builder)
+ : is_shutdown_(false), worker_env_(worker_env) {
+ builder->RegisterService(&verbs_service_);
+ cq_ = builder->AddCompletionQueue().release();
+}
+
+GrpcVerbsService::~GrpcVerbsService() {
+ delete shutdown_alarm_;
+ delete cq_;
+}
+
+void GrpcVerbsService::Shutdown() {
+ bool did_shutdown = false;
+ {
+ mutex_lock l(shutdown_mu_);
+ if (!is_shutdown_) {
+ LOG(INFO) << "Shutting down GrpcWorkerService.";
+ is_shutdown_ = true;
+ did_shutdown = true;
+ }
+ }
+ if (did_shutdown) {
+ shutdown_alarm_ =
+ new ::grpc::Alarm(cq_, gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
+ }
+}
+
+// This macro creates a new request for the given RPC method name
+// (e.g., `ENQUEUE_REQUEST(GetRemoteAddress, false);`), and enqueues it on
+// `this->cq_`.
+//
+// This macro is invoked one or more times for each RPC method to
+// ensure that there are sufficient completion queue entries to
+// handle incoming requests without blocking.
+//
+// The implementation of the request handler for each RPC method
+// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
+// to keep accepting new requests.
+#define ENQUEUE_REQUEST(method, supports_cancel) \
+ do { \
+ mutex_lock l(shutdown_mu_); \
+ if (!is_shutdown_) { \
+ Call<GrpcVerbsService, grpc::VerbsService::AsyncService, \
+ method##Request, method##Response>:: \
+ EnqueueRequest(&verbs_service_, cq_, \
+ &grpc::VerbsService::AsyncService::Request##method, \
+ &GrpcVerbsService::method##Handler, \
+ (supports_cancel)); \
+ } \
+ } while (0)
+
+// This method blocks forever handling requests from the completion queue.
+void GrpcVerbsService::HandleRPCsLoop() {
+ for (int i = 0; i < 10; ++i) {
+ ENQUEUE_REQUEST(GetRemoteAddress, false);
+ }
+
+ void* tag;
+ bool ok;
+
+ while (cq_->Next(&tag, &ok)) {
+ UntypedCall<GrpcVerbsService>::Tag* callback_tag =
+ static_cast<UntypedCall<GrpcVerbsService>::Tag*>(tag);
+ if (callback_tag) {
+ callback_tag->OnCompleted(this, ok);
+ } else {
+ cq_->Shutdown();
+ }
+ }
+}
+
+void GrpcVerbsService::GetRemoteAddressHandler(
+ WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call) {
+ Status s = GetRemoteAddressSync(&call->request, &call->response);
+ call->SendResponse(ToGrpcStatus(s));
+ ENQUEUE_REQUEST(GetRemoteAddress, false);
+}
+
+// synchronous method
+Status GrpcVerbsService::GetRemoteAddressSync(
+ const GetRemoteAddressRequest* request,
+ GetRemoteAddressResponse* response) {
+ // analyzing request
+ // the channel setting part is redundant.
+ const string remote_host_name = request->host_name();
+ RdmaChannel* rc = rdma_mgr_->FindChannel(remote_host_name);
+ CHECK(rc);
+ RdmaAddress ra;
+ ra.lid = request->channel().lid();
+ ra.qpn = request->channel().qpn();
+ ra.psn = request->channel().psn();
+ rc->SetRemoteAddress(ra, false);
+ rc->Connect();
+ int i = 0;
+ int idx[] = {1, 0, 3, 2};
+ std::vector<RdmaBuffer*> mb(rc->message_buffers());
+ CHECK_EQ(request->mr_size(), 4);
+ for (const auto& mr : request->mr()) {
+ // the connections are crossed, i.e.
+ // local tx_message_buffer <---> remote rx_message_buffer_
+ // local rx_message_buffer <---> remote tx_message_buffer_
+ // local tx_ack_buffer <---> remote rx_ack_buffer_
+ // local rx_ack_buffer <---> remote tx_ack_buffer_
+ // hence idx[] = {1, 0, 3, 2}.
+ RdmaBuffer* rb = mb[idx[i]];
+ RemoteMR rmr;
+ rmr.remote_addr = mr.remote_addr();
+ rmr.rkey = mr.rkey();
+ rb->SetRemoteMR(rmr, false);
+ i++;
+ }
+ CHECK(i == RdmaChannel::kNumMessageBuffers);
+
+ // setting up response
+ response->set_host_name(
+ worker_env_->session_mgr->LegacySession()->worker_name);
+ Channel* channel_info = response->mutable_channel();
+ channel_info->set_lid(rc->self().lid);
+ channel_info->set_qpn(rc->self().qpn);
+ channel_info->set_psn(rc->self().psn);
+ for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
+ MemoryRegion* mr = response->add_mr();
+ mr->set_remote_addr(reinterpret_cast<uint64>(mb[i]->buffer()));
+ mr->set_rkey(mb[i]->self()->rkey);
+ }
+ return Status::OK();
+}
+
+// Create a GrpcVerbsService, then assign it to a given handle.
+void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
+ ::grpc::ServerBuilder* builder) {
+ *handle = new GrpcVerbsService(worker_env, builder);
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_VERBS
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.h b/tensorflow/contrib/verbs/grpc_verbs_service.h
new file mode 100644
index 0000000000..aa509602b5
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_service.h
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
+#include "tensorflow/core/lib/core/refcount.h"
+
+namespace grpc {
+class ServerBuilder;
+class ServerCompletionQueue;
+class Alarm;
+} // namespace grpc
+
+namespace tensorflow {
+
+class GrpcVerbsService : public AsyncServiceInterface {
+ public:
+ GrpcVerbsService(const WorkerEnv* worker_env, ::grpc::ServerBuilder* builder);
+ ~GrpcVerbsService();
+ void HandleRPCsLoop() override;
+ void Shutdown() override;
+ void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; }
+
+ private:
+ template <class RequestMessage, class ResponseMessage>
+ using WorkerCall = Call<GrpcVerbsService, grpc::VerbsService::AsyncService,
+ RequestMessage, ResponseMessage>;
+ void GetRemoteAddressHandler(
+ WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call);
+ Status GetRemoteAddressSync(const GetRemoteAddressRequest* request,
+ GetRemoteAddressResponse* response);
+
+ ::grpc::ServerCompletionQueue* cq_;
+ grpc::VerbsService::AsyncService verbs_service_;
+ mutex shutdown_mu_;
+ bool is_shutdown_ GUARDED_BY(shutdown_mu_);
+ ::grpc::Alarm* shutdown_alarm_;
+ // not owned
+ RdmaMgr* rdma_mgr_;
+ const WorkerEnv* const worker_env_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsService);
+};
+
+// Create a GrpcVerbsService, then assign it to a given handle.
+void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
+ ::grpc::ServerBuilder* builder);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_VERBS
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
new file mode 100644
index 0000000000..e0ba78dbfd
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
@@ -0,0 +1,68 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
+
+#include "grpc++/impl/codegen/async_stream.h"
+#include "grpc++/impl/codegen/async_unary_call.h"
+#include "grpc++/impl/codegen/channel_interface.h"
+#include "grpc++/impl/codegen/client_unary_call.h"
+#include "grpc++/impl/codegen/method_handler_impl.h"
+#include "grpc++/impl/codegen/rpc_service_method.h"
+#include "grpc++/impl/codegen/service_type.h"
+#include "grpc++/impl/codegen/sync_stream.h"
+
+namespace tensorflow {
+
+namespace grpc {
+
+static const char* grpcVerbsService_method_names[] = {
+ "/tensorflow.VerbsService/GetRemoteAddress",
+};
+
+std::unique_ptr<VerbsService::Stub> VerbsService::NewStub(
+ const std::shared_ptr< ::grpc::ChannelInterface>& channel,
+ const ::grpc::StubOptions& options) {
+ std::unique_ptr<VerbsService::Stub> stub(new VerbsService::Stub(channel));
+ return stub;
+}
+
+VerbsService::Stub::Stub(
+ const std::shared_ptr< ::grpc::ChannelInterface>& channel)
+ : channel_(channel),
+ rpcmethod_GetRemoteAddress_(grpcVerbsService_method_names[0],
+ ::grpc::RpcMethod::NORMAL_RPC, channel) {}
+
+::grpc::Status VerbsService::Stub::GetRemoteAddress(
+ ::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
+ GetRemoteAddressResponse* response) {
+ return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_GetRemoteAddress_,
+ context, request, response);
+}
+
+VerbsService::AsyncService::AsyncService() {
+ for (int i = 0; i < 1; ++i) {
+ AddMethod(new ::grpc::RpcServiceMethod(grpcVerbsService_method_names[i],
+ ::grpc::RpcMethod::NORMAL_RPC,
+ nullptr));
+ ::grpc::Service::MarkMethodAsync(i);
+ }
+}
+
+VerbsService::AsyncService::~AsyncService() {}
+
+} // namespace grpc
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
new file mode 100644
index 0000000000..f7ea774b66
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
@@ -0,0 +1,89 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
+
+#include "grpc++/impl/codegen/async_stream.h"
+#include "grpc++/impl/codegen/async_unary_call.h"
+#include "grpc++/impl/codegen/proto_utils.h"
+#include "grpc++/impl/codegen/rpc_method.h"
+#include "grpc++/impl/codegen/service_type.h"
+#include "grpc++/impl/codegen/status.h"
+#include "grpc++/impl/codegen/stub_options.h"
+#include "grpc++/impl/codegen/sync_stream.h"
+
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+
+namespace grpc {
+class CompletionQueue;
+class Channel;
+class RpcService;
+class ServerCompletionQueue;
+class ServerContext;
+} // namespace grpc
+
+namespace tensorflow {
+
+namespace grpc {
+
+// Implementation of `tensorflow.VerbsService`, based on the
+// definition in "//tensorflow/contrib/verbs/verbs_service.proto",
+// and the gRPC generated stub and service classes.
+// See the proto file for the definition of methods and messages.
+class VerbsService GRPC_FINAL {
+ public:
+ class StubInterface {
+ public:
+ virtual ~StubInterface() {}
+ virtual ::grpc::Status GetRemoteAddress(
+ ::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
+ GetRemoteAddressResponse* response) = 0;
+ };
+ class Stub GRPC_FINAL : public StubInterface {
+ public:
+ Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel);
+ ::grpc::Status GetRemoteAddress(
+ ::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
+ GetRemoteAddressResponse* response) GRPC_OVERRIDE;
+
+ private:
+ std::shared_ptr< ::grpc::ChannelInterface> channel_;
+ const ::grpc::RpcMethod rpcmethod_GetRemoteAddress_;
+ };
+ static std::unique_ptr<Stub> NewStub(
+ const std::shared_ptr< ::grpc::ChannelInterface>& channel,
+ const ::grpc::StubOptions& options = ::grpc::StubOptions());
+
+ class AsyncService : public ::grpc::Service {
+ public:
+ AsyncService();
+ virtual ~AsyncService();
+ void RequestGetRemoteAddress(
+ ::grpc::ServerContext* context, GetRemoteAddressRequest* request,
+ ::grpc::ServerAsyncResponseWriter<GetRemoteAddressResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(0, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ };
+};
+
+} // namespace grpc
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
new file mode 100644
index 0000000000..53d840f5d1
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -0,0 +1,874 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/rdma.h"
+#include <cstdlib>
+#include "tensorflow/contrib/verbs/verbs_util.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+
+namespace {
+// hash name to 32-bit integer
+uint32_t NameHash(const string& name) {
+ return Hash32(name.data(), name.size(), 0x1234ABCD);
+}
+
+// convenience function for printing message
+string MessageTypeToString(RdmaMessageType rmt) {
+ switch (rmt) {
+ case RDMA_MESSAGE_ACK:
+ return "RDMA_MESSAGE_ACK";
+ break;
+ case RDMA_MESSAGE_BUFFER_IDLE:
+ return "RDMA_MESSAGE_BUFFER_IDLE";
+ break;
+ case RDMA_MESSAGE_BUFFER_REQUEST:
+ return "RDMA_MESSAGE_BUFFER_REQUEST";
+ break;
+ case RDMA_MESSAGE_BUFFER_RESPONSE:
+ return "RDMA_MESSAGE_BUFFER_RESPONSE";
+ break;
+ case RDMA_MESSAGE_TENSOR_REQUEST:
+ return "RDMA_MESSAGE_TENSOR_REQUEST";
+ break;
+ case RDMA_MESSAGE_TENSOR_WRITE:
+ return "RDMA_MESSAGE_TENSOR_WRITE";
+ break;
+ default:
+ return "UNKNOWN MESSAGE";
+ }
+}
+} // namespace
+
+ibv_context* open_default_device() {
+ ibv_device** dev_list;
+ ibv_device* ib_dev;
+ dev_list = ibv_get_device_list(NULL);
+ CHECK(dev_list) << "No InfiniBand device found";
+ ib_dev = dev_list[0];
+ CHECK(ib_dev) << "No InfiniBand device found";
+ ibv_context* context = ibv_open_device(ib_dev);
+ CHECK(context) << "Open context failed for " << ibv_get_device_name(ib_dev);
+ return context;
+}
+
+ibv_pd* alloc_protection_domain(ibv_context* context) {
+ ibv_pd* pd = ibv_alloc_pd(context);
+ CHECK(pd) << "Failed to allocate protection domain";
+ return pd;
+}
+
+RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
+ : context_(open_default_device()),
+ pd_(alloc_protection_domain(context_)),
+ worker_env_(worker_env) {
+ event_channel_ = ibv_create_comp_channel(context_);
+ CHECK(event_channel_) << "Failed to create completion channel";
+ cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_,
+ 0);
+ CHECK(cq_) << "Failed to create completion queue";
+ CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification";
+ polling_thread_.reset(Env::Default()->StartThread(
+ ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); }));
+ VLOG(2) << "Start RdmaAdapter: " << name();
+}
+
+RdmaAdapter::~RdmaAdapter() {
+ polling_thread_.reset();
+ CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ";
+ CHECK(!ibv_destroy_comp_channel(event_channel_))
+ << "Failed to destroy channel";
+ CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD";
+ CHECK(!ibv_close_device(context_)) << "Failed to release context";
+}
+
+string RdmaAdapter::name() const { return string(context_->device->name); }
+
+// Function to process incoming messages
+// There are two types of messages:
+// 1. IBV_WC_RECV_RDMA_WITH_IMM (receive)
+// 2. IBV_WC_RDMA_WRITE (send))
+void RdmaAdapter::Process_CQ() {
+ while (true) {
+ ibv_cq* cq;
+ void* cq_context;
+ CHECK(!ibv_get_cq_event(event_channel_, &cq, &cq_context));
+ CHECK(cq == cq_);
+ ibv_ack_cq_events(cq, 1);
+ CHECK(!ibv_req_notify_cq(cq_, 0));
+
+ int ne =
+ ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2, static_cast<ibv_wc*>(wc_));
+ CHECK_GE(ne, 0);
+ for (int i = 0; i < ne; ++i) {
+ CHECK(wc_[i].status == IBV_WC_SUCCESS)
+ << "Failed status \n"
+ << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
+ << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
+ if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
+ RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
+ // put back a recv wr.
+ rc->Recv();
+ // imm_data is the index of RX buffer in the buffer table.
+ uint32_t imm_data = wc_[i].imm_data;
+ RdmaBuffer* rb = rc->FindBuffer(imm_data);
+ RdmaMessage rm;
+ RdmaMessage::ParseMessage(rm, rb->buffer_);
+ VLOG(2) << "recv RDMA message: " << MessageTypeToString(rm.type_);
+
+ if (rm.type_ == RDMA_MESSAGE_ACK) {
+ // receive an ack to a message
+ rb = rc->tx_message_buffer_;
+ rb->SetBufferStatus(remote, idle);
+ rb->SendNextItem();
+ } else if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) {
+ // received a request-for-tensor message
+ // send ack to release remote tx message buffer
+ RdmaBuffer* ab = rc->tx_ack_buffer_;
+ ab->SendNextItem();
+ // find or create buffer
+ RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_);
+ string key_with_step_id =
+ VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
+ tb->EnqueueItem(key_with_step_id);
+ // send the next tensor
+ worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
+ } else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) {
+ // receive tensor-buffer-ready message
+ // send ack to release remote tx message buffer
+ RdmaBuffer* ab = rc->tx_ack_buffer_;
+ ab->SendNextItem();
+ // find buffer
+ RdmaBuffer* tb = rc->FindBuffer(rm.name_);
+ tb->SetBufferStatus(remote, idle);
+ worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
+ } else if (rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) {
+ // remote host requests to create a tensor buffer;
+ // send ack to release remote tx message buffer
+ RdmaBuffer* ab = rc->tx_ack_buffer_;
+ ab->SendNextItem();
+ // find or create the buffer
+ RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_, TENSOR);
+ RemoteMR rmr;
+ rmr.remote_addr = rm.remote_addr_;
+ rmr.rkey = rm.rkey_;
+ tb->SetRemoteMR(rmr, true);
+ tb->CreateCPUBuffer(rm.buffer_size_);
+ // create RDMA_MESSAGE_BUFFER_RESPONSE message
+ RdmaMessage br;
+ br.type_ = RDMA_MESSAGE_BUFFER_RESPONSE;
+ br.name_size_ = rm.name_.size();
+ br.name_ = rm.name_;
+ br.buffer_size_ = rm.buffer_size_;
+ br.remote_addr_ = reinterpret_cast<uint64_t>(tb->buffer_);
+ br.rkey_ = tb->self_->rkey;
+ string message = RdmaMessage::CreateMessage(br);
+ RdmaBuffer* mb = rc->tx_message_buffer_;
+ mb->EnqueueItem(message);
+ mb->SendNextItem();
+ } else if (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE) {
+ // remote creates a buffer and responds
+ // send ack to release remote tx message buffer
+ RdmaBuffer* ab = rc->tx_ack_buffer_;
+ ab->SendNextItem();
+ // find buffer
+ RdmaBuffer* tb = rc->FindBuffer(rm.name_);
+ CHECK(rm.buffer_size_ == tb->size_)
+ << "rm.buffer_size = " << rm.buffer_size_
+ << "tb->size_ = " << tb->size_ << "rm.name_ = " << rm.name_;
+ RemoteMR rmr;
+ rmr.remote_addr = rm.remote_addr_;
+ rmr.rkey = rm.rkey_;
+ tb->SetRemoteMR(rmr, true);
+ tb->SetBufferStatus(local, idle);
+ tb->SetBufferStatus(remote, idle);
+ worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
+ } else if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
+ // tensor RDMA write completed
+ worker_env_->compute_pool->Schedule([rm, rc]() {
+ string key_with_step_id =
+ VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
+ rc->RunRecvCallback(key_with_step_id);
+ });
+ }
+ } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
+ RdmaBuffer* rb = reinterpret_cast<RdmaBuffer*>(wc_[i].wr_id);
+ rb->SetBufferStatus(local, idle);
+ RdmaMessage rm;
+ RdmaMessage::ParseMessage(rm, rb->buffer_);
+ VLOG(2) << "sent RDMA message: " << MessageTypeToString(rm.type_);
+ if (rm.type_ != RDMA_MESSAGE_ACK) {
+ worker_env_->compute_pool->Schedule([rb]() { rb->SendNextItem(); });
+ }
+ }
+ }
+ }
+}
+
+RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
+ const string remote_name)
+ : adapter_(adapter), local_name_(local_name), remote_name_(remote_name) {
+ // Create queue pair
+ {
+ struct ibv_qp_init_attr attr;
+ memset(&attr, 0, sizeof(ibv_qp_init_attr));
+ attr.send_cq = adapter_->cq_;
+ attr.recv_cq = adapter_->cq_;
+ attr.cap.max_send_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
+ attr.cap.max_recv_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
+ attr.cap.max_send_sge = 1;
+ attr.cap.max_recv_sge = 1;
+ attr.qp_type = IBV_QPT_RC;
+
+ qp_ = ibv_create_qp(adapter_->pd_, &attr);
+ CHECK(qp_) << "Failed to create queue pair";
+ }
+
+ // Init queue pair
+ {
+ struct ibv_qp_attr attr;
+ memset(&attr, 0, sizeof(ibv_qp_attr));
+ attr.qp_state = IBV_QPS_INIT;
+ attr.pkey_index = 0;
+ attr.port_num = 1;
+ attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
+
+ int mask =
+ IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;
+ CHECK(!ibv_modify_qp(qp_, &attr, mask)) << "Failed to set QP to INIT";
+ }
+
+ // Local address
+ {
+ struct ibv_port_attr attr;
+ CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr))
+ << "Query port";
+ self_.lid = attr.lid;
+ self_.qpn = qp_->qp_num;
+ self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
+ }
+
+ // create message and ack buffers, then initialize the tables.
+ {
+ const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer",
+ "tx_ack_buffer", "rx_ack_buffer"};
+ tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]);
+ rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]);
+ tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]);
+ rx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[3]);
+ message_buffers_.reserve(kNumMessageBuffers);
+ message_buffers_.push_back(tx_message_buffer_);
+ message_buffers_.push_back(rx_message_buffer_);
+ message_buffers_.push_back(tx_ack_buffer_);
+ message_buffers_.push_back(rx_ack_buffer_);
+ // create buffer on host
+ tx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
+ rx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
+ tx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize);
+ rx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize);
+ // bt_mu_.lock() is not used in constructor.
+ for (int i = 0; i < kNumMessageBuffers; i++) {
+ uint32_t index = NameHash(buffer_names[i]);
+ buffer_table_.insert({index, message_buffers_[i]});
+ buffer_index_name_table_.insert({index, buffer_names[i]});
+ buffer_name_index_table_.insert({buffer_names[i], index});
+ }
+
+ // Initiate recv
+ for (int i = 0; i < 100; i++) {
+ Recv();
+ }
+ }
+}
+
+RdmaChannel::~RdmaChannel() {
+ CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP";
+ delete tx_message_buffer_;
+ delete rx_message_buffer_;
+ delete tx_ack_buffer_;
+ delete rx_ack_buffer_;
+}
+
+void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
+ mutex_lock lock{mu_};
+ if ((override) || (!remote_set_)) {
+ remote_.lid = ra.lid;
+ remote_.qpn = ra.qpn;
+ remote_.psn = ra.psn;
+ remote_set_ = true;
+ } else {
+ CHECK(remote_.lid == ra.lid);
+ CHECK(remote_.qpn == ra.qpn);
+ CHECK(remote_.psn == ra.psn);
+ }
+}
+
+// Adding tokens to the completion queue
+// Tokens are needed to process future messages.
+void RdmaChannel::Recv() {
+ struct ibv_recv_wr wr;
+ memset(&wr, 0, sizeof(wr));
+ wr.wr_id = (uint64_t)this;
+ struct ibv_recv_wr* bad_wr;
+ CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
+}
+
+// Lookup 32-bit buffer index from buffer name
+// Args:
+// buffer_name: name of the buffer
+// Returns:
+// 32-bit index
+uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name) {
+ mutex_lock lock{bt_mu_};
+ BufferNameIndexTable::iterator iter =
+ buffer_name_index_table_.find(buffer_name);
+ CHECK(iter != buffer_name_index_table_.end());
+ return iter->second;
+}
+
+// Find a buffer by its 32-bit index
+// Args:
+// index: 32-bit hash code of the tensor buffer name
+// Returns:
+// name of the tensor buffer
+RdmaBuffer* RdmaChannel::FindBuffer(const uint32_t index) {
+ mutex_lock lock{bt_mu_};
+ BufferTable::iterator iter = buffer_table_.find(index);
+ CHECK(iter != buffer_table_.end());
+ return iter->second;
+}
+
+// Find a buffer by its name
+// Args:
+// name: name of the buffer
+// Returns:
+// the named rdma buffer
+RdmaBuffer* RdmaChannel::FindBuffer(const string& name) {
+ uint32_t index = LookupBufferIndex(name);
+ return FindBuffer(index);
+}
+
+// Find a buffer if it exists, otherwise create one.
+// The memory inside the created buffer is not allocated.
+// Args:
+// name: the name of the buffer
+// buffer_type: TENSOR, MESSAGE or ACK.
+// Returns:
+// the named buffer
+RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
+ BufferType buffer_type) {
+ mutex_lock lock{bt_mu_};
+ RdmaBuffer* rb;
+ // find index
+ BufferNameIndexTable::iterator iter = buffer_name_index_table_.find(name);
+ if (iter != buffer_name_index_table_.end()) {
+ uint32_t index = iter->second;
+ // find buffer
+ BufferTable::iterator iter = buffer_table_.find(index);
+ CHECK(iter != buffer_table_.end());
+ rb = iter->second;
+ } else {
+ uint32_t index = NameHash(name);
+ if (buffer_type == TENSOR) {
+ rb = new RdmaTensorBuffer(this, name);
+ } else if (buffer_type == MESSAGE) {
+ rb = new RdmaMessageBuffer(this, name);
+ } else if (buffer_type == ACK) {
+ rb = new RdmaAckBuffer(this, name);
+ }
+ buffer_name_index_table_.insert({name, index});
+ buffer_index_name_table_.insert({index, name});
+ buffer_table_.insert({index, rb});
+ }
+ CHECK(rb);
+ return rb;
+}
+
+// Insert callback to the callback_table.
+// The callback is activated when the corresponding tensor is received.
+// Arg:
+// key: the name of the tensor
+// recv_done: the callback associated with the tensor.
+// Returns:
+// None
+void RdmaChannel::InsertRecvCallback(const string& key,
+ std::function<void()> recv_done) {
+ mutex_lock lock{ct_mu_};
+ callback_table_.insert({key, recv_done});
+}
+
+// Remove callback from the callback_table.
+// Arg:
+// key: the name of the tensor
+// Returns:
+// None
+void RdmaChannel::RemoveRecvCallback(const string& key) {
+ mutex_lock lock{ct_mu_};
+ callback_table_.erase(key);
+}
+
+// Run named callback in the callback_table.
+// Arg:
+// key: the name of the tensor
+// Returns:
+// None
+void RdmaChannel::RunRecvCallback(const string& key) {
+ std::function<void()> recv_done;
+ {
+ mutex_lock lock{ct_mu_};
+ CallbackTable::iterator iter = callback_table_.find(key);
+ CHECK(iter != callback_table_.end());
+ recv_done = iter->second;
+ }
+ recv_done();
+}
+
+void RdmaChannel::Connect() {
+ {
+ mutex_lock lock{mu_};
+ CHECK(remote_set_) << "remote channel is not set";
+ }
+ Connect(remote_);
+}
+
+// Setup channel to a remote node
+// Args:
+// remoteAddr: the rdma address of a remote channel.
+// Returns:
+// None
+void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
+ mutex_lock lock{mu_};
+ if (!connected_) {
+ struct ibv_qp_attr attr;
+ memset(&attr, 0, sizeof(ibv_qp_attr));
+ attr.qp_state = IBV_QPS_RTR;
+ attr.path_mtu = IBV_MTU_4096;
+ attr.dest_qp_num = remoteAddr.qpn;
+ attr.rq_psn = remoteAddr.psn;
+ attr.max_dest_rd_atomic = 1;
+ attr.min_rnr_timer = 12;
+ attr.ah_attr.is_global = 0;
+ attr.ah_attr.dlid = remoteAddr.lid;
+ attr.ah_attr.sl = 0;
+ attr.ah_attr.src_path_bits = 0;
+ attr.ah_attr.port_num = 1;
+
+ int r;
+ CHECK(!(r = ibv_modify_qp(qp_, &attr,
+ IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
+ IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
+ IBV_QP_MAX_DEST_RD_ATOMIC |
+ IBV_QP_MIN_RNR_TIMER)))
+ << "QP to Ready to Receive " << r;
+
+ memset(&attr, 0, sizeof(ibv_qp_attr));
+ attr.qp_state = IBV_QPS_RTS;
+ attr.sq_psn = self_.psn;
+ attr.timeout = 14;
+ attr.retry_cnt = 7;
+ attr.rnr_retry = 7; /* infinite */
+ attr.max_rd_atomic = 1;
+
+ CHECK(!(r = ibv_modify_qp(qp_, &attr,
+ IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
+ IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
+ IBV_QP_MAX_QP_RD_ATOMIC)))
+ << "QP to Ready to Send " << r;
+
+ connected_ = true;
+ } else {
+ LOG(INFO) << "channel already connected";
+ }
+}
+
+RdmaBuffer::RdmaBuffer(RdmaChannel* channel, string name)
+ : channel_(channel), name_(name) {}
+
+RdmaBuffer::~RdmaBuffer() {
+ CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
+ FreeBuffer();
+}
+
+void RdmaBuffer::FreeBuffer() {
+ if ((buffer_ != nullptr) && buffer_on_host_) {
+ free(buffer_);
+ }
+ // TODO
+ // release buffer if it is on device.
+ // We don't support RDMABuffer on device at this moment.
+}
+
+// Allocate CPU memory for the Rdma buffer
+// Args:
+// size: to-be-allocated memory size
+// lock: whether or not mutex_lock the process to protect concurrency.
+// Returns:
+// None
+void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) {
+ CHECK(size > 0);
+ if (lock) {
+ mu_.lock();
+ }
+ if (local_status_ != none) {
+ // delete existing buffer
+ CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
+ FreeBuffer();
+ }
+ size_ = size;
+ buffer_ = malloc(size_);
+ self_ = ibv_reg_mr(channel_->adapter_->pd_, buffer_, size_,
+ IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
+ CHECK(self_) << "Failed to register memory region";
+ buffer_on_host_ = true;
+ local_status_ = idle;
+ if (lock) {
+ mu_.unlock();
+ }
+}
+
+// Set address of remote memory region
+// Args:
+// rmr: address of remote memory region
+// override: whether override existing information
+// Returns:
+// None
+void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
+ mutex_lock lock{mu_};
+ if ((override) || (remote_status_ == none)) {
+ remote_.remote_addr = rmr.remote_addr;
+ remote_.rkey = rmr.rkey;
+ remote_status_ = idle;
+ } else {
+ CHECK(remote_.remote_addr == rmr.remote_addr);
+ CHECK(remote_.rkey == rmr.rkey);
+ }
+}
+
+// Put a task in the buffer's job queue
+void RdmaBuffer::EnqueueItem(string item) {
+ mutex_lock lock{mu_};
+ queue_.push(item);
+}
+
+// Rdma-Write the content of the buffer
+void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) {
+ struct ibv_sge list;
+ list.addr = (uint64_t)buffer_;
+ list.length = buffer_size;
+ list.lkey = self_->lkey;
+
+ struct ibv_send_wr wr;
+ memset(&wr, 0, sizeof(wr));
+ wr.wr_id = (uint64_t)this;
+ wr.sg_list = &list;
+ wr.num_sge = 1;
+ wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
+ wr.send_flags = IBV_SEND_SIGNALED;
+ wr.imm_data = imm_data;
+ wr.wr.rdma.remote_addr = (uint64_t)remote_.remote_addr;
+ wr.wr.rdma.rkey = remote_.rkey;
+
+ struct ibv_send_wr* bad_wr;
+ CHECK(!ibv_post_send(channel_->qp_, &wr, &bad_wr)) << "Failed to post send";
+}
+
+RdmaAckBuffer::RdmaAckBuffer(RdmaChannel* channel, string name)
+ : RdmaBuffer(channel, name) {}
+
+RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
+ : RdmaBuffer(channel, name) {}
+
+RdmaTensorBuffer::RdmaTensorBuffer(RdmaChannel* channel, string name)
+ : RdmaBuffer(channel, name) {}
+
+// Send the next ack from the buffer's job queue.
+void RdmaAckBuffer::SendNextItem() {
+ uint32_t imm_data = LookupBufferIndex("rx_ack_buffer");
+ RdmaMessage rm;
+ rm.name_ = "rx_ack_buffer";
+ rm.type_ = RDMA_MESSAGE_ACK;
+ rm.name_size_ = rm.name_.size();
+ string message = RdmaMessage::CreateMessage(rm);
+ memcpy(buffer_, message.data(), message.size());
+ Write(imm_data, message.size());
+}
+
+// Send the next message from the buffer's job queue.
+void RdmaMessageBuffer::SendNextItem() {
+ uint32_t imm_data = LookupBufferIndex("rx_message_buffer");
+ mu_.lock();
+ if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) {
+ local_status_ = busy;
+ remote_status_ = busy;
+ string message = queue_.front();
+ queue_.pop();
+ // local/remote_status_ won't be set back to idle
+ // unitl Write() is successful
+ mu_.unlock();
+ memcpy(buffer_, message.data(), message.size());
+ Write(imm_data, message.size());
+ } else {
+ mu_.unlock();
+ }
+}
+
+// Send the next tensor from the buffer's job queue.
+void RdmaTensorBuffer::SendNextItem() {
+ // get the key
+ string key_with_step_id = "";
+ {
+ mutex_lock lock{mu_};
+ if (!queue_.empty()) {
+ key_with_step_id = queue_.front();
+ queue_.pop();
+ }
+ }
+ // send the tensor if a key is acquired.
+ if (key_with_step_id != "") {
+ VLOG(2) << "try to send tensor: " << key_with_step_id;
+ string key;
+ int64 step_id;
+ VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id);
+ CHECK(key.compare(name_) == 0);
+ Rendezvous::ParsedKey parsed;
+ Rendezvous::ParseKey(key, &parsed);
+ Rendezvous::DoneCallback cb = [this, key_with_step_id, key, step_id,
+ parsed](const Status& status,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& in, bool is_dead) {
+ CHECK(status.ok()) << "RecvLocalAsync was not ok, key" << key_with_step_id
+ << " error message: " << status.error_message();
+ size_t buffer_size = RdmaMessage::kMessageTotalBytes;
+ size_t tensor_bytes = 0;
+ TensorProto proto;
+ // Figures out which device the tensor is hosted on.
+ Device* src_dev = nullptr;
+ Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice(
+ parsed.src_device, &src_dev);
+ CHECK(s.ok()) << "src device not found";
+ // Does the device have the right incarnation number we expect?
+ CHECK(src_dev->attributes().incarnation() == parsed.src_incarnation)
+ << "RecvTensor expects a different device incarnation: "
+ << parsed.src_incarnation << " vs. "
+ << src_dev->attributes().incarnation()
+ << ". Your worker job was probably restarted. Check your "
+ << "worker job for the reason why it was restarted.";
+ Device* dst_dev = nullptr;
+ // destination is on CPU.
+ s = channel_->adapter_->worker_env_->device_mgr->LookupDevice("CPU:0",
+ &dst_dev);
+ CHECK(s.ok()) << "dst device not found";
+ AllocatorAttributes dst_alloc_attr;
+ dst_alloc_attr.set_on_host(true);
+ // string tensor needs to be serialized
+ if (src_dev->tensorflow_gpu_device_info() &&
+ (!send_args.alloc_attrs.on_host())) {
+ CHECK(send_args.device_context)
+ << "send dev name: " << src_dev->name()
+ << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
+ // "val" is on a GPU. Uses GPUUtil to fill the proto.
+ s = VerbsUtil::SetProtoFromGPUSync(
+ in, src_dev, send_args.device_context, &proto, is_dead);
+ CHECK(s.ok()) << "set proto from gpu sync";
+ } else {
+ // tensor is in CPU memory.
+ in.AsProtoTensorContent(&proto);
+ }
+ tensor_bytes = proto.ByteSize();
+ // maybe some margin for string tensor?
+ buffer_size += tensor_bytes;
+ // prepare message
+ RdmaMessage rm;
+ rm.name_size_ = key.size();
+ rm.name_ = key;
+ rm.tensor_shape_ = in.shape();
+ rm.data_type_ = in.dtype();
+ rm.step_id_ = step_id;
+ rm.is_dead_ = is_dead;
+ rm.tensor_bytes_ = tensor_bytes;
+ rm.buffer_size_ = buffer_size;
+ mu_.lock();
+ if (local_status_ == none ||
+ (buffer_size > size_ && local_status_ == idle &&
+ remote_status_ == idle)) {
+ if ((local_status_ != none) && (buffer_size > size_)) {
+ CHECK(rm.data_type_ == DT_STRING)
+ << "Only string tensor allows to change size";
+ }
+ CreateCPUBuffer(buffer_size, false);
+ mu_.unlock();
+ // put back the key since it is not sent;
+ EnqueueItem(key_with_step_id);
+ // ask the remote to create the same buffer
+ rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST;
+ rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_);
+ rm.rkey_ = self_->rkey;
+ string message = RdmaMessage::CreateMessage(rm);
+ channel_->tx_message_buffer_->EnqueueItem(message);
+ channel_->tx_message_buffer_->SendNextItem();
+ } else if ((local_status_ == idle) && (remote_status_ == idle)) {
+ // both buffers are ready, send the tensor
+ local_status_ = busy;
+ remote_status_ = busy;
+ // local/remote_status_ won't be set back to idle
+ // unitl Write() is successful
+ mu_.unlock();
+ CHECK((buffer_size == size_ && rm.data_type_ != DT_STRING) ||
+ (buffer_size <= size_ && rm.data_type_ == DT_STRING))
+ << "tensor and buffer size do not agree!"
+ << " buffer_size = " << size_
+ << " requested tensor size = " << buffer_size << in.DebugString();
+ uint32_t imm_data = LookupBufferIndex(key);
+ rm.type_ = RDMA_MESSAGE_TENSOR_WRITE;
+ string message = RdmaMessage::CreateMessage(rm);
+ memcpy(buffer_, message.data(), message.size());
+ if (!is_dead) {
+ // copy the tensor buffer content
+ void* output =
+ static_cast<void*>(static_cast<char*>(buffer_) +
+ RdmaMessage::kTensorBufferStartIndex);
+ CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
+ proto.SerializeToArray(output, tensor_bytes);
+ } else {
+ buffer_size = RdmaMessage::kMessageTotalBytes;
+ }
+ Write(imm_data, buffer_size);
+ } else {
+ mu_.unlock();
+ // put back the key since it is not sent;
+ EnqueueItem(key_with_step_id);
+ }
+ };
+ // Use default session (legacy_session_)
+ // TODO use WorkerSessionForSession
+ // need to pass in session handle
+ channel_->adapter_->worker_env_->session_mgr->LegacySession()
+ ->rendezvous_mgr->RecvLocalAsync(step_id, parsed, cb);
+ }
+}
+
+// Create a RdmaMessage according to the pre-defined format
+// Args:
+// rm: the message structure
+// Returns:
+// message in string format
+string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
+ // Rdma Message format
+ // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
+ // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
+ // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
+ // ...| XB | XB | 8B |...
+ //
+ // ACK: type|13|"rx_ack_buffer"
+ // TENSOR_REQUEST: type|name_size|tensor_name|step_id
+ // TENSOR_WRITE: type|name_size|tensor_name|step_id|...|is_dead
+ // |data_type|tensor_shape|tensor_bytes
+ // BUFFER_IDLE: type|name_size|buffer_name
+ // BUFFER_REQUEST:
+ // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
+ // BUFFER_RESPONSE:
+ // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
+ char message[kMessageTotalBytes];
+ // type
+ message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff;
+ // size of name
+ memcpy(&message[kNameSizeStartIndex], &rm.name_size_, sizeof(rm.name_size_));
+ // name
+ memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
+ // buffer_size, remote_addr, rkey
+ if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
+ (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
+ memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_,
+ sizeof(rm.buffer_size_));
+ memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
+ sizeof(rm.remote_addr_));
+ memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
+ }
+ // step_id
+ if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
+ (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
+ memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_));
+ }
+ // is_dead, data_type, tensor_shape, tensor_bytes
+ if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
+ memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_));
+
+ memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
+ sizeof(rm.data_type_));
+ memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_,
+ sizeof(rm.tensor_shape_));
+ memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
+ sizeof(rm.tensor_bytes_));
+ }
+ return string(message, kMessageTotalBytes);
+}
+
+// Parse a RdmaMessage according to the pre-defined format
+// Args:
+// rm: the message structure where the parsed message will be saved
+// buffer: the place where the raw message is stored
+// Returns:
+// None
+void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
+ char* message = static_cast<char*>(buffer);
+ // type
+ rm.type_ = static_cast<RdmaMessageType>(message[kTypeStartIndex]);
+ // name_size_
+ memcpy(&rm.name_size_, &message[kNameSizeStartIndex], sizeof(rm.name_size_));
+ // name
+ rm.name_ = string(&message[kNameStartIndex], rm.name_size_);
+ // buffer_size, remote_addr, rkey
+ if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
+ (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
+ memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex],
+ sizeof(rm.buffer_size_));
+ memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex],
+ sizeof(rm.remote_addr_));
+ memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_));
+ }
+ // step_id
+ if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
+ (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
+ memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_));
+ }
+ // data_type, tensor_bytes, tensor_shape, is_dead
+ if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
+ memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_));
+ memcpy(&rm.data_type_, &message[kDataTypeStartIndex],
+ sizeof(rm.data_type_));
+ memcpy(&rm.tensor_shape_, &message[kTensorShapeStartIndex],
+ sizeof(rm.tensor_shape_));
+ memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex],
+ sizeof(rm.tensor_bytes_));
+ }
+}
+
+} // end namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h
new file mode 100644
index 0000000000..ae2aa63e3f
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma.h
@@ -0,0 +1,277 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include <infiniband/verbs.h>
+#include <cstring> // for memset
+#include <functional>
+#include <memory> // for shared_ptr
+#include <queue>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// structure to save the address of remote channels.
+struct RdmaAddress {
+ uint32_t lid;
+ uint32_t qpn;
+ uint32_t psn;
+};
+// structure to save information for remote memory regions.
+struct RemoteMR {
+ uint64_t remote_addr;
+ uint32_t rkey;
+};
+enum BufferStatus { none, idle, busy };
+enum Location { local, remote };
+enum BufferType { ACK, MESSAGE, TENSOR };
+enum RdmaMessageType {
+ RDMA_MESSAGE_ACK,
+ RDMA_MESSAGE_BUFFER_IDLE,
+ RDMA_MESSAGE_BUFFER_REQUEST,
+ RDMA_MESSAGE_BUFFER_RESPONSE,
+ RDMA_MESSAGE_TENSOR_REQUEST,
+ RDMA_MESSAGE_TENSOR_WRITE
+};
+class RdmaBuffer;
+// Class that represents the Rdma Adapter.
+// Responsible for creation of the completion queue, and handling
+// of work completions.
+class RdmaAdapter {
+ friend class RdmaChannel;
+ friend class RdmaBuffer;
+ friend class RdmaAckBuffer;
+ friend class RdmaMessageBuffer;
+ friend class RdmaTensorBuffer;
+ friend class RdmaMgr;
+ friend class RdmaRemoteRendezvous;
+
+ public:
+ RdmaAdapter(const WorkerEnv* worker_env);
+ ~RdmaAdapter();
+ // Adapter name, e.g. mlx5_0.
+ string name() const;
+ void Process_CQ();
+
+ protected:
+ static const int MAX_CONCURRENT_WRITES = 1000;
+ ibv_context* context_;
+ // ibverbs protection domain
+ ibv_pd* pd_;
+ // Completion event channel, to wait for work completions
+ ibv_comp_channel* event_channel_;
+ // Completion queue, to poll on work completions
+ ibv_cq* cq_;
+ // Pre-allocated work completions array used for polling
+ ibv_wc wc_[MAX_CONCURRENT_WRITES * 2];
+ // worker env for thread
+ const WorkerEnv* worker_env_;
+ // thread for cq.
+ std::unique_ptr<Thread> polling_thread_;
+};
+
+// Class that represents a connection to a remote Rdma peer.
+// Responsible for connecting queue pairs.
+class RdmaChannel {
+ friend class RdmaAdapter;
+ friend class RdmaBuffer;
+ friend class RdmaAckBuffer;
+ friend class RdmaMessageBuffer;
+ friend class RdmaTensorBuffer;
+ friend class RdmaMgr;
+ friend class RdmaRemoteRendezvous;
+
+ public:
+ explicit RdmaChannel(const RdmaAdapter* adapter, const string local_name,
+ const string remote_name_);
+ ~RdmaChannel();
+ inline const RdmaAddress& self() { return self_; }
+ RdmaAddress address() const;
+ inline const std::vector<RdmaBuffer*>& message_buffers() const {
+ return message_buffers_;
+ }
+ void Connect(const RdmaAddress& remoteAddr);
+ void Connect();
+ void Recv();
+ RdmaBuffer* FindBuffer(const uint32_t index);
+ RdmaBuffer* FindBuffer(const string& name);
+ RdmaBuffer* FindOrCreateBuffer(const string& name,
+ BufferType buffer_type = TENSOR);
+ uint32_t LookupBufferIndex(const string& buffer_name);
+ void SetRemoteAddress(const RdmaAddress& ra, bool override);
+ void InsertRecvCallback(const string& key, std::function<void()> recv_done);
+ void RemoveRecvCallback(const string& key);
+ void RunRecvCallback(const string& key);
+ static const int kNumMessageBuffers = 4;
+
+ protected:
+ const RdmaAdapter* adapter_;
+ RdmaAddress self_;
+ string local_name_;
+ string remote_name_;
+ ibv_qp* qp_;
+ mutex mu_;
+ bool connected_ GUARDED_BY(bt_mu_) = false;
+ RdmaAddress remote_ GUARDED_BY(bt_mu_);
+ bool remote_set_ GUARDED_BY(bt_mu_) = false;
+ mutex ct_mu_;
+ typedef std::unordered_map<string, std::function<void()> > CallbackTable;
+ CallbackTable callback_table_ GUARDED_BY(ct_mu_);
+ mutex bt_mu_;
+ typedef std::unordered_map<unsigned int, RdmaBuffer*> BufferTable;
+ BufferTable buffer_table_ GUARDED_BY(bt_mu_);
+ typedef std::unordered_map<uint32_t, string> BufferIndexNameTable;
+ BufferIndexNameTable buffer_index_name_table_ GUARDED_BY(bt_mu_);
+ typedef std::unordered_map<string, uint32_t> BufferNameIndexTable;
+ BufferNameIndexTable buffer_name_index_table_ GUARDED_BY(bt_mu_);
+ RdmaBuffer* tx_message_buffer_;
+ RdmaBuffer* rx_message_buffer_;
+ RdmaBuffer* tx_ack_buffer_;
+ RdmaBuffer* rx_ack_buffer_;
+ std::vector<RdmaBuffer*> message_buffers_;
+};
+
+// Class that represents a buffer for Rdma writes and reads.
+class RdmaBuffer {
+ friend class RdmaChannel;
+ friend class RdmaAdapter;
+ friend class RdmaMgr;
+ friend class RdmaRemoteRendezvous;
+
+ public:
+ explicit RdmaBuffer(RdmaChannel* channel, string name);
+ virtual ~RdmaBuffer();
+
+ inline void* buffer() const { return buffer_; }
+ inline ibv_mr* self() const { return self_; }
+ inline void SetBufferStatus(Location loc, BufferStatus status) {
+ mu_.lock();
+ if (loc == local) {
+ local_status_ = status;
+ } else {
+ remote_status_ = status;
+ }
+ mu_.unlock();
+ }
+ void FreeBuffer();
+ void EnqueueItem(string Item);
+ virtual void SendNextItem(){};
+ void CreateCPUBuffer(size_t size, bool lock = true);
+ void SetRemoteMR(RemoteMR rmi, bool override);
+ uint32_t LookupBufferIndex(const string& buffer_name) {
+ return const_cast<RdmaChannel*>(channel_)->LookupBufferIndex(buffer_name);
+ }
+ void Write(uint32_t imm_data, size_t buffer_size);
+
+ protected:
+ const RdmaChannel* channel_;
+ void* buffer_ = nullptr;
+ bool buffer_on_host_ = true;
+ size_t size_ = 0;
+ const string name_;
+ ibv_mr* self_ = nullptr;
+ mutex mu_;
+ RemoteMR remote_;
+ std::queue<string> queue_ GUARDED_BY(mu_);
+ BufferStatus local_status_ GUARDED_BY(mu_) = none;
+ BufferStatus remote_status_ GUARDED_BY(mu_) = none;
+};
+
+class RdmaAckBuffer : public RdmaBuffer {
+ public:
+ explicit RdmaAckBuffer(RdmaChannel* channel, string name);
+ virtual ~RdmaAckBuffer() override {}
+ void SendNextItem() override;
+};
+
+class RdmaMessageBuffer : public RdmaBuffer {
+ friend class RdmaChannel;
+ friend class RdmaAapater;
+
+ public:
+ explicit RdmaMessageBuffer(RdmaChannel* channel, string name);
+ virtual ~RdmaMessageBuffer() override {}
+ void SendNextItem() override;
+};
+
+class RdmaTensorBuffer : public RdmaBuffer {
+ public:
+ explicit RdmaTensorBuffer(RdmaChannel* channel, string name);
+ virtual ~RdmaTensorBuffer() override {}
+ void SendNextItem() override;
+};
+
+struct RdmaMessage {
+ RdmaMessageType type_;
+ uint16_t name_size_;
+ string name_;
+ int64 step_id_;
+ uint64_t buffer_size_;
+ uint64_t remote_addr_;
+ uint32_t rkey_;
+ bool is_dead_;
+ DataType data_type_;
+ TensorShape tensor_shape_;
+ size_t tensor_bytes_;
+
+ // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
+ // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
+ // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
+ // ...| XB | XB | 8B |...
+ //
+ static const size_t kNameCapacity = 512;
+ static const size_t kTypeStartIndex = 0;
+ static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_);
+ static const size_t kNameStartIndex =
+ kNameSizeStartIndex + sizeof(name_size_);
+ static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity;
+ static const size_t kBufferSizeStartIndex =
+ kStepIdStartIndex + sizeof(step_id_);
+ static const size_t kRemoteAddrStartIndex =
+ kBufferSizeStartIndex + sizeof(buffer_size_);
+ static const size_t kRkeyStartIndex =
+ kRemoteAddrStartIndex + sizeof(remote_addr_);
+ static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_);
+ static const size_t kDataTypeStartIndex =
+ kIsDeadStartIndex + sizeof(is_dead_);
+ static const size_t kTensorShapeStartIndex =
+ kDataTypeStartIndex + sizeof(data_type_);
+ static const size_t kTensorBytesStartIndex =
+ kTensorShapeStartIndex + sizeof(TensorShape);
+ static const size_t kTensorBufferStartIndex =
+ kTensorBytesStartIndex + sizeof(tensor_bytes_);
+ static const size_t kMessageTotalBytes = kTensorBufferStartIndex;
+ static const size_t kRdmaMessageBufferSize = kMessageTotalBytes;
+ static const size_t kRdmaAckBufferSize = kMessageTotalBytes;
+ static string CreateMessage(const RdmaMessage& rm);
+ static void ParseMessage(RdmaMessage& rm, void* buffer);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_VERBS
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc
new file mode 100644
index 0000000000..e28b80c6f6
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma_mgr.cc
@@ -0,0 +1,133 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include <vector>
+#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
+ GrpcChannelCache* const channel_cache)
+ : worker_env_(worker_env), channel_cache_(channel_cache) {
+ rdma_adapter_ = new RdmaAdapter(worker_env_);
+ // hardcoded to default session (legacy_session_)
+ // TODO: use WorkerSessionForSession
+ // need to pass in session handle
+ local_worker_ = worker_env_->session_mgr->LegacySession()->worker_name;
+ std::vector<string> workers;
+ worker_env_->session_mgr->LegacySession()->worker_cache->ListWorkers(
+ &workers);
+ num_remote_workers_ = workers.size() - 1;
+ VLOG(2) << "rmda_mgr on local worker: " << local_worker_;
+ for (size_t i = 0; i < workers.size(); i++) {
+ if (local_worker_.compare(workers[i]) != 0) {
+ channel_table_.insert(
+ {workers[i],
+ new RdmaChannel(rdma_adapter_, local_worker_, workers[i])});
+ }
+ }
+}
+
+// Setup Rdma channels between peers.
+// This is done at the beginning of the server setup.
+
+void RdmaMgr::SetupChannels() {
+ for (const auto& p : channel_table_) {
+ string worker_name = p.first;
+ LOG(INFO) << "connecting to remote node " << worker_name;
+ RdmaChannel* rc = p.second;
+ GetRemoteAddressRequest req;
+ GetRemoteAddressResponse resp;
+ // get the channel cache
+ SharedGrpcChannelPtr client_channel =
+ channel_cache_->FindWorkerChannel(worker_name);
+ GrpcVerbsClient* client = new GrpcVerbsClient(client_channel);
+ CHECK(client != nullptr) << "No worker known as " << worker_name;
+
+ // setting up request
+ req.set_host_name(local_worker_);
+ Channel* channel_info = req.mutable_channel();
+ channel_info->set_lid(rc->self_.lid);
+ channel_info->set_qpn(rc->self_.qpn);
+ channel_info->set_psn(rc->self_.psn);
+ for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
+ MemoryRegion* mr = req.add_mr();
+ mr->set_remote_addr(
+ reinterpret_cast<uint64_t>(rc->message_buffers_[i]->buffer_));
+ mr->set_rkey(rc->message_buffers_[i]->self_->rkey);
+ }
+ // synchronous call
+ Status s = client->GetRemoteAddress(&req, &resp);
+ // save obtained remote addresses
+ // connect to the remote channel
+ if (s.ok()) {
+ CHECK(worker_name.compare(resp.host_name()) == 0);
+ RdmaAddress ra;
+ ra.lid = resp.channel().lid();
+ ra.qpn = resp.channel().qpn();
+ ra.psn = resp.channel().psn();
+ rc->SetRemoteAddress(ra, false);
+ rc->Connect();
+ int i = 0;
+ int idx[] = {1, 0, 3, 2};
+ for (const auto& mr : resp.mr()) {
+ // the connections are crossed, i.e.
+ // local tx_message_buffer <---> remote rx_message_buffer_
+ // local rx_message_buffer <---> remote tx_message_buffer_
+ // local tx_ack_buffer <---> remote rx_ack_buffer_
+ // local rx_ack_buffer <---> remote tx_ack_buffer_
+ // hence idx[] = {1, 0, 3, 2}.
+ RdmaBuffer* rb = rc->message_buffers_[idx[i]];
+ RemoteMR rmr;
+ rmr.remote_addr = mr.remote_addr();
+ rmr.rkey = mr.rkey();
+ rb->SetRemoteMR(rmr, false);
+ i++;
+ }
+ CHECK(i == RdmaChannel::kNumMessageBuffers);
+ } else {
+ LOG(ERROR) << s.error_message();
+ }
+ delete client;
+ }
+}
+
+RdmaMgr::~RdmaMgr() {
+ for (const auto& p : channel_table_) delete p.second;
+ channel_table_.clear();
+ delete rdma_adapter_;
+}
+
+// Find a channel via the given name.
+// Args:
+// name: peer name, e.g. worker1
+// Returns
+// channel object that is connected to the named peer.
+RdmaChannel* RdmaMgr::FindChannel(const string& name) {
+ ChannelTable::iterator iter = channel_table_.find(name);
+ CHECK(iter != channel_table_.end());
+ return iter->second;
+}
+
+} // end namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/verbs/rdma_mgr.h b/tensorflow/contrib/verbs/rdma_mgr.h
new file mode 100644
index 0000000000..b156f64096
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma_mgr.h
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/contrib/verbs/rdma.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+
+namespace tensorflow {
+
+class RdmaMgr {
+ public:
+ explicit RdmaMgr(const WorkerEnv* const worker_env,
+ GrpcChannelCache* const channel_cache);
+ ~RdmaMgr();
+ RdmaChannel* FindChannel(const string& key);
+ void SetupChannels();
+ const string& local_worker() { return local_worker_; }
+
+ private:
+ string local_worker_;
+ size_t num_remote_workers_;
+ const WorkerEnv* const worker_env_;
+ GrpcChannelCache* const channel_cache_;
+ RdmaAdapter* rdma_adapter_;
+ typedef std::unordered_map<string, RdmaChannel*> ChannelTable;
+ ChannelTable channel_table_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RdmaMgr);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_VERBS
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
new file mode 100644
index 0000000000..8cbdfaa943
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -0,0 +1,149 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
+#include <unordered_set>
+#include "tensorflow/contrib/verbs/verbs_util.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+class RdmaRemoteRendezvous : public BaseRemoteRendezvous {
+ public:
+ RdmaRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
+ int64 step_id, RdmaMgr* rdma_mgr)
+ : BaseRemoteRendezvous(env, worker_name, step_id, true),
+ rdma_mgr_(rdma_mgr) {}
+
+ protected:
+ void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
+ const Rendezvous::Args& args,
+ DoneCallback done) override;
+
+ private:
+ ~RdmaRemoteRendezvous() override {}
+ RdmaMgr* rdma_mgr_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RdmaRemoteRendezvous);
+};
+
+void RdmaRemoteRendezvous::RecvFromRemoteAsync(
+ const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
+ DoneCallback done) {
+ Status s;
+ // parse src_name and dst_name
+ string src_name, dst_name, unused;
+ if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_name,
+ &unused)) {
+ s = errors::Internal("Could not parse src name.");
+ }
+ CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor{}, false);
+ return;
+ }
+ if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name,
+ &unused)) {
+ s = errors::Internal("Could not parse dst name.");
+ }
+ CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor{}, false);
+ return;
+ }
+ CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0);
+ RdmaChannel* rc = rdma_mgr_->FindChannel(src_name);
+ string key(std::move(parsed.FullKey().ToString()));
+ string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
+ // insert callback
+ rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc,
+ recv_args, parsed, done]() {
+ Status s;
+ Device* src_dev;
+ s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
+ CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor(), true);
+ return;
+ }
+ Device* dst_dev;
+ s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
+ CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor(), true);
+ return;
+ }
+ RdmaBuffer* rb = rc->FindBuffer(key);
+ RdmaMessage rm;
+ CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes);
+ RdmaMessage::ParseMessage(rm, rb->buffer_);
+ CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE);
+ Tensor val;
+ if (!rm.is_dead_) {
+ void* input = static_cast<char*>(rb->buffer_) +
+ RdmaMessage::kTensorBufferStartIndex;
+ TensorProto proto;
+ CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
+ rb->size_);
+ CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
+ << "fail to parse proto from array";
+ s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
+ }
+
+ rc->RemoveRecvCallback(key_with_step_id);
+ // create message
+ RdmaMessage br;
+ br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
+ br.name_size_ = key.size();
+ br.name_ = key;
+ string message = RdmaMessage::CreateMessage(br);
+ RdmaBuffer* tb = rc->tx_message_buffer_;
+ tb->EnqueueItem(message);
+ tb->SendNextItem();
+ done(s, Args(), recv_args, val, rm.is_dead_);
+ });
+ // append key to message queue
+ RdmaBuffer* rb = rc->tx_message_buffer_;
+ RdmaMessage rm;
+ rm.type_ = RDMA_MESSAGE_TENSOR_REQUEST;
+ rm.name_size_ = key.size();
+ rm.name_ = key;
+ rm.step_id_ = step_id_;
+ string message = RdmaMessage::CreateMessage(rm);
+ rb->EnqueueItem(message);
+ rb->SendNextItem();
+}
+
+RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env,
+ const string& worker_name,
+ WorkerCacheInterface* worker_cache)
+ : BaseRendezvousMgr(env, worker_name) {}
+
+BaseRemoteRendezvous* RdmaRendezvousMgr::Create(int64 step_id,
+ const WorkerEnv* worker_env,
+ const string& worker_name) {
+ return new RdmaRemoteRendezvous(worker_env, worker_name, step_id, rdma_mgr_);
+}
+
+} // end namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h
new file mode 100644
index 0000000000..57cd4bf5e4
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h
@@ -0,0 +1,64 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+// RendezvousMgr keeps track of a set of local rendezvous instances.
+// All tensors sent by this worker are buffered in a RendezvousMgr
+// until the tensor is received. Each global unique "step_id"
+// corresponds to one local rendezvous instance managed by a
+// RendezvousMgr.
+//
+// E.g.,
+// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
+// fork execution of an graph executor using "rendez" on thread 1;
+// fork execution of another graph executor using "rendez" on thread 2;
+// ...
+// join threads 1 and 2;
+//
+// In the example above, execution in thread 1 and 2 communicates with
+// each other by send/recv operations through the "rend".
+//
+// Tensors sent and recved through rendezvous managed by this
+// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
+class RdmaRendezvousMgr : public BaseRendezvousMgr {
+ public:
+ explicit RdmaRendezvousMgr(const WorkerEnv* env, const string& worker_name,
+ WorkerCacheInterface* worker_cache);
+ void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; }
+
+ protected:
+ BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env,
+ const string& worker_name) override;
+
+ private:
+ RdmaMgr* rdma_mgr_;
+ TF_DISALLOW_COPY_AND_ASSIGN(RdmaRendezvousMgr);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_USE_VERBS
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc
new file mode 100644
index 0000000000..b061c81d2d
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_server_lib.cc
@@ -0,0 +1,172 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/verbs_server_lib.h"
+
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+namespace {
+// static utility function
+RendezvousMgrInterface* NewRdmaRendezvousMgr(
+ const WorkerEnv* env, const string& worker_name,
+ WorkerCacheInterface* worker_cache) {
+ return new RdmaRendezvousMgr(env, worker_name, worker_cache);
+}
+
+} // namespace
+
+VerbsServer::VerbsServer(const ServerDef& server_def, Env* env)
+ : GrpcServer(server_def, env), verbs_state_(DISCONNECTED) {}
+
+VerbsServer::~VerbsServer() {
+ TF_CHECK_OK(Stop());
+ TF_CHECK_OK(Join());
+ delete rdma_mgr_;
+ delete verbs_service_;
+ delete channel_cache_;
+}
+
+Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
+ GrpcChannelCache** channel_cache) {
+ string name_prefix =
+ strings::StrCat("/job:", server_def.job_name(), "/replica:0",
+ "/task:", server_def.task_index());
+
+ GrpcChannelSpec channel_spec;
+ TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
+
+ *channel_cache =
+ NewGrpcChannelCache(channel_spec, GetChannelCreationFunction(server_def));
+
+ const string host_port = (*channel_cache)->TranslateTask(name_prefix);
+ int requested_port;
+
+ if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
+ &requested_port)) {
+ return errors::Internal("Could not parse port for local server from \"",
+ (*channel_cache)->TranslateTask(name_prefix),
+ "\".");
+ }
+ if (requested_port != bound_port()) {
+ return errors::InvalidArgument("Requested port ", requested_port,
+ " differs from expected port ",
+ bound_port());
+ }
+
+ return Status::OK();
+}
+
+Status VerbsServer::Init(ServiceInitFunction service_func,
+ RendezvousMgrCreationFunction rendezvous_mgr_func) {
+ Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
+ {
+ mutex_lock l(mu_);
+ CHECK_EQ(verbs_state_, DISCONNECTED);
+ CHECK(ChannelCacheFactory(server_def(), &channel_cache_).ok());
+ rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_);
+ // set rdma_mgr for verbs_service and rdma_rendezvous_mgr
+ verbs_service_->SetRdmaMgr(rdma_mgr_);
+ // hardcoded to default session (legacy_session_)
+ // TODO: use WorkerSessionForSession
+ // need to pass in session handle
+ dynamic_cast<RdmaRendezvousMgr*>(
+ worker_env()->session_mgr->LegacySession()->rendezvous_mgr.get())
+ ->SetRdmaMgr(rdma_mgr_);
+ }
+ return s;
+}
+
+Status VerbsServer::Start() {
+ Status s = GrpcServer::Start();
+ {
+ mutex_lock l(mu_);
+ if (verbs_state_ == DISCONNECTED) {
+ // verbs_thread needs to be initiated
+ // before rdma_mgr sets up the rdma channels.
+ verbs_thread_.reset(worker_env()->env->StartThread(
+ ThreadOptions(), "TF_verbs_service",
+ [this] { verbs_service_->HandleRPCsLoop(); }));
+ rdma_mgr_->SetupChannels();
+ verbs_state_ = CONNECTED;
+ }
+ }
+ return s;
+}
+
+Status VerbsServer::Join() {
+ Status s = GrpcServer::Join();
+ {
+ mutex_lock l(mu_);
+ if (verbs_state_ == CONNECTED) {
+ verbs_state_ = DISCONNECTED;
+ verbs_thread_.reset();
+ }
+ }
+ return s;
+}
+
+/* static */
+Status VerbsServer::Create(const ServerDef& server_def, Env* env,
+ std::unique_ptr<ServerInterface>* out_server) {
+ std::unique_ptr<VerbsServer> ret(new VerbsServer(server_def, Env::Default()));
+ ServiceInitFunction service_func = [&ret](const WorkerEnv* worker_env,
+ ::grpc::ServerBuilder* builder) {
+ return SetNewVerbsService(&ret->verbs_service_, worker_env, builder);
+ };
+ TF_RETURN_IF_ERROR(ret->Init(service_func, NewRdmaRendezvousMgr));
+ *out_server = std::move(ret);
+ return Status::OK();
+}
+
+namespace {
+
+class VerbsServerFactory : public ServerFactory {
+ public:
+ bool AcceptsOptions(const ServerDef& server_def) override {
+ return server_def.protocol() == "grpc+verbs";
+ }
+
+ Status NewServer(const ServerDef& server_def,
+ std::unique_ptr<ServerInterface>* out_server) override {
+ return VerbsServer::Create(server_def, Env::Default(), out_server);
+ }
+};
+
+// Registers a `ServerFactory` for `VerbsServer` instances.
+class VerbsServerRegistrar {
+ public:
+ VerbsServerRegistrar() {
+ gpr_allocation_functions alloc_fns;
+ alloc_fns.malloc_fn = port::Malloc;
+ alloc_fns.realloc_fn = port::Realloc;
+ alloc_fns.free_fn = port::Free;
+ gpr_set_allocation_functions(alloc_fns);
+ ServerFactory::Register("VERBS_SERVER", new VerbsServerFactory());
+ }
+};
+static VerbsServerRegistrar registrar;
+
+} // namespace
+} // namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.h b/tensorflow/contrib/verbs/verbs_server_lib.h
new file mode 100644
index 0000000000..855380129f
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_server_lib.h
@@ -0,0 +1,66 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
+
+namespace tensorflow {
+
+class VerbsServer : public GrpcServer {
+ protected:
+ VerbsServer(const ServerDef& server_def, Env* env);
+
+ public:
+ static Status Create(const ServerDef& server_def, Env* env,
+ std::unique_ptr<ServerInterface>* out_server);
+
+ // Destruction is only supported in the factory method. Clean
+ // shutdown is not currently implemented for this server type.
+ virtual ~VerbsServer() override;
+
+ // Implementations of ServerInterface methods.
+ Status Start() override;
+ Status Join() override;
+
+ protected:
+ Status Init(ServiceInitFunction service_func,
+ RendezvousMgrCreationFunction rendezvous_mgr_func);
+ Status ChannelCacheFactory(const ServerDef& server_def,
+ GrpcChannelCache** channel_cache);
+
+ private:
+ RdmaMgr* rdma_mgr_;
+
+ // Guards state transitions.
+ mutex mu_;
+
+ enum State { DISCONNECTED, CONNECTED };
+ State verbs_state_ GUARDED_BY(mu_);
+
+ GrpcVerbsService* verbs_service_ = nullptr;
+ std::unique_ptr<Thread> verbs_thread_ GUARDED_BY(mu_);
+ GrpcChannelCache* channel_cache_ = nullptr;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_VERBS
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
diff --git a/tensorflow/contrib/verbs/verbs_service.proto b/tensorflow/contrib/verbs/verbs_service.proto
new file mode 100644
index 0000000000..b985febfb8
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_service.proto
@@ -0,0 +1,60 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow;
+option java_outer_classname = "VerbsServiceProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.contrib.verbs";
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// GRPC Helper messages used to exchange RDMA information.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message Channel {
+ int32 lid = 1;
+ int32 qpn = 2;
+ int32 psn = 3;
+}
+
+message MemoryRegion {
+ uint64 remote_addr = 1;
+ uint32 rkey = 2;
+}
+message GetRemoteAddressRequest {
+ string host_name = 1;
+ Channel channel = 2;
+ repeated MemoryRegion mr = 3;
+}
+
+message GetRemoteAddressResponse {
+ string host_name = 1;
+ Channel channel = 2;
+ repeated MemoryRegion mr = 3;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// VerbsService
+//
+////////////////////////////////////////////////////////////////////////////////
+
+service VerbsService {
+ rpc GetRemoteAddress(GetRemoteAddressRequest)
+ returns (GetRemoteAddressResponse);
+}
diff --git a/tensorflow/contrib/verbs/verbs_util.cc b/tensorflow/contrib/verbs/verbs_util.cc
new file mode 100644
index 0000000000..c3350f7958
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_util.cc
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/verbs/verbs_util.h"
+
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+namespace tensorflow {
+
+// static sync wrapper:
+Status VerbsUtil::SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
+ const DeviceContext* device_context,
+ TensorProto* proto, bool is_dead) {
+ Notification n;
+ Status status;
+ GPUUtil::SetProtoFromGPU(tensor, dev, device_context, proto, is_dead,
+ [&n, &status](const Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return status;
+}
+
+// static
+string VerbsUtil::AppendStepidToKey(const string& key, int64 step_id) {
+ return strings::StrCat(key, ";", step_id);
+}
+
+// static
+void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id, string& key,
+ int64& step_id) {
+ StringPiece s(key_with_step_id);
+ // a key (with step_id) has exact 6 parts if split by ";"
+ // part 1: src_device;
+ // part 2: src_incarnation;
+ // part 3: dst_device;
+ // part 4: name;
+ // part 5: frame_iter.frame_id:frame_iter.iter_id
+ // part 6: step_id
+ std::vector<string> parts = str_util::Split(s, ';');
+ CHECK(parts.size() == 6) << "Key with step_id must have 6 parts";
+ strings::safe_strto64(parts[5], &step_id);
+ parts.pop_back(); // remove step_id
+ key.assign(str_util::Join(parts, ";")); // stitch them together
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/verbs/verbs_util.h b/tensorflow/contrib/verbs/verbs_util.h
new file mode 100644
index 0000000000..cbc01adae4
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_util.h
@@ -0,0 +1,41 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_RDMA_UTIL_H_
+#define TENSORFLOW_CONTRIB_RDMA_UTIL_H_
+
+#include <string>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class TensorProto;
+
+class VerbsUtil {
+ public:
+ // synchronous wrapper of SetProtoFromGPU
+ static Status SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
+ const DeviceContext* device_context,
+ TensorProto* proto, bool is_dead);
+ static string AppendStepidToKey(const string& key, int64 step_id);
+ static void GetKeyAndStepId(const string& key_with_step_id, string& key,
+ int64& step_id);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_RDMA_UTIL_H_
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index d614349387..71fba99aad 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -108,6 +108,7 @@ load(
"tf_additional_cloud_op_deps",
"tf_additional_cloud_kernel_deps",
"tf_lib_proto_parsing_deps",
+ "tf_additional_verbs_lib_defines",
)
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
@@ -732,9 +733,13 @@ cc_library(
"//tensorflow/core/kernels:math_not_windows",
"//tensorflow/core/kernels:quantized_ops",
]) + if_mkl([
+ "//tensorflow/core/kernels:mkl_concat_op",
"//tensorflow/core/kernels:mkl_conv_op",
+ "//tensorflow/core/kernels:mkl_fused_batch_norm_op",
+ "//tensorflow/core/kernels:mkl_lrn_op",
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
+ "//tensorflow/core/kernels:mkl_reshape_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
]),
)
@@ -1272,7 +1277,9 @@ cc_library(
"platform/tracing.h",
],
copts = tf_copts(),
- defines = tf_additional_lib_defines() + ["SNAPPY"],
+ defines = tf_additional_lib_defines() + [
+ "SNAPPY",
+ ] + tf_additional_verbs_lib_defines(),
linkopts = select({
"//tensorflow:freebsd": [],
"//conditions:default": ["-ldl"],
@@ -2089,7 +2096,6 @@ tf_cc_test_mkl(
size = "small",
srcs = [
"graph/mkl_layout_pass_test.cc",
- "graph/mkl_optimizer_merge_test.cc",
"graph/mkl_tfconversion_pass_test.cc",
],
linkstatic = tf_kernel_tests_linkstatic(),
@@ -2110,9 +2116,13 @@ tf_cc_test_mkl(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops",
+ "//tensorflow/core/kernels:mkl_concat_op",
"//tensorflow/core/kernels:mkl_conv_op",
+ "//tensorflow/core/kernels:mkl_fused_batch_norm_op",
+ "//tensorflow/core/kernels:mkl_lrn_op",
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
+ "//tensorflow/core/kernels:mkl_reshape_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
"//tensorflow/core/kernels:ops_util",
"//third_party/eigen3",
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 2be510ee9b..a79ea1b45d 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -17,7 +17,9 @@ limitations under the License.
#include <vector>
+#if defined(PLATFORM_GOOGLE)
#include "grpc++/create_channel.h"
+#endif
#if defined(PLATFORM_WINDOWS)
// winsock2.h is used in grpc, so Ws2_32.lib is needed
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index 1aafa862cb..7160962b16 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -62,6 +62,13 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
plugins) override {}
};
+// static utility function
+RendezvousMgrInterface* NewRpcRendezvousMgr(
+ const WorkerEnv* env, const string& worker_name,
+ WorkerCacheInterface* worker_cache) {
+ return new RpcRendezvousMgr(env, worker_name, worker_cache);
+}
+
} // namespace
GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
@@ -93,7 +100,8 @@ GrpcServer::~GrpcServer() {
// - worker_env_.compute_pool
}
-Status GrpcServer::Init() {
+Status GrpcServer::Init(ServiceInitFunction service_func,
+ RendezvousMgrCreationFunction rendevous_mgr_func) {
mutex_lock l(mu_);
CHECK_EQ(state_, NEW);
master_env_.env = env_;
@@ -170,6 +178,10 @@ Status GrpcServer::Init() {
worker_impl_ = NewGrpcWorker(&worker_env_);
worker_service_ =
NewGrpcWorkerService(worker_impl_.get(), &builder).release();
+ // extra service:
+ if (service_func != nullptr) {
+ service_func(&worker_env_, &builder);
+ }
server_ = builder.BuildAndStart();
if (!server_) {
@@ -182,7 +194,9 @@ Status GrpcServer::Init() {
// Set up worker environment.
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
- new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache));
+ rendevous_mgr_func == nullptr ?
+ new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) :
+ rendevous_mgr_func(&worker_env_, name_prefix, worker_cache));
worker_env_.session_mgr = new SessionMgr(
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
std::unique_ptr<WorkerCacheInterface>(worker_cache),
@@ -211,6 +225,10 @@ Status GrpcServer::Init() {
return Status::OK();
}
+Status GrpcServer::Init() {
+ return Init(nullptr, nullptr);
+}
+
Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
GrpcChannelSpec* channel_spec) {
for (const auto& job : server_def.cluster().job()) {
@@ -248,6 +266,7 @@ Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def,
channel_spec, GetChannelCreationFunction(server_def)));
const string host_port = channel_cache->TranslateTask(name_prefix);
int requested_port;
+
if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
&requested_port)) {
return errors::Internal("Could not parse port for local server from \"",
@@ -346,7 +365,8 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env,
std::unique_ptr<ServerInterface>* out_server) {
std::unique_ptr<GrpcServer> ret(
new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
- TF_RETURN_IF_ERROR(ret->Init());
+ ServiceInitFunction service_func = nullptr;
+ TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr));
*out_server = std::move(ret);
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index c6ba260104..3b66291a9a 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -36,6 +36,17 @@ namespace tensorflow {
class GrpcWorker;
class Master;
+// function that creates a RendezvousMgr.
+typedef std::function<RendezvousMgrInterface*(
+ const WorkerEnv*, const std::string& worker_name,
+ WorkerCacheInterface* worker_cache)>
+ RendezvousMgrCreationFunction;
+
+// function that registers a service to the server. The service needs to
+// be registered before builder.BuildAndStart().
+typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
+ ServiceInitFunction;
+
class GrpcServer : public ServerInterface {
protected:
GrpcServer(const ServerDef& server_def, Env* env);
@@ -55,6 +66,9 @@ class GrpcServer : public ServerInterface {
const string target() const override;
protected:
+ Status Init(ServiceInitFunction service_func,
+ RendezvousMgrCreationFunction rendezvous_mgr_func);
+
Status Init();
// A subclass can override this method to support secure credentials.
@@ -78,6 +92,10 @@ class GrpcServer : public ServerInterface {
// This method may only be called after `this->Init()` returns successfully.
int bound_port() const { return bound_port_; }
+ WorkerEnv* worker_env() { return &worker_env_; }
+
+ const ServerDef& server_def() const { return server_def_; }
+
private:
// The overall server configuration.
const ServerDef server_def_;
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index fb1ad0102f..e45f156e1e 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -126,25 +126,33 @@ FunctionDef XTimes16() {
{{"y", "y:y:0"}});
}
-FunctionDef WXPlusB() {
- return FDH::Define(
- // Name
- "WXPlusB",
- // Args
- {"w: T", "x: T", "b: T"},
- // Return values
- {"y: T"},
- // Attr def
- {"T: {float, double}"},
- // Nodes
- {{{"mm"},
- "MatMul",
- {"w", "x"},
- {{"T", "$T"},
- {"transpose_a", false},
- {"transpose_b", false},
+FunctionDef WXPlusB(){return FDH::Define(
+ // Name
+ "WXPlusB",
+ // Args
+ {"w: T", "x: T", "b: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double}"},
+ // Nodes
+ {
+ {{"mm"},
+ "MatMul",
+ {"w", "x"},
+ {
+ {"T", "$T"}, {"transpose_a", false}, {"transpose_b", false},
+#ifdef INTEL_MKL
+ }},
+#else
{"_kernel", "eigen"}}},
- {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
+#endif
+ {
+ {"y"}, "Add", {"mm", "b"}, {
+ { "T", "$T" }
+ }
+ }
+ });
}
FunctionDef Swap() {
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 309c4cd774..09b632a165 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -48,7 +48,7 @@ namespace tensorflow {
// 1) Propagating Mkl layout as an additional output tensor
// (we will loosely call a tensor that carries Mkl layout as Mkl tensor
// henceforth.) from every Mkl supported NN layer.
-// 2) Context-based rewrite: This is neded in order to optimize
+// 2) Context-based rewrite: This is needed in order to optimize
// gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and
// MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into
// Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad.
@@ -63,12 +63,12 @@ namespace tensorflow {
// P = BiasAdd(O, C)
//
// We merge them into Conv2DWithBias as:
-// P = MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
+// P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
//
-// Meaning of A_m, B_m and C_m is explained in B.1.
+// The meaning of A_m, B_m and C_m is explained in B.1.
//
// Merge rules:
-// - Merge for Conv2D and BiasAdd happens only when output of Conv2D _only_
+// - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
// goes to BiasAdd.
// - Also, the intersection of attributes of both the nodes must have same
// values.
@@ -76,7 +76,7 @@ namespace tensorflow {
//
// Example of B.1 : Rewriting nodes to Mkl nodes
// ---------------------------------------------
-// Consider Relu layer. Current definition of Relu layer looks like:
+// Consider a Relu node. Current definition of Relu node looks like:
//
// O = Relu(A)
//
@@ -87,58 +87,59 @@ namespace tensorflow {
//
// O, O_m = MklRelu(A, A_m)
//
-// MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here A input is
-// same as A input of Relu; O output is same as O output of Relu. O_m is the
+// MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
+// same as input A of Relu; output O is same as output O of Relu. O_m is the
// additional output tensor that will be set by MklRelu, and it represents
// Mkl tensor corresponding to O -- in other words, O_m is some kind of
// metadata for O. A_m is additional input of Relu, and it represents metadata
// for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
-// this metadata from previous layer (in the graph).
+// this metadata from previous node in the graph.
//
-// When previous layer in the graph is Mkl layer, A_m will represent a valid
-// Mkl tensor. But when previous Mkl layer is not an Mkl layer, then A_m
-// represents a dummy Mkl tensor.
+// When a previous node in the graph is an Mkl node, A_m will represent a valid
+// Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
+// a dummy Mkl tensor.
//
// Rewriting rules:
-// - Selection of an op for rewriting happens by registering an op with this
-// pass. If an op is not registered, then it is not rewritten.
+// - Selection of a node for rewriting happens by registering the op type of
+// the node with the rewriting pass. If the op type is not registered, then
+// all nodes of this op type will not be rewritten.
// - Number of inputs after rewriting:
-// Since for every input Tensorflow tensor, the rewritten layer gets Mkl
-// tensor, rewritten op gets 2*N inputs, where N is the number of inputs
-// for original op.
+// Since for every input Tensorflow tensor, the rewritten node gets Mkl
+// tensor(s), rewritten node gets 2*N inputs, where N is the number of
+// inputs for the original node.
// - Number of outputs after rewriting:
-// Since for every output Tensorflow tensor, the rewritten layer generates
-// Mkl tensor, rewritten op generates 2*N outputs, where N is the number
-// of outputs of original op.
+// Since for every output Tensorflow tensor, the rewritten node generates
+// Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
+// number of outputs of the original node.
// - Ordering of Tensorflow tensors and Mkl tensors:
-// Since every op generates twice the number of inputs and outputs, one
-// could imagine different ordering among Tensorflow tensors and Mkl
-// tensors. E.g., let's assume an op 'Conv2D' takes (A, B) as input, then
-// new op 'MklConv2D' can take (A, A_m, B, B_m) as input or it can also
-// take (A, B, A_m, B_m) as input. Among N inputs one can get N!
-// permutations.
+// Since every rewritten node generates twice the number of inputs and
+// outputs, one could imagine various orderings among Tensorflow tensors
+// and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
+// inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
+// in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
+// order. Among N inputs one can get N! permutations.
//
-// So the question is: which one do we follow? Currently, we follow an
-// intuitive order where Mkl tensor follows a corresponding Tensorflow
-// tensor immediately. In the context of above example, it will be: (A,
-// A_m, B, B_m). We follow same ordering rule for output tensors.
-//
-// NOTE: Current rewriting approach rewrites an op to Mkl op without any
-// conditions. But in the future, it may be possible to consider
-// conditions such as input shapes and sizes to rewrite an op.
+// So the question is: which order do we follow? We support 2 types of
+// orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
+// follows an intuitive order where an Mkl tensor follows the
+// corresponding Tensorflow tensor immediately. In the context of the
+// above example, it will be: A, A_m, B, B_m. Note that the ordering rule
+// applies to both the inputs and outputs. Contiguous ordering means
+// all the Tensorflow tensors are contiguous followed by all the Mkl
+// tensors. We use contiguous ordering as default.
//
// Graph rewrite algorithm:
// Algorithm: Graph Rewrite
-// Input: Graph G, Names of nodes to rewrite and their new nodes
-// Output: Modified Graph G' if nodes are modified, G otherwise.
+// Input: Graph G, Names of the nodes to rewrite and their new names
+// Output: Modified Graph G' if the nodes are modified, G otherwise.
// Start:
-// N = Topological_Sort(G) // N is set of nodes in toposort order.
+// N = Topological_Sort(G) // N is a set of nodes in toposort order.
// foreach node n in N
// do
-// if (Is_MKL_Layer(n)) // Can this layer accept Mkl layout as input.
+// if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input.
// then
// E = set of <incoming edge and its src_output slot> of n
-// E' = {} // new set of edges for rewritten node
+// E' = {} // a new set of edges for rewritten node
// foreach <e,s> in E
// do
// E' U {<e,s>} // First copy edge which generates Tensorflow
@@ -146,42 +147,44 @@ namespace tensorflow {
// m = Source node of edge e
// if Is_Rewritten(m) // Did we rewrite this node in this pass?
// then
-// E' U {<m,s+1>} // If yes, then m will generate Mkl tensor
-// // as output.
+// E' U {<m,s+1>} // If yes, then m will generate an Mkl
+// // tensor as an additional output.
// else
-// d = Generate_Dummy_Mkl_Tensor() // If not, generate dummy
+// d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy
// // Mkl tensor.
-// E' U {<d,0>} // Dummy Mkl tensor has only 1 output slot.
+// E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot.
// fi
// done
// n' = Build_New_Node(G,new_name,E')
-// Mark_Rewritten(n') // Mark new node as being rewritten.
+// Mark_Rewritten(n') // Mark the new node as being rewritten.
// fi
// done
//
// Explanation:
-// For graph rewrite, we visit nodes of the graph in the topological
-// sort order. With this ordering, we visit nodes in top-to-bottom
-// fashion. We need this order because while visiting a node we want
-// all of its input nodes (parents) visited (and rewritten if
-// applicable). This is because if we need to rewrite a current node
+// For graph rewrite, we visit nodes of the input graph in the
+// topological sort order. With this ordering, we visit nodes in the
+// top-to-bottom fashion. We need this order because while visiting a
+// node we want that all of its input nodes are visited and rewritten if
+// applicable. This is because if we need to rewrite a given node
// then all of its input nodes need to be fixed (in other words they
-// cannot be removed later.)
+// cannot be deleted later.)
//
-// While visiting each node, we first check if it is Mkl layer. If
-// it is, then we rewrite that node after constructing new inputs to
-// the node. If it is not Mkl layer, then we do not rewrite the node.
+// While visiting a node, we first check if the op type of the node is
+// an Mkl op. If it is, then we rewrite that node after constructing
+// new inputs to the node. If the op type of the node is not Mkl op,
+// then we do not rewrite that node.
//
// Handling workspace propagation for certain ops:
//
// Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
-// passing of workspace from their corresponding forward ops. But
-// TensorFlow does not have a notion of workspace and as a result
-// does not allow producing additional outputs from these forward ops.
-// For these ops, we need to add an additional edge between forward
-// ops and their corresponding backward ops, and this edge carries
-// workspace tensor value and another edge carries Mkl tensor for
-// workspace tensor.
+// passing of a workspace from their respective forward ops. Workspace
+// tensors provide memory for storing results of intermediate operations
+// which are helpful in backward propagation. TensorFlow does not have
+// a notion of a workspace and as a result does not allow producing
+// additional outputs from these forward ops. For these ops, we need
+// to add 2 extra edges between forward ops and their corresponding
+// backward ops - the first extra edge carries a workspace tensor and
+// the second one carries an Mkl tensor for the workspace tensor.
//
// Example:
//
@@ -190,59 +193,61 @@ namespace tensorflow {
// A = MaxPool(T)
// B = MaxPoolGrad(X, A, Y)
//
-// We will transform this graph to propagate workspace as:
+// We will transform this graph to propagate the workspace as:
+// (with the contiguous ordering)
//
-// A, A_m, W, W_m = MklMaxPool(T, T_m)
-// B, B_m = MklMaxPoolGrad(X, X_m, A, A_m, Y, Y_m, W, W_m)
+// A, W, A_m, W_m = MklMaxPool(T, T_m)
+// B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
//
-// Here W is the workspace tensor. Transformed tensors with name
-// suffix _m are Mkl tensors and this transformation has been done
+// Here W is the workspace tensor. Transformed tensor names with the
+// suffix _m are Mkl tensors, and this transformation has been done
// using the algorithm discussed earlier. The transformation for
-// workspace only adds extra outputs (W, W_m) for forward op and
-// connects them to corresponding backward ops.
+// workspace propagation only adds extra outputs (W, W_m) for a forward
+// op and connects them to the corresponding backward ops.
//
// Terms:
//
// Forward op name = name of the op in the forward pass
-// where workspace originates (MaxPool in this example)
+// where a workspace tensor originates (MaxPool in this example)
// Backward op name = name of the op in the backward pass that receives
-// workspace from forward op (MaxPoolGrad in the example)
-// Slot = Number of the output or input slot that will be
-// used by the workspace (2 for MklMaxPool as W is 3rd
-// output of MaxPool (0 is 1st); 6 for MklMaxPoolGrad)
+// a workspace tensor from the forward op (MaxPoolGrad in the example)
+// Slot = Position of the output or input slot that will be
+// used by the workspace tensor (1 for MklMaxPool as W is the 2nd
+// output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
//
// Question:
//
-// How do we associate backward op to forward op? There can be more
-// than one op with exact same name.
+// How do we associate a backward op to a forward op? There can be more
+// than one op with the exact same name.
//
-// In this example we associate MaxPoolGrad with MaxPool. But there
+// In this example, we associate MaxPoolGrad with MaxPool. But there
// could be more than one MaxPool ops. To solve this problem, we look
-// for _direct_ edge between forward op and backward op (tensor A is
-// flowing along this edge in the example.)
+// for _direct_ edge between a forward op and a backward op (tensor A is
+// flowing along this edge in the example).
//
-// How do we transform forward and backward op when there is no direct
-// edge between them? In such case, we generate dummy tensors as
+// How do we transform forward and backward ops when there is no direct
+// edge between them? In such a case, we generate dummy tensors for
// workspace tensors. For the example, transformation of MaxPool will
-// be exactly same --- it is just that MaxPool won't generate any
-// workspace tensor. For MaxPoolGrad, transformation will also be same,
-// but instead of connecting W and W_m with outputs of MaxPool, we will
-// produce dummy tensors for them, and we will set workspace_enabled
-// attribute to false.
+// be exactly same as it would be when there is a direct edge between
+// the forward and the backward op --- it is just that MaxPool won't
+// generate any workspace tensor. For MaxPoolGrad, the transformation
+// will also be same, but instead of connecting W and W_m with the
+// outputs of MaxPool, we will produce dummy tensors for them, and we
+// will set workspace_enabled attribute to false.
//
// Example of B.2 : Context-based node rewrite
// -------------------------------------------
// Consider BiasAddGrad op as:
//
-// O = MklConv2D(A, A_m, B, B_m, C, C_m)
+// O = _MklConv2D(A, B, C, A_m, B_m, C_m)
// P = BiasAddGrad(O)
//
-// Then we rewrite is as:
+// Then we rewrite it as:
//
// P = Conv2DWithBiasBackpropBias(O, O_m)
//
-// 'Distance' between input of BiasAddGrad and MklConv2D in terms of hops is
-// the context matching depth. If MklConv2DWithBias is not within the context
+// 'Distance' between input of BiasAddGrad and _MklConv2D in terms of hops is
+// the context matching depth. If _MklConv2DWithBias is not within the context
// matching depth, then we do not rewrite BiasAddGrad.
// How many hops do we search for matching node in the backward dataflow graph?
@@ -255,53 +260,85 @@ static size_t kNodeMergeContextMaxDepth = 10;
class MklLayoutRewritePass : public GraphOptimizationPass {
public:
MklLayoutRewritePass() {
+ // NOTE: names are alphabetically sorted.
+ csinfo_.avg_pool = "AvgPool";
+ csinfo_.avg_pool_grad = "AvgPoolGrad";
+ csinfo_.bias_add = "BiasAdd";
+ csinfo_.bias_add_grad = "BiasAddGrad";
+ csinfo_.concat = "Concat";
+ csinfo_.concatv2 = "ConcatV2";
csinfo_.conv2d = "Conv2D";
- csinfo_.mklconv2d = "MklConv2D";
- csinfo_.mklconv2dwithbias = "MklConv2DWithBias";
- csinfo_.mklconv2dwithbiasbackpropbias = "MklConv2DWithBiasBackpropBias";
- csinfo_.biasadd = "BiasAdd";
+ csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
+ csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
+ csinfo_.fused_batch_norm = "FusedBatchNorm";
+ csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
+ csinfo_.lrn = "LRN";
+ csinfo_.lrn_grad = "LRNGrad";
csinfo_.matmul = "MatMul";
- csinfo_.biasaddgrad = "BiasAddGrad";
+ csinfo_.max_pool = "MaxPool";
+ csinfo_.max_pool_grad = "MaxPoolGrad";
+ csinfo_.mkl_conv2d = "_MklConv2D";
+ csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
+ csinfo_.mkl_conv2d_with_bias_backprop_bias =
+ "_MklConv2DWithBiasBackpropBias";
csinfo_.relu = "Relu";
- csinfo_.relugrad = "ReluGrad";
- csinfo_.maxpool = "MaxPool";
- csinfo_.maxpoolgrad = "MaxPoolGrad";
- csinfo_.avgpool = "AvgPool";
- csinfo_.avgpoolgrad = "AvgPoolGrad";
- csinfo_.conv2dgradinput = "Conv2DBackpropInput";
- csinfo_.conv2dgradfilter = "Conv2DBackpropFilter";
-
- rinfo_.push_back(
- {csinfo_.conv2d, csinfo_.mklconv2d, 2, CopyAttrsConv2D, AlwaysRewrite});
- rinfo_.push_back({csinfo_.conv2dgradfilter,
- GetMklOpName(csinfo_.conv2dgradfilter), 3,
+ csinfo_.reshape = "Reshape";
+ csinfo_.relu_grad = "ReluGrad";
+ csinfo_.split = "Split";
+
+ // NOTE: names are alphabetically sorted.
+ rinfo_.push_back({csinfo_.avg_pool, GetMklOpName(csinfo_.avg_pool), 1,
+ CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.avg_pool_grad,
+ GetMklOpName(csinfo_.avg_pool_grad), 2, CopyAttrsPooling,
+ AlwaysRewrite});
+ rinfo_.push_back({csinfo_.concat, GetMklOpName(csinfo_.concat), 0,
+ CopyAttrsConcat, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.concatv2, GetMklOpName(csinfo_.concatv2), 0,
+ CopyAttrsConcatV2, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv2d, GetMklOpName(csinfo_.conv2d), 2,
+ CopyAttrsConv2D, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv2d_grad_filter,
+ GetMklOpName(csinfo_.conv2d_grad_filter), 3,
CopyAttrsConv2D, AlwaysRewrite});
- rinfo_.push_back({csinfo_.conv2dgradinput,
- GetMklOpName(csinfo_.conv2dgradinput), 3, CopyAttrsConv2D,
+ rinfo_.push_back({csinfo_.conv2d_grad_input,
+ GetMklOpName(csinfo_.conv2d_grad_input), 3,
+ CopyAttrsConv2D, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.fused_batch_norm,
+ GetMklOpName(csinfo_.fused_batch_norm), 5,
+ CopyAttrsFusedBatchNorm, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.fused_batch_norm_grad,
+ GetMklOpName(csinfo_.fused_batch_norm_grad), 5,
+ CopyAttrsFusedBatchNorm, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.lrn, GetMklOpName(csinfo_.lrn), 1, CopyAttrsLRN,
+ AlwaysRewrite});
+ rinfo_.push_back({csinfo_.lrn_grad, GetMklOpName(csinfo_.lrn_grad), 3,
+ CopyAttrsLRN, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.max_pool, GetMklOpName(csinfo_.max_pool), 1,
+ CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.max_pool_grad,
+ GetMklOpName(csinfo_.max_pool_grad), 3, CopyAttrsPooling,
AlwaysRewrite});
rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1,
CopyAttrsRelu, AlwaysRewrite});
- rinfo_.push_back({csinfo_.maxpool, GetMklOpName(csinfo_.maxpool), 1,
- CopyAttrsPooling, AlwaysRewrite});
- rinfo_.push_back({csinfo_.maxpoolgrad, GetMklOpName(csinfo_.maxpoolgrad), 3,
- CopyAttrsPooling, AlwaysRewrite});
- rinfo_.push_back({csinfo_.avgpool, GetMklOpName(csinfo_.avgpool), 1,
- CopyAttrsPooling, AlwaysRewrite});
- rinfo_.push_back({csinfo_.avgpoolgrad, GetMklOpName(csinfo_.avgpoolgrad), 2,
- CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.reshape, GetMklOpName(csinfo_.reshape), 2,
+ CopyAttrsReshape, AlwaysRewrite});
+
+ // TODO(inteltf): we do not support ReluGrad and BiasAddGrad yet.
// Add info about which ops to add workspace edge to and the slots.
- wsinfo_.push_back({csinfo_.maxpool, csinfo_.maxpoolgrad, 0, 1, 2, 6});
+ wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
+ wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
// Add a rule for merging nodes
- minfo_.push_back(
- {csinfo_.mklconv2d, csinfo_.biasadd, 0, csinfo_.mklconv2dwithbias});
+ minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0,
+ csinfo_.mkl_conv2d_with_bias});
// We use maxhop of 10 based on empirical observations. Also, these are
// maxhops in backward data-flow graph. Since input of forward nodes
// (Conv2D) directly goes to backward nodes, we do not expect the
// hop-distance would be more than few nodes.
- cinfo_.push_back({csinfo_.biasaddgrad, csinfo_.mklconv2dwithbias,
+ cinfo_.push_back({csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias,
kNodeMergeContextMaxDepth});
}
@@ -318,73 +355,80 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
bool RunPass(std::unique_ptr<Graph>* g);
private:
- /// Structure to specify name of original op, its new name after rewrite,
- /// the number of inputs to the original op, and the function to be used
- /// to copy attributes for the op
+ /// Structure to specify the name of an original node, its new name after
+ /// rewrite, the number of inputs to the original node, the function to
+ /// be used to copy attributes for the op, and the rule (if any) which
+ /// must hold for rewriting the node
typedef struct {
- string name; // Original name of the op in the graph
- string newname; // New name of op in the graph
- int numins; // Number of inputs to the original op
- // Function handler to copy attributes from old node to new node.
- std::function<void(const Node*, NodeBuilder*)> copyattrs;
- std::function<bool(const Node*)> rewriterule; // Rule under which to
- // rewrite this node.
+ string name; // Original name of op of the node in the graph
+ string new_name; // New name of the op of the node in the graph
+ int num_ins; // The number of inputs to the original op type
+ // A function handler to copy attributes from an old node to a new node.
+ std::function<void(const Node*, NodeBuilder*)> copy_attrs;
+ std::function<bool(const Node*)> rewrite_rule; // A rule under which to
+ // rewrite this node.
} RewriteInfo;
- /// Structure to specify forward op, backward op, and the slot numbers
- /// in forward and backward op where we will add workspace edge.
+ /// Structure to specify a forward op, a backward op, and the slot numbers
+ /// in the forward and backward ops where we will add a workspace edge.
typedef struct {
- string fwdop; // Name of the forward op in the graph
- string bwdop; // Name of the backward op in the graph
- int fwdslot; // Output slot in the forward op node where actual
- // output tensor resides
- int bwdslot; // Input slot in the backward op node where actual
- // input tensor resides
- int wsfwdslot; // Output slot in the forward op node where workspace
- // edge is added
- int wsbwdslot; // Input slot in the backward op node where workspace
- // edge is added
+ string fwd_op; // Name of a forward op in the graph
+ string bwd_op; // Name of a backward op in the graph
+ int fwd_slot; // Output slot in the forward op node where actual
+ // output tensor resides
+ int bwd_slot; // Input slot in the backward op node where actual
+ // input tensor resides
+ int ws_fwd_slot; // Output slot in the forward op node where workspace
+ // edge is added
+ int ws_bwd_slot; // Input slot in the backward op node where workspace
+ // edge is added
} WorkSpaceInfo;
/// Structure to specify information used in node merge
typedef struct {
- string pred; // Predecessor node string
- string succ; // Successor node string
- int op; // What operand no the predecessor node corresponds
- // to successor node?
- string newnode; // Name of the node after merge
+ string pred; // Predecessor node string
+ string succ; // Successor node string
+ int op; // The operand no the predecessor node corresponds
+ // to the successor node
+ string new_node; // Name of the node after merge
} MergeInfo;
- /// Structure to specify the context information used in node rewrite rule
+ /// Structure to specify the context information used in a node rewrite rule
typedef struct {
- string node; // Name of the node to be rewritten
- string fwd; // Node name in forward pass that this node
- // corresponds to
- size_t maxhop; // Maximum number of hops the fwd is located
- // from this node. If fwd is farther than maxhop
- // then we do not rewrite the node.
+ string node; // Name of the node to be rewritten
+ string fwd; // Name of the node in the forward pass that this node
+ // corresponds to
+ size_t max_hop; // Maximum number of hops the fwd is located
+ // from this node. If the fwd is farther than max_hop
+ // then we do not rewrite the node.
} ContextInfo;
/// Structure to store all constant strings
+ /// NOTE: names are alphabetically sorted.
struct {
- string relu;
- string relugrad;
- // Conv ops
+ string avg_pool;
+ string avg_pool_grad;
+ string bias_add;
+ string bias_add_grad;
+ string concat;
+ string concatv2;
string conv2d;
- string mklconv2d;
- string conv2dgradinput;
- string conv2dgradfilter;
- string mklconv2dwithbias;
- string mklconv2dwithbiasbackpropbias;
- // Pooling ops
- string maxpool;
- string maxpoolgrad;
- string avgpool;
- string avgpoolgrad;
- // Others
- string biasadd;
+ string conv2d_grad_input;
+ string conv2d_grad_filter;
+ string fused_batch_norm;
+ string fused_batch_norm_grad;
+ string lrn;
+ string lrn_grad;
string matmul;
- string biasaddgrad;
+ string max_pool;
+ string max_pool_grad;
+ string mkl_conv2d;
+ string mkl_conv2d_with_bias;
+ string mkl_conv2d_with_bias_backprop_bias;
+ string relu;
+ string relu_grad;
+ string split;
+ string reshape;
} csinfo_;
/// Maintain info about nodes to rewrite
@@ -393,7 +437,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
/// Maintain info about nodes to add workspace edge
std::vector<WorkSpaceInfo> wsinfo_;
- /// Maintain info to be merged
+ /// Maintain info about nodes to be merged
std::vector<MergeInfo> minfo_;
/// Maintain info about nodes to rewrite
@@ -403,7 +447,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
std::unordered_set<const Node*> visited_nodes_;
private:
- // Predicate to check if we rewrote node 'n'
+ // Check if we rewrote node 'n'
//
// If we rewrote the node, then the rewritten node will produce
// Mkl tensor as output. If we did not rewrite the node, then
@@ -420,12 +464,49 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Clear all visited nodes
inline void UnMarkRewrittenNodes() { visited_nodes_.clear(); }
+ // Is this a graph node that can accept variable number of inputs?
+ // Return true if yes, false otherwise.
+ //
+ // Concat, Split are vararg nodes.
+ inline bool IsVarArgNode(Node* n) {
+ if (n->type_string() == csinfo_.concat ||
+ n->type_string() == csinfo_.concatv2 ||
+ n->type_string() == csinfo_.split) {
+ return true;
+ }
+ return false;
+ }
+
+ // Is OpDef::ArgDef a list type? It could be N * T or list(type).
+ // Refer to opdef.proto for details of list type.
+ inline bool ArgIsList(const OpDef::ArgDef& arg) const {
+ return !arg.type_list_attr().empty() || !arg.number_attr().empty();
+ }
+
+ // Get length of a list in 'n' if 'arg' is of list type. Refer to
+ // description of ArgIsList for definition of list type.
+ inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) {
+ CHECK_EQ(ArgIsList(arg), true);
+ int N = 0;
+ const string attr_name = !arg.type_list_attr().empty()
+ ? arg.type_list_attr()
+ : arg.number_attr();
+ if (!arg.type_list_attr().empty()) {
+ std::vector<DataType> value;
+ TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
+ N = value.size();
+ } else {
+ TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
+ }
+ return N;
+ }
+
// Get the name of Mkl op from original TensorFlow op
// We prefix 'Mkl' to the original op to get Mkl op.
// TODO(nhasabni) We should move this to mkl_util.h.
inline string GetMklOpName(const string& name) const {
// Prefix that we add to Tensorflow op name to construct Mkl op name.
- const char* const kMklOpPrefix = "Mkl";
+ const char* const kMklOpPrefix = "_Mkl";
return string(kMklOpPrefix) + name;
}
@@ -440,7 +521,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
//
// Input nodes succ and pred may be deleted if the call to
// this function is successful. Attempt to use the pointers
- // after the call to function may result is undefined behaviors.
+ // after the call to function may result in undefined behaviors.
//
// @input g - input graph, succ - successor node, pred - predecessor node
// @return Status::OK(), if merging is successful and supported.
@@ -470,13 +551,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// gradient op in the backward direction.
//
// @input n - Node (gradient op) whose contextinfo is to be searched,
- // fwdn - pointer to node from the forward pass that this node
- // belongs to. fwdn cannot be NULL.
+ // fwd_node - pointer to node from the forward pass that this node
+ // belongs to. fwd_node cannot be NULL.
// @return Matching contextinfo in case a match is found; null otherwise.
- // Also updates *fwdn with pointer to forward node that this context
- // matches.
+ // Also updates *fwd_node with pointer to forward node that this
+ // context matches.
static const ContextInfo* SearchMatchingContext(const Node* n,
- const Node** fwdn);
+ const Node** fwd_node);
// Rewrites input node to a new node specified by its matching rewrite info.
//
@@ -494,46 +575,132 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Otherwise, it is not updated.
Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
+ // Get nodes that will feed a list of TF tensors to the new
+ // node that we are constructing.
+ //
+ // @input g - input graph,
+ // @input inputs - inputs to old node that we are using for constructing
+ // new inputs,
+ // @input input_idx - the index in the 'inputs' vector pointing to the
+ // current input that we have processed so far
+ // @output input_idx - index will be incremented by the number of nodes
+ // from 'inputs' that are processed
+ // @input list_length - The expected length of list of TF tensors
+ // @output output_nodes - the list of new nodes creating TF tensors
+ //
+ // @return None
+ void GetNodesProducingTFTensorList(
+ const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
+ int* input_idx, int list_length,
+ std::vector<NodeBuilder::NodeOut>* output_nodes);
+
+ // Get nodes that will feed a list of Mkl tensors to the new
+ // node that we are constructing.
+ //
+ // @input g - input graph,
+ // @input inputs - inputs to old node that we are using for constructing
+ // new inputs,
+ // @input input_idx - the index in the 'inputs' vector pointing to the
+ // current input that we have processed so far
+ // @output input_idx - index will be incremented by the number of nodes
+ // from 'inputs' that are processed
+ // @input list_length - The expected length of list of Mkl tensors
+ // @output output_nodes - the list of new nodes creating Mkl tensors
+ //
+ // @return None
+ void GetNodesProducingMklTensorList(
+ std::unique_ptr<Graph>* g,
+ const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
+ int* input_idx, int list_length,
+ std::vector<NodeBuilder::NodeOut>* output_nodes);
+
+ // Get a node that will feed an Mkl tensor to the new
+ // node that we are constructing. The output node could be (1) 'n'
+ // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
+ // if 'n' is not an Mkl layer.
+ //
+ // @input g - input graph,
+ // @input n - Node based on which we are creating Mkl node,
+ // @input n_output_slot - the output slot of node 'n'
+ // which is feeding to the node that we are constructing
+ // @output mkl_node - the new node that will feed Mkl tensor
+ // @output mkl_node_output_slot - the slot number of mkl_node that
+ // will feed the tensor
+ // @return None
+ void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* n,
+ int n_output_slot, Node** mkl_node,
+ int* mkl_node_output_slot);
+
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
- // in graph 'g'. Original node is input in 'orign'.
+ // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
+ // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
+ // producing workspace edges if 'are_workspace_tensors_available' is true.
+ // Otherwise, 'workspace_tensors' is empty vector.
//
- // For details, refer to 'Number of inputs after rewriting' section in the
+ // For details, refer to 'Ordering of inputs after rewriting' section in the
// documentation above.
//
// Returns Status::OK() if setting up inputs is successful, otherwise
// returns appropriate status code.
+ int SetUpContiguousInputs(
+ std::unique_ptr<Graph>* g,
+ const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
+ NodeBuilder* nb, Node* old_node,
+ std::vector<NodeBuilder::NodeOut>* workspace_tensors,
+ bool are_workspace_tensors_available);
+
+ // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
+ // in graph 'g'. Original node is input in 'orig_node'.
+ //
+ // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
+ // section in the documentation above.
+ //
+ // Returns Status::OK() if setting up inputs is successful, otherwise
+ // returns appropriate status code.
Status SetUpInputs(std::unique_ptr<Graph>* g,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
- NodeBuilder* nb, Node* orign);
-
- // Add workspace edge on the input or output side of Node 'orign' by using
- // NodeBuilder 'nb' for the new node provided. If 'orign' does not dictate
- // adding workspace edge then do not add it.
- void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orign,
- NodeBuilder* nb);
+ NodeBuilder* nb, Node* orig_node);
+
+ // Add workspace edge on the input or output side of Node 'orig_node' by using
+ // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
+ // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
+ // tensors, if they need to be added, will be set into these tensors.
+ // If we set workspace tensors, then are_ws_tensors_added should be true.
+ void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node,
+ NodeBuilder* nb,
+ std::vector<NodeBuilder::NodeOut>* ws_tensors,
+ bool* are_ws_tensors_added);
// Functions specific to operators to copy attributes
// We need operator-specific function to copy attributes because the framework
// does not provide any generic function for it.
- static void CopyAttrsConv2D(const Node* orign, NodeBuilder* nb);
- static void CopyAttrsBiasAddGrad(const Node* orign, NodeBuilder* nb);
- static void CopyAttrsPooling(const Node* orign, NodeBuilder* nb);
- static void CopyAttrsRelu(const Node* orign, NodeBuilder* nb);
+ // NOTE: names are alphabetically sorted.
+ static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsRelu(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb);
// Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
- // using node for original node 'orign' and return it in '*out'.
+ // using node for original node 'orig_node' and return it in '*out'.
// TODO(nhasabni) We should move this to mkl_util.h
void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
- Node* orign);
+ Node* orig_node);
void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
- Node* orign);
+ Node* orig_node);
};
std::vector<MklLayoutRewritePass::ContextInfo> MklLayoutRewritePass::cinfo_;
-// We register Mkl rewrite pass for phase 1 in pre-placement group.
-// Do not change the ordering of the Mkl passes.
-REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 1,
+// We register Mkl rewrite pass for phase 1 in post rewrite group.
+// We register it here so that we get a complete picture of all users of Mkl
+// nodes. Do not change the ordering of the Mkl passes.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 1,
MklLayoutRewritePass);
//////////////////////////////////////////////////////////////////////////
@@ -543,7 +710,6 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 1,
static void FillInputs(const Node* n,
gtl::InlinedVector<Node*, 4>* control_edges,
gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
- DCHECK_EQ(in->size(), n->num_inputs());
control_edges->clear();
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) {
@@ -561,9 +727,43 @@ static void FillInputs(const Node* n,
}
}
+void MklLayoutRewritePass::GetNodesProducingTFTensorList(
+ const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
+ int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
+ CHECK_LT(*input_idx, inputs.size());
+ CHECK_GT(list_length, 0);
+ CHECK_NOTNULL(output_nodes);
+ output_nodes->reserve(list_length);
+
+ while (list_length != 0) {
+ CHECK_GT(list_length, 0);
+ CHECK_LE(*input_idx, inputs.size());
+ Node* n = inputs[*input_idx].first;
+ int slot = inputs[*input_idx].second;
+ const OpDef::ArgDef& arg = n->op_def().output_arg(slot);
+ // If input node 'n' is producing a list/array output at output
+ // slot 'slot' then we need to find out the length of that list/array.
+ if (ArgIsList(arg)) {
+ int N = GetTensorListLength(arg, n);
+ CHECK_LE(N, list_length);
+ for (int j = 0; j < N; j++) {
+ output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
+ }
+ (*input_idx)++;
+ list_length -= N;
+ } else {
+ // But if input node 'n' is just producing a single tensor at
+ // output slot 'slot' then we just add that single node.
+ output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
+ (*input_idx)++;
+ list_length--;
+ }
+ }
+}
+
// TODO(nhasabni) We should move this to mkl_util.h.
void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
- Node** out, Node* orign) {
+ Node** out, Node* orig_node) {
// We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
// dummy Mkl tensor. 8 = 2*size_t.
const DataType dt = DataTypeToEnum<uint8>::v();
@@ -574,63 +774,228 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
8);
TensorShape dummy_shape({8});
dummy_shape.AsProto(proto.mutable_tensor_shape());
- TF_CHECK_OK(
- NodeBuilder((*g)->NewName("DMT"), "Const")
- .Attr("value", proto)
- .Attr("dtype", dt)
- .Device(orign->def().device()) // We place this node on same
- // device as device of original
- // node.
- .Finalize(&**g, out));
- (*out)->set_assigned_device_name(orign->assigned_device_name());
+ TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
+ .Attr("value", proto)
+ .Attr("dtype", dt)
+ .Device(orig_node->def().device()) // We place this node on
+ // the same device as the
+ // device of the original
+ // node.
+ .Finalize(&**g, out));
+ (*out)->set_assigned_device_name(orig_node->assigned_device_name());
}
-Status MklLayoutRewritePass::SetUpInputs(
+void MklLayoutRewritePass::GetNodesProducingMklTensorList(
+ std::unique_ptr<Graph>* g,
+ const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
+ int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
+ CHECK_LT(*input_idx, inputs.size());
+ CHECK_GT(list_length, 0);
+ CHECK_NOTNULL(output_nodes);
+ output_nodes->reserve(list_length);
+
+ while (list_length != 0) {
+ CHECK_GT(list_length, 0);
+ CHECK_LE(*input_idx, inputs.size());
+ Node* n = inputs[*input_idx].first;
+ int slot = inputs[*input_idx].second;
+ const OpDef::ArgDef& arg = n->op_def().output_arg(slot);
+ // We need to check first if the input edge is going to carry a
+ // single tensor or a list of tensors. If it is a list of tensors,
+ // then we need to create list of Mkl dummy nodes.
+ if (ArgIsList(arg)) {
+ // If input node 'n' is producing a list/array output at output
+ // slot 'slot' then we need to find out the length of that list/array.
+ int N = GetTensorListLength(arg, n);
+ CHECK_LE(N, list_length);
+ Node* mkl_node = nullptr;
+ int mkl_node_output_slot = 0;
+ // If it is a list, then create a list of Mkl dummy nodes.
+ for (int j = 0; j < N; j++) {
+ GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
+ output_nodes->push_back(
+ NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
+ }
+ (*input_idx)++;
+ list_length -= N;
+ } else {
+ // If it is not a list, then create a single Mkl tensor node.
+ Node* mkl_node = nullptr;
+ int mkl_node_output_slot = 0;
+ GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
+ output_nodes->push_back(
+ NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
+ (*input_idx)++;
+ list_length--;
+ }
+ }
+}
+
+// Get an input node that will feed Mkl tensor to the new
+// node that we are constructing. An input node could be (1) 'n'
+// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
+// if 'n' is not an Mkl layer.
+void MklLayoutRewritePass::GetNodeProducingMklTensor(
+ std::unique_ptr<Graph>* g, Node* n, int n_output_slot, Node** mkl_node,
+ int* mkl_node_output_slot) {
+ CHECK_NOTNULL(n);
+ CHECK_NOTNULL(mkl_node);
+ CHECK_NOTNULL(mkl_node_output_slot);
+ if (IsRewrittenNode(n)) {
+ // If we have visited this node and rewritten it, then it will generate
+ // an edge that will receive Mkl tensor from a node.
+ // First, let's assert that this op is Mkl layer.
+ DataType T;
+ TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
+ // If this op has been rewritten, then its name must have been same as
+ // Mkl op.
+ CHECK_EQ(mkl_op_registry::IsMklOp(n->type_string(), T), true);
+ // output slot number for Mkl tensor would be N+slot number of TensorFlow
+ // tensor, where N is total number of TensorFlow tensors.
+ *mkl_node = n;
+ *mkl_node_output_slot =
+ GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
+ } else {
+ // If we have not visited the node and rewritten it, then we need
+ // to create a dummy node that will feed a dummy Mkl tensor to this node.
+ // DummyMklTensor node has no input and generates only 1 output
+ // (dummy Mkl tensor) as output slot number 0.
+ GetDummyMklTensorNode(g, mkl_node, n);
+ CHECK_NOTNULL(*mkl_node);
+ *mkl_node_output_slot = 0;
+ }
+}
+
+int MklLayoutRewritePass::SetUpContiguousInputs(
std::unique_ptr<Graph>* g,
- const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, NodeBuilder* nb,
- Node* orign) {
- std::vector<NodeBuilder::NodeOut> new_inputs;
-
- // 1. Let's setup inputs for the new node.
- for (int i = 0; i < inputs.size(); i++) {
- Node* n = inputs[i].first;
- // First let's copy original TF tensor input as it is.
- new_inputs.push_back(NodeBuilder::NodeOut(n, inputs[i].second));
-
- // Second, let's add edge to propagate Mkl tensors from input Mkl layers,
- // or generate a dummy Mkl tensor representing not-mkl-tensor case.
- if (IsRewrittenNode(n)) {
- // If we have visited this node and rewritten it, then it will generate
- // an edge that will receive Mkl tensor from a node.
- // First, let's assert that this op is Mkl layer.
- DataType T;
- TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
- // If this op has been rewritten, then its name must have been same as
- // Mkl op.
- CHECK_EQ(mkl_layer_registry::IsMklLayer(n->type_string(), T), true);
- // src slot number for Mkl tensor would be the one next to TF tensor
- // slot number.
- new_inputs.push_back(NodeBuilder::NodeOut(n, inputs[i].second + 1));
+ const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
+ NodeBuilder* nb, Node* old_node,
+ std::vector<NodeBuilder::NodeOut>* workspace_tensors,
+ bool are_workspace_tensors_available) {
+ CHECK_NOTNULL(workspace_tensors);
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+
+ // Number of input slots to original op
+ // Input slots are represented by .Input() calls in REGISTER_OP.
+ int old_node_input_slots = old_node->op_def().input_arg_size();
+ // Actual number of inputs can be greater than or equal to number
+ // of Input slots because inputs of type list could be unfolded.
+ CHECK_GE(old_node_inputs.size(), old_node_input_slots);
+ int nn_slot_idx = 0; // slot index for inputs of new node
+
+ // Let's copy all inputs (TF tensors) of original node to new node.
+ int iidx = 0;
+ for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
+ // An input slot could be a single tensor or a list. We need
+ // to handle this case accordingly.
+ CHECK_LT(iidx, old_node_inputs.size());
+ const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
+ if (ArgIsList(arg)) {
+ std::vector<NodeBuilder::NodeOut> new_node_inputs;
+ int N = GetTensorListLength(arg, old_node);
+ GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
+ &new_node_inputs);
+ nb->Input(new_node_inputs);
+ nn_slot_idx++;
} else {
- // If we have not visited the node and rewritten it, then we need
- // to create a dummy node that will feed a non-Mkl tensor to this node.
- // DummyMklTensor node has no input and generates only 1 output
- // (dummy Mkl tensor) as output slot number 0.
- Node* dmt = nullptr;
- GetDummyMklTensorNode(g, &dmt, orign);
- CHECK_NOTNULL(dmt);
- new_inputs.push_back(NodeBuilder::NodeOut(dmt, 0));
+ nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
+ iidx++;
+ nn_slot_idx++;
}
}
- // The total number of inputs to new node _must_ be 2 times the number
- // of inputs to the original node: N original Tensorflow tensors and
- // N for Mkl tensors corresponding to each Tensorflow tensors.
- CHECK_EQ(new_inputs.size(), inputs.size() * 2);
+ // If workspace tensors are available for this op and we are using
+ // contiguous ordering then we need to add Tensorflow tensor for
+ // workspace here because Tensorflow tensor for workspace is the
+ // last tensor in the list of Tensorflow tensors.
+ if (are_workspace_tensors_available) {
+ CHECK_EQ(workspace_tensors->size(), 2);
+ // Tensorflow tensor
+ nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
+ nn_slot_idx++;
+ }
+
+ // Let's now setup all Mkl inputs to new node.
+ // Number of Mkl inputs must be same as number of TF inputs.
+ iidx = 0;
+ for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
+ // An input slot could be a single tensor or a list. We need
+ // to handle this case accordingly.
+ CHECK_LT(iidx, old_node_inputs.size());
+ const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
+ if (ArgIsList(arg)) {
+ std::vector<NodeBuilder::NodeOut> new_node_inputs;
+ int N = GetTensorListLength(arg, old_node);
+ GetNodesProducingMklTensorList(g, old_node_inputs, &iidx, N,
+ &new_node_inputs);
+ nb->Input(new_node_inputs);
+ nn_slot_idx++;
+ } else {
+ Node* mkl_node = nullptr;
+ int mkl_node_output_slot = 0;
+ GetNodeProducingMklTensor(g, old_node_inputs[iidx].first,
+ old_node_inputs[iidx].second, &mkl_node,
+ &mkl_node_output_slot);
+ nb->Input(mkl_node, mkl_node_output_slot);
+ iidx++;
+ nn_slot_idx++;
+ }
+ }
+
+ // If workspace tensors are available for this op and we are using
+ // contiguous ordering then we need to add Mkl tensor for
+ // workspace here because Mkl tensor for workspace is the
+ // last tensor in the list of Mkl tensors.
+ if (are_workspace_tensors_available) {
+ CHECK_EQ(workspace_tensors->size(), 2);
+ // Mkl tensor
+ nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
+ nn_slot_idx++;
+ }
+
+ return nn_slot_idx;
+}
+
+Status MklLayoutRewritePass::SetUpInputs(
+ std::unique_ptr<Graph>* g,
+ const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
+ NodeBuilder* nb, Node* old_node) {
+ // Let's check if we need to add workspace tensors for this node.
+ // We add workspace edge only for MaxPool, LRN and BatchNorm.
+ std::vector<NodeBuilder::NodeOut> workspace_tensors;
+ bool are_workspace_tensors_available = false;
+ AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
+ &are_workspace_tensors_available);
+
+ int new_node_input_slots = 0;
+ if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+ // TODO(nhasabni): implement this function just for same of completion.
+ // We do not use interleaved ordering right now.
+ return Status(
+ error::Code::UNIMPLEMENTED,
+ "Interleaved ordering of tensors is currently not supported.");
+ } else {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ new_node_input_slots = SetUpContiguousInputs(
+ g, old_node_inputs, nb, old_node, &workspace_tensors,
+ are_workspace_tensors_available);
+ }
- // 2. Let's add the new inputs.
- for (auto ni : new_inputs) {
- nb->Input(ni.node, ni.index);
+ // Sanity check
+ int old_node_input_slots = old_node->op_def().input_arg_size();
+ if (!are_workspace_tensors_available) {
+ // If we are not adding workspace tensors for this op, then the total
+ // number of input slots to the new node _must_ be 2 times the number
+ // of input slots to the original node: N original Tensorflow tensors and
+ // N for Mkl tensors corresponding to each Tensorflow tensors.
+ CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
+ } else {
+ // If we are adding workspace tensors for this op, then the total
+ // The total number of input slots to new node _must_ be 2 times the number
+ // of input slots to the original node: N original Tensorflow tensors and
+ // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
+ // (for workspace Tensorflow tensor and workspace Mkl tensor).
+ CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
}
return Status::OK();
@@ -642,7 +1007,7 @@ Status MklLayoutRewritePass::SetUpInputs(
// TODO(nhasabni) We should move this to mkl_util.h.
void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
- std::unique_ptr<Graph>* g, Node** out, Node* orign) {
+ std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
// We use a tensor of shape {1} and value 0 to represent
// dummy float tensor. We need this as a dummy workspace tensor.
// Workspace tensor has type float.
@@ -654,39 +1019,42 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
4);
TensorShape dummy_shape({1});
dummy_shape.AsProto(proto.mutable_tensor_shape());
- TF_CHECK_OK(
- NodeBuilder((*g)->NewName("DMT"), "Const")
- .Attr("value", proto)
- .Attr("dtype", dt)
- .Device(orign->def().device()) // We place this node on same
- // device as device of original
- // node.
- .Finalize(&**g, out));
- (*out)->set_assigned_device_name(orign->assigned_device_name());
+ TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
+ .Attr("value", proto)
+ .Attr("dtype", dt)
+ .Device(orig_node->def().device()) // We place this node on
+ // same the device as the
+ // device of the original
+ // node.
+ .Finalize(&**g, out));
+ (*out)->set_assigned_device_name(orig_node->assigned_device_name());
}
-void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
- Node* orign,
- NodeBuilder* nb) {
- bool workspace_edge_added = false;
+void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
+ std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb,
+ std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
+ bool workspace_edge_added = false; // Default initializer
+ CHECK_NOTNULL(are_ws_tensors_added);
+ *are_ws_tensors_added = false; // Default initializer
+
DataType T;
- TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
for (auto ws : wsinfo_) {
- if (orign->type_string() == ws.fwdop &&
- mkl_layer_registry::IsMklLayer(GetMklOpName(orign->type_string()), T)) {
+ if (orig_node->type_string() == ws.fwd_op &&
+ mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()), T)) {
// If this op is a fwd op, then we need to check if there is an
- // edge from this node's fwdslot to bwdop's bwdslot. If there is
+ // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
// an edge, then we just add an attribute on this node for setting
// workspace_passed to true. We don't add actual workspace edge
// in this node. Actual workspace edge gets added in the backward
// op for this node.
- for (const Edge* e : orign->out_edges()) {
- if (e->src_output() == ws.fwdslot &&
- e->dst()->type_string() == ws.bwdop &&
- e->dst_input() == ws.bwdslot) {
+ for (const Edge* e : orig_node->out_edges()) {
+ if (e->src_output() == ws.fwd_slot &&
+ e->dst()->type_string() == ws.bwd_op &&
+ e->dst_input() == ws.bwd_slot) {
nb->Attr("workspace_enabled", true);
VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
- << orign->type_string();
+ << orig_node->type_string();
workspace_edge_added = true;
// We found the edge that we were looking for, so break.
break;
@@ -698,34 +1066,40 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
// node.
nb->Attr("workspace_enabled", false);
}
- } else if (orign->type_string() == ws.bwdop &&
- mkl_layer_registry::IsMklLayer(
- GetMklOpName(orign->type_string()), T)) {
+ } else if (orig_node->type_string() == ws.bwd_op &&
+ mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()),
+ T)) {
// If this op is a bwd op, then we need to add workspace edge and
// it's Mkl tensor edge between its corresponding fwd op and this
- // op. Corresponding fwd op is specified in 'fwdop' field of
- // workspace info. fwdslot and bwdslot in workspace info specify
+ // op. Corresponding fwd op is specified in 'fwd_op' field of
+ // workspace info. fwd_slot and bwd_slot in workspace info specify
// an edge between which slots connect forward and backward op.
// Once all these criteria match, we add a workspace edge between
- // wsfwdslot and wsbwdslot. It's corresponding Mkl tensor is added
- // in wsfwdslot+1 and wsbwdslot+1.
- for (const Edge* e : orign->in_edges()) {
- if (e->src_output() == ws.fwdslot &&
+ // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
+ // determined by interleaved/contiguous ordering. Function
+ // DataIndexToMetaDataIndex tells us the location of Mkl tensor
+ // from the location of the Tensorflow tensor.
+ for (const Edge* e : orig_node->in_edges()) {
+ if (e->src_output() == ws.fwd_slot &&
// We would have rewritten the forward op, so we need to use
// GetMklOpName call to get its Mkl name.
- e->src()->type_string() == GetMklOpName(ws.fwdop) &&
- e->dst_input() == ws.bwdslot) {
+ e->src()->type_string() == GetMklOpName(ws.fwd_op) &&
+ e->dst_input() == ws.bwd_slot) {
nb->Attr("workspace_enabled", true);
+ CHECK_NOTNULL(ws_tensors);
// Add workspace edge between fwd op and bwd op.
- nb->Input(e->src(), ws.wsfwdslot);
+ ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
// Add Mkl tensor edge for workspace edge between fwd op and bwd op.
- nb->Input(e->src(), ws.wsfwdslot + 1);
+ ws_tensors->push_back(NodeBuilder::NodeOut(
+ e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
+ e->src()->num_outputs())));
+ *are_ws_tensors_added = true;
// In terms of input ordering, we add these calls to add Input
// here because workspace edge (and its Mkl tensor) is the last
// edge in the fwdop and bwdop. So all inputs before workspace
// tensor have been added by SetUpInputs function.
VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
- << orign->type_string();
+ << orig_node->type_string();
workspace_edge_added = true;
// We found the edge that we were looking for, so break.
break;
@@ -740,15 +1114,18 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
nb->Attr("workspace_enabled", false);
Node* dmt_ws = nullptr; // Dummy tensor for workspace
Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace
- GetDummyWorkspaceTensorNode(g, &dmt_ws, orign);
- GetDummyMklTensorNode(g, &dmt_mkl_ws, orign);
+ GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
+ GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
CHECK_NOTNULL(dmt_ws);
CHECK_NOTNULL(dmt_mkl_ws);
- nb->Input(dmt_ws, 0); // We add dummy tensor as workspace tensor.
- nb->Input(dmt_mkl_ws, 0); // We add dummy tensor as Mkl
- // tensor for workspace tensor.
+ CHECK_NOTNULL(ws_tensors);
+ // We add dummy tensor as workspace tensor.
+ ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
+ // We add dummy tensor as Mkl tensor for workspace tensor.
+ ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
+ *are_ws_tensors_added = true;
VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
- << orign->type_string();
+ << orig_node->type_string();
}
} else {
// If this node does not match any workspace info, then we do not
@@ -761,7 +1138,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
// Op-specific functions to copy attributes from old node to new node
//////////////////////////////////////////////////////////////////////////
-void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orign, NodeBuilder* nb) {
+void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
+ NodeBuilder* nb) {
DataType T;
string data_format;
string padding;
@@ -769,11 +1147,12 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orign, NodeBuilder* nb) {
bool use_cudnn_on_gpu;
// Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orign->def(), "strides", &strides));
- TF_CHECK_OK(GetNodeAttr(orign->def(), "padding", &padding));
- TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &data_format));
- TF_CHECK_OK(GetNodeAttr(orign->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
+ TF_CHECK_OK(
+ GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
// Add attributes to new node.
nb->Attr("T", T);
@@ -783,16 +1162,16 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orign, NodeBuilder* nb) {
nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
}
-void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orign,
+void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
string data_format;
std::vector<int32> strides;
// Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orign->def(), "strides", &strides));
- TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &data_format));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
// Add attributes to new node.
nb->Attr("T", T);
@@ -800,7 +1179,30 @@ void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orign,
nb->Attr("data_format", data_format);
}
-void MklLayoutRewritePass::CopyAttrsPooling(const Node* orign,
+void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ int depth_radius;
+ float bias;
+ float alpha;
+ float beta;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta));
+
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("depth_radius", depth_radius);
+ nb->Attr("bias", bias);
+ nb->Attr("alpha", alpha);
+ nb->Attr("beta", beta);
+}
+
+void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
string data_format;
@@ -808,11 +1210,11 @@ void MklLayoutRewritePass::CopyAttrsPooling(const Node* orign,
std::vector<int32> ksize, strides;
// Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orign->def(), "ksize", &ksize));
- TF_CHECK_OK(GetNodeAttr(orign->def(), "strides", &strides));
- TF_CHECK_OK(GetNodeAttr(orign->def(), "padding", &padding));
- TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &data_format));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
// Add attributes to new node.
nb->Attr("T", T);
@@ -822,14 +1224,97 @@ void MklLayoutRewritePass::CopyAttrsPooling(const Node* orign,
nb->Attr("data_format", data_format);
}
-void MklLayoutRewritePass::CopyAttrsRelu(const Node* orign, NodeBuilder* nb) {
+void MklLayoutRewritePass::CopyAttrsRelu(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+
+ // Add attributes to new node.
+ nb->Attr("T", T);
+}
+
+void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ string data_format;
+ int num_split;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
+
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("num_split", num_split);
+ nb->Attr("data_format", data_format);
+}
+
+void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ int N;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
+
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("N", N);
+}
+
+void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ int N;
+ DataType tidx;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx));
+
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("N", N);
+ nb->Attr("Tidx", tidx);
+}
+
+void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ float epsilon;
+ string data_format;
+ bool is_training;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training));
+
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("epsilon", epsilon);
+ nb->Attr("data_format", data_format);
+ nb->Attr("is_training", is_training);
+}
+
+void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
+ NodeBuilder* nb) {
DataType T;
+ DataType Tshape;
// Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
// Add attributes to new node.
nb->Attr("T", T);
+ nb->Attr("Tshape", Tshape);
}
//////////////////////////////////////////////////////////////////////////
@@ -889,8 +1374,8 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
CHECK_NOTNULL(succ);
CHECK_NOTNULL(pred);
- if (succ->type_string() == csinfo_.biasadd &&
- pred->type_string() == csinfo_.mklconv2d) {
+ if (succ->type_string() == csinfo_.bias_add &&
+ pred->type_string() == csinfo_.mkl_conv2d) {
// 1. Get all attributes from input nodes.
DataType T_pred, T_succ;
string padding;
@@ -947,7 +1432,7 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
// 2. Get inputs from both the nodes.
// Find the 2 inputs from the conv and the bias from the add Bias.
// Get operand 0, 1 of conv2D and their Mkl tensors.
- CHECK_EQ(pred->in_edges().size(), 4); // MklConv2D must have 4 inputs.
+ CHECK_EQ(pred->in_edges().size(), 4); // _MklConv2D must have 4 inputs.
// Get operand 1 of add_bias
// BiasAdd must have 2 inputs: Conv, bias
CHECK_EQ(succ->in_edges().size(), 2);
@@ -960,13 +1445,29 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
// We will use the node name of BiasAdd as the name of new node
// Build new node. We use same name as original node, but change the op
// name.
- NodeBuilder nb(succ->name(), csinfo_.mklconv2dwithbias);
- nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D
- nb.Input(pred_in[1].first, pred_in[1].second); // Mkl for In1
- nb.Input(pred_in[2].first, pred_in[2].second); // In2 of Conv2D
- nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2
- nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
- nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd
+ NodeBuilder nb(succ->name(), csinfo_.mkl_conv2d_with_bias);
+ if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+ nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D
+ // pred_in[1] will be Mkl tensor for In1 if we follow interleaved
+ // ordering, and it will be 2nd Tensorflow tensor for Conv2D if
+ // we follow contiguous ordering.
+ nb.Input(pred_in[1].first, pred_in[1].second); // Mkl for In1
+ nb.Input(pred_in[2].first, pred_in[2].second); // In2 of Conv2D
+ nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2
+ nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
+ nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd
+ } else {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D
+ // pred_in[1] will be Mkl tensor for In1 if we follow interleaved
+ // ordering, and it will be 2nd Tensorflow tensor for Conv2D if
+ // we follow contiguous ordering.
+ nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D
+ nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
+ nb.Input(pred_in[2].first, pred_in[2].second); // Mkl for In1 of Conv2D
+ nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2 of Conv2D
+ nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd
+ }
// Copy attributes from Conv2D to Conv2DWithBias.
CopyAttrsConv2D(const_cast<const Node*>(pred), &nb);
@@ -975,30 +1476,30 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
nb.Device(succ->def().device());
// Create node.
- Node* newn;
- nb.Finalize(&**g, &newn);
- CHECK_NOTNULL(newn);
+ Node* new_node;
+ nb.Finalize(&**g, &new_node);
+ CHECK_NOTNULL(new_node);
// Set the Mkl layer label for this op.
- newn->AddAttr("_kernel", mkl_layer_registry::kMklLayerLabel);
+ new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
// Incoming edges are fixed, we will fix the outgoing edges now.
for (const Edge* e : succ->out_edges()) {
- (*g)->AddEdge(newn, e->src_output(), e->dst(), e->dst_input());
+ (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
}
// Copy device assigned to old node to new node.
// It's ok to use pred or succ as we have enforced a check that
// both have same device assigned.
- newn->set_assigned_device_name(pred->assigned_device_name());
+ new_node->set_assigned_device_name(pred->assigned_device_name());
VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
<< ", and node: " << succ->DebugString()
- << ", into node:" << newn->DebugString();
+ << ", into node:" << new_node->DebugString();
(*g)->RemoveNode(succ);
(*g)->RemoveNode(pred);
- MarkRewrittenNode(newn);
+ MarkRewrittenNode(new_node);
return Status::OK();
}
@@ -1011,35 +1512,39 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
// Helper functions for node rewrite
//////////////////////////////////////////////////////////////////////////
-Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* orign,
+Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
+ Node* orig_node,
const RewriteInfo* ri) {
CHECK_NOTNULL(ri);
- CHECK_NOTNULL(orign);
+ CHECK_NOTNULL(orig_node);
- VLOG(1) << "MklLayoutRewritePass: Original node:" << orign->DebugString();
+ VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
// Check if this is scenario 2 (context-based rewrite).
// Get the matching ContextInfo if it is.
- const Node* fwdn = nullptr;
+ const Node* fwd_node = nullptr;
const ContextInfo* ci = nullptr;
bool is_context_based_rewrite = false;
- if ((ci = SearchMatchingContext(orign, &fwdn)) != nullptr) {
- CHECK_NOTNULL(fwdn);
+ if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) {
+ CHECK_NOTNULL(fwd_node);
is_context_based_rewrite = true;
// Sanity checks for context-based rewrite (if any)
- if (orign->type_string() == csinfo_.biasaddgrad &&
- ri->newname == csinfo_.mklconv2dwithbiasbackpropbias) {
+ if (orig_node->type_string() == csinfo_.bias_add_grad &&
+ ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) {
DataType orig_T, ctx_T;
string orig_data_format, ctx_data_format;
- TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &orig_T));
- TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &orig_data_format));
- TF_CHECK_OK(GetNodeAttr(fwdn->def(), "T", &ctx_T));
- TF_CHECK_OK(GetNodeAttr(fwdn->def(), "data_format", &ctx_data_format));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T));
+ TF_CHECK_OK(
+ GetNodeAttr(orig_node->def(), "data_format", &orig_data_format));
+ TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "T", &ctx_T));
+ TF_CHECK_OK(
+ GetNodeAttr(fwd_node->def(), "data_format", &ctx_data_format));
if (orig_data_format != ctx_data_format || orig_T != ctx_T ||
- orign->assigned_device_name() != fwdn->assigned_device_name() ||
- orign->def().device() != fwdn->def().device()) {
+ orig_node->assigned_device_name() !=
+ fwd_node->assigned_device_name() ||
+ orig_node->def().device() != fwd_node->def().device()) {
return Status(
error::Code::INVALID_ARGUMENT,
"data_format or T attribute or devices of BiasAddGrad and "
@@ -1049,18 +1554,22 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* orign,
}
// Get all inputs.
- const int num = orign->num_inputs();
- CHECK_EQ(num, ri->numins);
+ const int num = orig_node->in_edges().size();
+ // Check the number of inputs against the user-specified value for non-vararg
+ // nodes.
+ if (!IsVarArgNode(orig_node)) {
+ CHECK_EQ(num, ri->num_ins);
+ }
gtl::InlinedVector<Node*, 4> control_edges;
gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num);
- FillInputs(orign, &control_edges, &inputs);
+ FillInputs(orig_node, &control_edges, &inputs);
// Build new node. We use same name as original node, but change the op name.
- NodeBuilder nb(orign->name().c_str(), ri->newname.c_str());
+ NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
// Copy user-specified device assigned to original node to new node.
- nb.Device(orign->def().device());
+ nb.Device(orig_node->def().device());
// Set up new inputs to the rewritten node.
- Status s = SetUpInputs(g, inputs, &nb, orign);
+ Status s = SetUpInputs(g, inputs, &nb, orig_node);
if (s != Status::OK()) {
return s;
}
@@ -1068,62 +1577,63 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* orign,
// Copy attributes from original node to new node (for scenario 1).
// For context-based rewrite, we use context to copy the attributes.
if (is_context_based_rewrite) {
- if (orign->type_string() == csinfo_.biasaddgrad &&
- ri->newname == csinfo_.mklconv2dwithbiasbackpropbias) {
- CHECK_NOTNULL(fwdn);
- ri->copyattrs(fwdn, &nb);
+ if (orig_node->type_string() == csinfo_.bias_add_grad &&
+ ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) {
+ CHECK_NOTNULL(fwd_node);
+ ri->copy_attrs(fwd_node, &nb);
} else {
return Status(error::Code::UNIMPLEMENTED,
"Unimplemented case for node rewrite optimization.");
}
} else {
- ri->copyattrs(const_cast<const Node*>(orign), &nb);
+ ri->copy_attrs(const_cast<const Node*>(orig_node), &nb);
}
// Set the Mkl layer label for this op.
- nb.Attr("_kernel", mkl_layer_registry::kMklLayerLabel);
-
- // Add workspace edge to this node if needed.
- // We add workspace edge only for MaxPool, LRN and BatchNorm.
- AddWorkSpaceEdgeIfNeeded(g, orign, &nb);
+ nb.Attr("_kernel", mkl_op_registry::kMklOpLabel);
// Finalize graph and get new node.
- Node* newn = nullptr;
- TF_CHECK_OK(nb.Finalize(&**g, &newn));
- CHECK_NOTNULL(newn);
-
- // Incoming edges from 'orign' node to new 'newn' node are already copied
- // in BuildNode. Copy outgoing edges from 'orign' node to new 'newn' node.
- // Since the output also follows same ordering among Tensorflow tensors and
- // Mkl tensors. We need to connect Tensorflow tensors appropriately.
- // Specifically, nth output of original node will become 2*nth output of
- // Mkl node. GetTensorDataIndex provides this mapping function.
- for (const Edge* e : orign->out_edges()) {
+ Node* new_node = nullptr;
+ TF_CHECK_OK(nb.Finalize(&**g, &new_node));
+ CHECK_NOTNULL(new_node);
+
+ // Incoming edges from 'orig_node' node to new 'new_node' node are already
+ // copied in BuildNode. Copy outgoing edges from 'orig_node' node to new
+ // 'new_node' node, since the output also follows same ordering among
+ // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
+ // tensors appropriately. Specifically, nth output of the original node
+ // will become 2*nth output of the Mkl node for the interleaved ordering
+ // of the tensors. For the contiguous ordering of the tensors, it will be n.
+ // GetTensorDataIndex provides this mapping function.
+ for (const Edge* e : orig_node->out_edges()) {
// We need to handle control-edges by using their original slot number.
// Generally, -1 is reserved for control slot.
if (e->src_output() < 0) {
- (*g)->AddEdge(newn, e->src_output(), e->dst(), e->dst_input());
+ (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
} else {
- (*g)->AddEdge(newn, GetTensorDataIndex(e->src_output()), e->dst(),
- e->dst_input());
+ (*g)->AddEdge(
+ new_node,
+ GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
+ e->dst(), e->dst_input());
}
}
// Copy the runtime device assigned from original code to new node.
- newn->set_assigned_device_name(orign->assigned_device_name());
+ new_node->set_assigned_device_name(orig_node->assigned_device_name());
// Delete original node and mark new node as rewritten.
- (*g)->RemoveNode(orign);
- MarkRewrittenNode(newn);
+ (*g)->RemoveNode(orig_node);
+ MarkRewrittenNode(new_node);
- VLOG(1) << "MklLayoutRewritePass: New node:" << newn->DebugString();
+ VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
return Status::OK();
}
const MklLayoutRewritePass::ContextInfo*
-MklLayoutRewritePass::SearchMatchingContext(const Node* n, const Node** fwdn) {
+MklLayoutRewritePass::SearchMatchingContext(const Node* n,
+ const Node** fwd_node) {
CHECK_NOTNULL(n);
- CHECK_NOTNULL(fwdn);
- *fwdn = nullptr;
+ CHECK_NOTNULL(fwd_node);
+ *fwd_node = nullptr;
// Search for matching contextinfo based on node name.
// There could be more than one matching contextinfos.
@@ -1171,7 +1681,7 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n, const Node** fwdn) {
// If we find a match, we return immediately.
for (const ContextInfo* ci : mci) {
if (curr_node->type_string() == ci->fwd) {
- *fwdn = curr_node;
+ *fwd_node = curr_node;
return ci;
}
}
@@ -1192,8 +1702,8 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n, const Node** fwdn) {
}
bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n) {
- const Node* fwdn = nullptr;
- return SearchMatchingContext(n, &fwdn) != nullptr;
+ const Node* fwd_node = nullptr;
+ return SearchMatchingContext(n, &fwd_node) != nullptr;
}
const MklLayoutRewritePass::RewriteInfo*
@@ -1208,7 +1718,8 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
if (!GetNodeAttr(n->def(), "T", &T).ok()) {
return nullptr;
}
- if (!mkl_layer_registry::IsMklLayer(GetMklOpName(n->type_string()), T)) {
+
+ if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) {
return nullptr;
}
@@ -1219,7 +1730,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
// Find matching RewriteInfo and then check that rewrite rule applies.
for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
- if (n->type_string().compare(ri->name) == 0 && ri->rewriterule(n)) {
+ if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
return &*ri;
}
}
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 142d60d611..6e72baf84e 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -110,9 +110,11 @@ class MklLayoutPassTest : public ::testing::Test {
};
REGISTER_OP("Input").Output("o: float").SetIsStateful();
+REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
-REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
-REGISTER_OP("MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
+REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
+REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
+REGISTER_OP("_MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
/////////////////////////////////////////////////////////////////////
// Unit tests related to node merge optiimization
@@ -133,20 +135,22 @@ TEST_F(MklLayoutPassTest, Basic) {
// Test set 1: Conv2D + AddBias
-// C=MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y)
+// C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved ordering)
+// C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous ordering)
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
"node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}"
+ " input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
@@ -157,26 +161,28 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);"
- "M(MklInput);N(MklInput);Y(Input);Z(Sub)|A->E;B->E:2;D->E:4;"
- "DMT/_0->E:5;E->Z;M->E:1;N->E:3;Y->Z:1");
+ "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
+ "M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
+ "DMT/_0->E:5;E->Z;M->E:3;N->E:4;Y->Z:1");
}
-// C=MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y)
+// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved)
+// C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous)
// Test for correct output slots selected
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput2'}"
"node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput2'}"
- "node { name: 'C' op: 'MklConv2D'"
+ "node { name: 'M' op: '_MklInput2'}"
+ "node { name: 'N' op: '_MklInput2'}"
+ "node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M:1', 'B', 'N:1']}"
+ " input: ['A', 'B', 'M:1', 'N:1']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
@@ -187,16 +193,17 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);"
- "M(MklInput2);N(MklInput2);Y(Input);Z(Sub)|A->E;B->E:2;D->E:4;"
- "DMT/_0->E:5;E->Z;M:1->E:1;N:1->E:3;Y->Z:1");
+ "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
+ "M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
+ "DMT/_0->E:5;E->Z;M:1->E:3;N:1->E:4;Y->Z:1");
}
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
// This is a case of node rewrite followed by node merge.
-// We will first rewrite Conv2D to MklConv2D, and then merge MklConv2D
-// with BiasAdd to produce MklConv2DWithBias.
+// We will first rewrite Conv2D to _MklConv2D, and then merge _MklConv2D
+// with BiasAdd to produce _MklConv2DWithBias.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
@@ -218,70 +225,70 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);E(MklConv2DWithBias);Y(Input);Z(Sub)|"
- "A->E;B->E:2;D->E:4;DMT/_0->E:1;DMT/_1->E:3;DMT/_2->E:5;"
+ "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|"
+ "A->E;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;DMT/_2->E:5;"
"E->Z;Y->Z:1");
}
-// Graph contains only MklConv2D, no AddBias.
+// Graph contains only _MklConv2D, no AddBias.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) {
InitGraph(
"node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
"node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}");
+ " input: ['A', 'B', 'M', 'N']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(MklConv2D);M(MklInput);N(MklInput)|"
- "A->C;B->C:2;M->C:1;N->C:3");
+ "A(Input);B(Input);C(_MklConv2D);M(_MklInput);N(_MklInput)|"
+ "A->C;B->C:1;M->C:2;N->C:3");
}
-// MklConv2D output does not go to BiasAdd.
+// _MklConv2D output does not go to BiasAdd.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
"node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}"
+ " input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D', 'E'] }"); // Output of MklConv2D does not go to BiasAdd.
+ " input: ['D', 'E'] }"); // Output of _MklConv2D does not go to BiasAdd.
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
- "M(MklInput);N(MklInput)|A->C;B->C:2;D->F;E->F:1;M->C:1;N->C:3");
+ "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
+ "M(_MklInput);N(_MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3");
}
-// MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
+// _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
// Merge should not be done in such case.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
"node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}"
+ " input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'BiasAdd'"
@@ -293,9 +300,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
- "G(Add);M(MklInput);N(MklInput)|A->C;B->C:2;C->G;D->F;"
- "E->F:1;E->G:1;M->C:1;N->C:3");
+ "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
+ "G(Add);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;"
+ "E->F:1;E->G:1;M->C:2;N->C:3");
}
// data_format attribute value mismatch. Merge should not be done
@@ -303,43 +310,81 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
InitGraph(
"node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
"node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}"
+ " input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHCW' } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(MklConv2D);D(Input);E(BiasAdd);M(MklInput);"
- "N(MklInput)|A->C;B->C:2;C->E;D->E:1;M->C:1;N->C:3");
+ "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);M(_MklInput);"
+ "N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
}
-// No MklConv2D in context, but Conv2D in context.
-// Only Conv2D would be rewritten to MklConv2D, but no rewrite
+// Disabling Conv2DBackpropBias test for now as we have disabled rewrite
+// of BiasAddGrad into BackpropBias
+#if 0
+// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
+// rewrite tests
+
+// D=_MklConv2D(A,M,B,N,C,O); E=Sub(D,A); F=BiasAddGrad(E)
+TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'O' op: '_MklInput'}"
+ "node { name: 'D' op: '_MklConv2DWithBias'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
+ "node { name: 'E' op: 'Sub'"
+ " attr {key: 'T' value { type: DT_FLOAT } }"
+ " input: ['D', 'A']}"
+ "node { name: 'F' op: 'BiasAddGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " input: ['E'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
+ "E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
+ "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;"
+ "M->D:3;N->D:4;O->D:5");
+}
+#endif
+
+// No _MklConv2D in context, but Conv2D in context.
+// Only Conv2D would be rewritten to _MklConv2D, but no rewrite
// for BiasAddGrad should happen.
-// C=MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
+// C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved)
+// C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous)
+TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_No_MklConv2DWithBias) {
InitGraph(
"node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
"node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}"
+ " input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'A']}"
@@ -348,9 +393,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(MklConv2D);D(Sub);E(BiasAddGrad);"
- "M(MklInput);N(MklInput)|A->C;A->D:1;B->C:2;C->D;D->E;"
- "M->C:1;N->C:3");
+ "A(Input);B(Input);C(_MklConv2D);D(Sub);E(BiasAddGrad);"
+ "M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;"
+ "M->C:2;N->C:3");
}
// No Conv2D in the context for BiasAddGrad. No rewrite should happen.
@@ -462,8 +507,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
- "A->C;B->C:2;B->D;C->D:1;DMT/_0->C:1;DMT/_1->C:3");
+ "A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
+ "A->C;B->C:1;B->D;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
}
// 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will
@@ -489,9 +534,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(MklConv2D);D(MklConv2D);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:2;C->D:2;C->E;"
- "C:1->D:3;D->E:1;DMT/_0->C:1;DMT/_1->C:3;DMT/_2->D:1");
+ "A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
+ "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:1;C->D:1;C->E;"
+ "C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
}
// Conv2D with INT32 which is not supported by Mkl
@@ -513,10 +558,374 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
"A->C;B->C:1;B->D;C->D:1");
}
+// Concat Op test: Concat with no Mkl layer feeding it
+TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
+ InitGraph(
+ "node { name: 'A' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value { "
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'B' op: 'InputList'"
+ " attr { key: 'N' value { i: 2 } }}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Concat'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'N' value { i: 2 } }"
+ " input: ['A', 'B']}"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['C', 'D'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
+ "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;B->D:1;B->D:2;C->E;"
+ "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+}
+
+// Concat with 2 Mkl layers feeding it
+TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Input'}"
+ "node { name: 'E' op: 'Conv2D'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B']}"
+ "node { name: 'F' op: 'Conv2D'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['C', 'D']}"
+ "node { name: 'G' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value { "
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'H' op: 'Concat'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'N' value { i: 2 } }"
+ " input: ['G', 'E', 'F']}"
+ "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'H'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+ "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
+ "F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;"
+ "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
+ "DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;H->I:1");
+}
+
+// Concat with 1 Mkl and 1 non-Mkl layer feeding it
+TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Input'}"
+ "node { name: 'E' op: 'Conv2D'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B']}"
+ "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['C', 'D']}"
+ "node { name: 'G' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value { "
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'H' op: 'Concat'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'N' value { i: 2 } }"
+ " input: ['G', 'E', 'F']}"
+ "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'H'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+ "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
+ "H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
+ "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;"
+ "G->H;H->I:1");
+}
+
+#if 0
+// ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it
+TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
+ InitGraph(
+ "node { name: 'A' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value { "
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'B' op: 'InputList'"
+ " attr { key: 'N' value { i: 2 } }}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'ConcatV2'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'Tidx' value { type: DT_INT32 } }"
+ " attr { key: 'N' value { i: 2 } }"
+ " input: ['B:0', 'B:1', 'A']}"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['C', 'D'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
+ "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;C->E;"
+ "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+}
+#endif
+
+// ConcatV2 with 2 Mkl layers feeding it
+TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Input'}"
+ "node { name: 'E' op: 'Conv2D'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B']}"
+ "node { name: 'F' op: 'Conv2D'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['C', 'D']}"
+ "node { name: 'G' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value { "
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'H' op: 'ConcatV2'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'Tidx' value { type: DT_INT32 } }"
+ " attr { key: 'N' value { i: 2 } }"
+ " input: ['E', 'F', 'G']}"
+ "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'H'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+ "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
+ "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;"
+ "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
+ "DMT/_4->H:5;E->H;E:1->H:3;F->H:1;F:1->H:4;G->H:2;H->I:1");
+}
+
+// ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
+TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Input'}"
+ "node { name: 'E' op: 'Conv2D'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B']}"
+ "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['C', 'D']}"
+ "node { name: 'G' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value { "
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'H' op: 'ConcatV2'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'Tidx' value { type: DT_INT32 } }"
+ " attr { key: 'N' value { i: 2 } }"
+ " input: ['E', 'F', 'G']}"
+ "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'H'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+ "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
+ "H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
+ "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;F->H:1;"
+ "G->H:2;H->I:1");
+}
+
/////////////////////////////////////////////////////////////////////
// Unit tests related to rewriting node for workspace edges
/////////////////////////////////////////////////////////////////////
+/* Test LRN->MaxPool->MaxPoolGrad->LRNGrad replacement by workspace nodes. */
+TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'LRN'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'alpha' value { f: 0.001 } }"
+ " attr { key: 'beta' value { f: 0.75 } }"
+ " attr { key: 'bias' value { f: 1.0 } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'depth_radius' value { i: 2 } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'MaxPool'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
+ " input: ['B'] }"
+ "node { name: 'D' op: 'Input'}"
+ "node { name: 'E' op: 'MaxPoolGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
+ " input: ['B', 'C', 'D'] }"
+ "node { name: 'F' op: 'Input'}"
+ "node { name: 'G' op: 'LRNGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'alpha' value { f: 0.001 } }"
+ " attr { key: 'beta' value { f: 0.75 } }"
+ " attr { key: 'bias' value { f: 1.0 } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'depth_radius' value { i: 2 } }"
+ " input: ['E', 'F', 'B'] }"
+ "node { name: 'H' op: 'Input'}"
+ "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['H', 'G'] }");
+ EXPECT_EQ(
+ DoMklLayoutOptimizationPass(),
+ "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
+ "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);I(Mul)|"
+ "A->B;B->C;B->E;B->G:2;B:1->G:3;B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;"
+ "C->E:1;C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;"
+ "DMT/_2->G:5;E->G;E:1->G:4;F->G:1;G->I:1;H->I");
+}
+
+/* Test LRN->LRNGrad replacement by workspace nodes. */
+TEST_F(MklLayoutPassTest, LRN_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'LRN'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'alpha' value { f: 0.001 } }"
+ " attr { key: 'beta' value { f: 0.75 } }"
+ " attr { key: 'bias' value { f: 1.0 } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'depth_radius' value { i: 2 } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Input'}"
+ "node { name: 'E' op: 'LRNGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'alpha' value { f: 0.001 } }"
+ " attr { key: 'beta' value { f: 0.75 } }"
+ " attr { key: 'bias' value { f: 1.0 } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'depth_radius' value { i: 2 } }"
+ " input: ['C', 'D', 'B'] }"
+ "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['C', 'E'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+ "DMT/_2(Const);E(_MklLRNGrad);F(Mul)|"
+ "A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:1;"
+ "DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
+}
+
+/* Test LRN->LRNGrad replacement when only one of them is present. */
+TEST_F(MklLayoutPassTest, LRN_Negative1) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'LRN'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'alpha' value { f: 0.001 } }"
+ " attr { key: 'beta' value { f: 0.75 } }"
+ " attr { key: 'bias' value { f: 1.0 } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'depth_radius' value { i: 2 } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|"
+ "A->B;A->C;B->C:1;DMT/_0->B:1");
+}
+
+/* Test LRN->LRNGrad replacement when only one of them is present. */
+TEST_F(MklLayoutPassTest, LRN_Negative2) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'LRNGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'alpha' value { f: 0.001 } }"
+ " attr { key: 'beta' value { f: 0.75 } }"
+ " attr { key: 'bias' value { f: 1.0 } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'depth_radius' value { i: 2 } }"
+ " input: ['A', 'B', 'C'] }"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'D'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
+ "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
+ "A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
+ "DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
+}
+
+/* Test LRN->LRNGrad negative case, where single LRN feeds
+ 2 LRNGrad nodes at different slots. */
+TEST_F(MklLayoutPassTest, LRN_Negative3) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'LRN'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'alpha' value { f: 0.001 } }"
+ " attr { key: 'beta' value { f: 0.75 } }"
+ " attr { key: 'bias' value { f: 1.0 } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'depth_radius' value { i: 2 } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Input'}"
+ "node { name: 'E' op: 'LRNGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'alpha' value { f: 0.001 } }"
+ " attr { key: 'beta' value { f: 0.75 } }"
+ " attr { key: 'bias' value { f: 1.0 } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'depth_radius' value { i: 2 } }"
+ " input: ['C', 'D', 'B'] }"
+ "node { name: 'F' op: 'LRNGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'alpha' value { f: 0.001 } }"
+ " attr { key: 'beta' value { f: 0.75 } }"
+ " attr { key: 'bias' value { f: 1.0 } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'depth_radius' value { i: 2 } }"
+ " input: ['C', 'B', 'D'] }"
+ "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['E', 'F'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+ "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
+ "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;B->E:2;"
+ "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;D->E:1;"
+ "D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
+ "DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
+}
+
/* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
InitGraph(
@@ -540,10 +949,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
"node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);E(MklMaxPoolGrad);F(Mul)|"
- "A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:4;"
- "DMT/_0->B:1;DMT/_1->E:1;DMT/_2->E:5;E->F:1");
+ "A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
+ "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|"
+ "A->B;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;C->E;C->F;D->E:2;"
+ "DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
}
// Test MaxPool>MaxPoolGrad replacement when only one of them is present.
@@ -562,11 +971,11 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MklMaxPool);C(Mul);DMT/_0(Const)|"
+ "A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|"
"A->B;A->C;B->C:1;DMT/_0->B:1");
}
-// Test MaxPool->MaxPoolGrad replacement when only one of them is present.
+// Test MaxPoolGrad replacement when only one of them is present.
// In this case, we will rewrite MaxPoolGrad and for workspace tensor and
// its Mkl part, we will generate dummy tensor.
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
@@ -584,10 +993,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(MklMaxPoolGrad);DMT/_0(Const);"
+ "A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
- "A->D;A->E;B->D:2;C->D:4;D->E:1;DMT/_0->D:1;DMT/_1->D:3;"
- "DMT/_2->D:5;DMT/_3->D:6;DMT/_4->D:7");
+ "A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
+ "DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
}
/////////////////////////////////////////////////////////////////////
diff --git a/tensorflow/core/graph/mkl_optimizer_merge.cc b/tensorflow/core/graph/mkl_optimizer_merge.cc
deleted file mode 100644
index a171a27d8f..0000000000
--- a/tensorflow/core/graph/mkl_optimizer_merge.cc
+++ /dev/null
@@ -1,651 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifdef INTEL_MKL
-// This module implements node merging optimization on the graph.
-// We process the nodes in the graph in reverse postorder
-// (i.e. inputs before their downstream dependencies).
-//
-#include <memory>
-#include <queue>
-#include <set>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "tensorflow/core/graph/mkl_optimizer_merge.h"
-
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/common_runtime/optimization_registry.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/graph/algorithm.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace tensorflow {
-
-// How many hops do we search for matching node in the backward dataflow graph?
-// We use maxhop of 10 based on empirical observations. Also, these are
-// maxhops in backward data-flow graph. Since input of forward nodes (Conv2D)
-// directly goes to backward nodes, we do not expect the hop-distance
-// would be more than few nodes.
-static size_t kNodeMergeContextMaxDepth = 10;
-
-// This optimization pass performs two tasks: merge
-// nodes in the forward pass, and rewrite the gradient ops
-// corresponding to merged forward ops.
-//
-// Merging nodes in the graph: Currently, it merges Conv2D+AddBias together.
-//
-// Rewriting nodes in the graph: This is neded in order to optimize
-// gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and
-// MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into
-// Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad.
-// This is context-specific optimization, where the context is the
-// forward operator that the BiasAddGrad corresponds to.
-class NodeMergeRewritePass : public GraphOptimizationPass {
- public:
- NodeMergeRewritePass() {
- csinfo_.conv2d = "MklConv2D";
- csinfo_.conv2dwithbias = "MklConv2DWithBias";
- csinfo_.conv2dwithbiasbackpropbias = "Conv2DWithBiasBackpropBias";
- csinfo_.biasadd = "BiasAdd";
- csinfo_.matmul = "MatMul";
- csinfo_.biasaddgrad = "BiasAddGrad";
-
- minfo_.push_back(
- {csinfo_.conv2d, csinfo_.biasadd, 0, csinfo_.conv2dwithbias});
-
-// We use maxhop of 10 based on emperical observations. Also, these are
-// maxhops in backward data-flow graph. Since input of forward nodes
-// (Conv2D) directly goes to backward nodes, we do not expect the
-// hop-distance would be more than few nodes.
-// TODO(nhasabni) Temporarily disabling rewrite of BiasAddGrad.
-// Will enable it once we support Conv2DWithBiasBackpropBias op.
-#if 0
- rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias,
- {csinfo_.conv2dwithbias, kNodeMergeContextMaxDepth}});
- rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias,
- {csinfo_.conv2d, kNodeMergeContextMaxDepth}});
- // For now, we are rewriting BiasAddGrad to BiasAddGrad for MatMul. This is
- // because we do not have a separate Op for MatMulwithBias.
- rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.biasaddgrad,
- {csinfo_.matmul, kNodeMergeContextMaxDepth}});
-#endif
- }
-
- // Standard interface to run optimization pass
- Status Run(const GraphOptimizationPassOptions& options);
-
- // Helper function which does most of heavy lifting for node merge
- //
- // Extracts common functionality between Run public interface and
- // test interface.
- //
- // @return true, if and only if graph is mutated; false otherwise.
- bool RunPass(std::unique_ptr<Graph>* g);
-
- private:
- /// Structure to specify information used in node merge
- typedef struct {
- string pred; // Predecessor node string
- string succ; // Successor node string
- int op; // What operand no the predecessor node corresponds
- // to successor node?
- string newnode; // Name of the node after merge
- } MergeInfo;
-
- /// Structure to specify information used in node rewrite
- typedef struct {
- string node; // Name of the node to be rewritten
- string rewrite; // New name of the node after rewrite
- typedef struct {
- string fwd; // Node name in forward pass that this node
- // corresponds to
- size_t maxhop; // Maximum number of hops the mfwd_ is located
- // from this node. If mfwd_ is farther than mmaxhop_
- // then we do not rewrite the node.
- } ContextInfo;
- ContextInfo cinfo; // Context for rewrite
- } RewriteInfo;
-
- /// Structure to store all constant strings
- typedef struct {
- string conv2d;
- string conv2dwithbias;
- string conv2dwithbiasbackpropbias;
- string biasadd;
- string matmul;
- string biasaddgrad;
- } ConstStringInfo;
-
- ConstStringInfo csinfo_;
- std::vector<MergeInfo> minfo_;
- std::vector<RewriteInfo> rinfo_;
-
- private:
- // Return a node that can be merged with input node
- //
- // @return pointer to the node if we can find such a
- // node. Otherwise, it returns nullptr.
- Node* FindNodeForMerge(const Node* a) const;
-
- // Merge predecessor node with its successor.
- // Currently, we merge Conv2D with AddBias only.
- //
- // Input nodes succ and pred may be deleted if the call to
- // this function is successful. Attempt to use the pointers
- // after the call to function may result is undefined behaviors.
- //
- // @input g - input graph, succ - successor node, pred - predecessor node
- // @return Status::OK(), if merging is successful and supported.
- // Returns appropriate Status error code otherwise.
- // Graph is updated in case nodes are merged. Otherwise, it is
- // not updated.
- Status MergeNode(std::unique_ptr<Graph>* g, Node* succ, Node* pred);
-
- // Is input node (n) a candidate for rewrite?
- //
- // @return true, if it can be rewritten; false, otherwise.
- bool IsApplicableRewriteNode(const Node* n) const;
-
- // Rewrites input node to a new node specified by its matching rewrite info.
- //
- // Method first searches matching rewrite info for input node and then
- // uses that info to rewrite.
- //
- // Input node may be deleted in case of rewrite. Attempt to use the node
- // after the call can result in undefined behaviors.
- //
- // @input g - input graph, n - Node to be rewritten
- // @return Status::OK(), if the input node is rewritten;
- // Returns appropriate Status error code otherwise.
- // Graph is updated in case the input node is rewritten.
- // Otherwise, it is not updated.
- Status RewriteNode(std::unique_ptr<Graph>* g, Node* n);
-
- // Helper function that searches the matching rewriteinfo for the node.
- // Implements depth-first search in the data dependence graph for the
- // gradient op in backward direction.
- //
- // @input n - Node (gradient op) whose rewriteinfo is to be searched,
- // fwdn - pointer to node from the forward pass that this node
- // belongs to
- // @return Matching rewriteinfo in case a match is found; null otherwise.
- const RewriteInfo* FindMatchingRewriteInfo(const Node* n,
- const Node** fwdn) const;
-
- // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
- // and return it in '*out'.
- // TODO(nhasabni) We should move this to mkl_util.h
- void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out);
-};
-
-// We register merge optimizer for phase 2 in pre-placement group.
-// Do not change the ordering of the Mkl passes.
-REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 2,
- NodeMergeRewritePass);
-
-static void FillInputs(const Node* n,
- gtl::InlinedVector<Node*, 4>* control_edges,
- gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
- DCHECK_EQ(in->size(), n->num_inputs());
- control_edges->clear();
- for (const Edge* e : n->in_edges()) {
- if (e->IsControlEdge()) {
- control_edges->push_back(e->src());
- } else {
- (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
- }
- }
- std::sort(control_edges->begin(), control_edges->end());
- if (n->op_def().is_commutative()) {
- // For commutative inputs, we sort the input by the input Node*
- // to get a canonical ordering (so that add(a,b) and add(b, a) will
- // hash to the same value if is_commutative is true for 'add').
- std::sort(in->begin(), in->end());
- }
-}
-
-Node* NodeMergeRewritePass::FindNodeForMerge(const Node* a) const {
- // Search for all matching mergeinfo.
- // We allow more than one match for extensibility.
- std::vector<const MergeInfo*> matching_mi;
- for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
- if (a->type_string() == mi->succ) {
- matching_mi.push_back(&*mi);
- }
- }
-
- for (const MergeInfo* mi : matching_mi) {
- const int N_in = a->num_inputs();
- if (mi->op >= N_in) {
- continue;
- }
-
- // Get the control edges and input of node
- gtl::InlinedVector<Node*, 4> a_control_edges;
- gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
- FillInputs(a, &a_control_edges, &a_in);
-
- // Get operand op of the operator
- Node* b = nullptr;
- b = a_in[mi->op].first;
- if (b == nullptr || (b->type_string() != mi->pred)) {
- // NOTE: Should the first check be assert?
- continue;
- }
-
- gtl::InlinedVector<Node*, 4> b_control_edges;
- gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
- FillInputs(b, &b_control_edges, &b_in);
-
- // Shouldn't merge if a and b have different control edges.
- if (a_control_edges != b_control_edges) {
- continue;
- } else {
- // We found a match.
- return b;
- }
- }
-
- return nullptr;
-}
-
-void NodeMergeRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
- Node** out) {
- const DataType dt = DataTypeToEnum<uint8>::v();
- TensorProto proto;
- proto.set_dtype(dt);
- uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
- proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)),
- 8);
- TensorShape dummy_shape({8});
- dummy_shape.AsProto(proto.mutable_tensor_shape());
- TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
- .Attr("value", proto)
- .Attr("dtype", dt)
- .Finalize(&**g, out));
-}
-
-Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
- Node* pred) {
- CHECK_NOTNULL(succ);
- CHECK_NOTNULL(pred);
-
- if (succ->type_string() == csinfo_.biasadd &&
- pred->type_string() == csinfo_.conv2d) {
- // 1. Get all attributes from input nodes.
- DataType T_pred, T_succ;
- string padding;
- std::vector<int32> strides;
- string data_format_pred, data_format_succ;
- bool use_cudnn_on_gnu;
- TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
- TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
- TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
- TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
- TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
- TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
- TF_CHECK_OK(
- GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu));
- // We check to ensure that data formats of both succ and pred are same.
- // We expect them to be same, so we can enforce this as assert.
- // But assert can be too strict, so we enforce this as a check.
- // If the check fails, then we do not merge two nodes.
- // We also do same check for devices.
- if (data_format_pred != data_format_succ || T_pred != T_succ ||
- pred->assigned_device_name() != succ->assigned_device_name() ||
- pred->def().device() != succ->def().device()) {
- return Status(error::Code::INVALID_ARGUMENT,
- "data_format or T attribute or devices of Conv2D and "
- "BiasAdd do not match. Will skip node merge optimization");
- }
-
- // 2. Get inputs from both the nodes.
- // Find the 2 inputs from the conv and the bias from the add Bias.
- Node* oper1 = nullptr;
- Node* oper1_mkl = nullptr; // Mkl tensor corresponding to oper1
- Node* oper2 = nullptr;
- Node* oper2_mkl = nullptr; // Mkl tensor corresponding to oper2
- Node* oper3 = nullptr;
- Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3
-
- const int succ_num = succ->num_inputs();
- gtl::InlinedVector<Node*, 4> succ_control_edges;
- gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
- FillInputs(succ, &succ_control_edges, &succ_in);
-
- const int pred_num = pred->num_inputs();
- gtl::InlinedVector<Node*, 4> pred_control_edges;
- gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
- FillInputs(pred, &pred_control_edges, &pred_in);
-
- // We need to ensure that there is only 1 edge between Conv2D and AddBias.
- // Otherwise, merging is semantically incorrect.
- if (pred->out_edges().size() != 1) {
- return Status(error::Code::INVALID_ARGUMENT,
- "Conv2D has multiple outputs."
- "Will skip node merge optimization");
- }
-
- for (const Edge* e : pred->out_edges()) {
- if (e->dst() != succ) {
- return Status(error::Code::INVALID_ARGUMENT,
- "Conv2D does not feed to BiasAdd."
- "Will skip node merge optimization");
- }
- }
-
- // Get operand 0, 1 of conv2D and their Mkl tensors.
- CHECK_EQ(pred->in_edges().size(), 4); // MklConv2D must have 4 inputs.
- oper1 = pred_in[0].first;
- oper1_mkl = pred_in[1].first;
- oper2 = pred_in[2].first;
- oper2_mkl = pred_in[3].first;
- // Get operand 1 of add_bias
- // BiasAdd must have 2 inputs: Conv, bias
- CHECK_EQ(succ->in_edges().size(), 2);
- oper3 = succ_in[1].first;
- GetDummyMklTensorNode(g, &oper3_mkl); // Get dummy Mkl tensor node
- // as BiasAdd does not have Mkl tensor as input.
- CHECK_NOTNULL(oper3_mkl);
-
- Node* ret;
- // We will use the node name of BiasAdd as the name of new node
- TF_CHECK_OK(NodeBuilder(succ->name(), csinfo_.conv2dwithbias)
- .Input(oper1)
- .Input(oper1_mkl)
- .Input(oper2)
- .Input(oper2_mkl)
- .Input(oper3)
- .Input(oper3_mkl)
- .Attr("T", T_pred)
- .Attr("strides", strides)
- .Attr("padding", padding)
- .Attr("data_format", data_format_pred)
- .Attr("use_cudnn_on_gpu", use_cudnn_on_gnu)
- .Device(succ->def().device())
- .Finalize(&**g, &ret));
- CHECK_NOTNULL(ret);
-
- // Incoming edges are fixed, we will fix the outgoing edges now.
- for (const Edge* e : succ->out_edges()) {
- (*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input());
- }
-
- // Copy device assigned to old node to new node.
- // It's ok to use pred or succ as we have enforced a check that
- // both have same device assigned.
- ret->set_assigned_device_name(pred->assigned_device_name());
-
- VLOG(1) << "NodeMergeRewritePass: Merged old node:" << pred->DebugString()
- << ", and node: " << succ->DebugString()
- << ", into node:" << ret->DebugString();
-
- (*g)->RemoveNode(succ);
- (*g)->RemoveNode(pred);
-
- return Status::OK();
- }
-
- return Status(error::Code::UNIMPLEMENTED,
- "Unimplemented case for node merge optimization.");
-}
-
-Status NodeMergeRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* n) {
- CHECK_NOTNULL(n);
-
- // Get the matching rewriteinfo for the node
- const Node* fwdn = nullptr;
- const RewriteInfo* ri = FindMatchingRewriteInfo(n, &fwdn);
- if (ri == nullptr || fwdn == nullptr) {
- VLOG(2) << "NodeMergeRewritePass: Rewriteinfo not found for: "
- << n->type_string();
- return Status(error::Code::INVALID_ARGUMENT,
- "Rewrite info not found for the node."
- "Will skip node rewrite optimization");
- }
-
- VLOG(1) << "NodeMergeRewritePass: Rewrite called for: " << n->type_string();
-
- if (n->type_string() == csinfo_.biasaddgrad &&
- ri->node == csinfo_.biasaddgrad &&
- (ri->rewrite == csinfo_.conv2dwithbiasbackpropbias ||
- ri->rewrite == csinfo_.biasaddgrad)) {
- DataType T;
- string data_format;
- TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(n->def(), "data_format", &data_format));
-
- int n_num = n->num_inputs(); // this must be 1.
- CHECK_EQ(n_num, 1);
-
- gtl::InlinedVector<Node*, 4> n_control_edges;
- gtl::InlinedVector<std::pair<Node*, int>, 4> n_in(n_num);
- FillInputs(n, &n_control_edges, &n_in);
-
- Node *ret = nullptr, *op = n_in[0].first;
-
- if (ri->rewrite == csinfo_.conv2dwithbiasbackpropbias) {
- // Get strides info from Conv2D (node in the forward pass that this
- // node corresponds to).
- std::vector<int32> strides;
- TF_CHECK_OK(GetNodeAttr(fwdn->def(), "strides", &strides));
-
- // We use same name as original node name as there may be fetchoutputs
- // associated with it.
- TF_CHECK_OK(NodeBuilder(n->name(), ri->rewrite)
- .Input(op)
- .Attr("T", T)
- .Attr("data_format", data_format)
- .Attr("strides", strides)
- .Device(n->def().device())
- .Finalize(&**g, &ret));
- } else {
- CHECK_EQ(ri->rewrite, csinfo_.biasaddgrad);
- TF_CHECK_OK(NodeBuilder(n->name(), ri->rewrite)
- .Input(op)
- .Attr("T", T)
- .Attr("data_format", data_format)
- .Device(n->def().device())
- .Finalize(&**g, &ret));
- }
-
- CHECK_NOTNULL(ret);
-
- // Incoming edges are fixed, we will fix the outgoing edges now.
- for (const Edge* e : n->out_edges()) {
- (*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input());
- }
-
- // Copy device assigned to old node to new node.
- ret->set_assigned_device_name(n->assigned_device_name());
-
- VLOG(1) << "MKLOptimizerMergePass: Rewrote old node:" << n->DebugString()
- << ", into node:" << ret->DebugString();
- (*g)->RemoveNode(n);
-
- return Status::OK();
- }
-
- return Status(error::Code::UNIMPLEMENTED,
- "Unimplemented case for node rewrite optimization.");
-}
-
-const NodeMergeRewritePass::RewriteInfo*
-NodeMergeRewritePass::FindMatchingRewriteInfo(const Node* n,
- const Node** fwdn) const {
- CHECK_NOTNULL(n);
- CHECK_NOTNULL(fwdn);
- *fwdn = nullptr;
-
- // Search for matching rewriteinfo based on node name.
- // There could be more than one matching rewriteinfos.
- std::vector<const RewriteInfo*> matching_ri;
- for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
- if (n->type_string() == ri->node) {
- matching_ri.push_back(&*ri);
- }
- }
-
- VLOG(1) << "NodeMergeRewritePass: Searching graph for: " << n->type_string()
- << " in backwards.";
-
- // Now we will check for forward op name for rewrite info in data
- // flow graph. Get the max hops we should search for the fwd node
- // We are now going to search (breadth-first) backwards in data
- // dependence graph (for up to max hops) from n for the node
- // specified in fwd.
- // queue to maintain nodes to be visited and depth info for
- // breadth-first search
- std::queue<std::pair<const Node*, int>> nqueue;
- const Node* curr_node = n;
- size_t curr_depth = 0;
- nqueue.push(std::make_pair(curr_node, curr_depth));
-
- while (curr_depth < kNodeMergeContextMaxDepth && !nqueue.empty()) {
- std::pair<const Node*, int> curr_pair = nqueue.front();
- nqueue.pop();
-
- std::set<const Node*> visited_nodes;
- curr_node = curr_pair.first;
- curr_depth = curr_pair.second;
- CHECK_NOTNULL(curr_node);
-
- VLOG(1) << "NodeMergeRewritePass: Visiting node: "
- << curr_node->type_string() << " at depth: " << curr_depth
- << " for node: " << n->type_string();
-
- // If we find a match, we return immediately with the matching rewrite
- // info.
- for (const RewriteInfo* ri : matching_ri) {
- if (curr_node->type_string() == ri->cinfo.fwd) {
- *fwdn = curr_node;
- return ri;
- }
- }
-
- // Else we explore backward edges from current node.
- // Add the source nodes of all incoming edges of the node to the queue.
- for (const Edge* e : curr_node->in_edges()) {
- // We do not visit already visited node.
- if (visited_nodes.find(e->src()) == visited_nodes.end()) {
- // Depth of these nodes is 1 more than the depth of current node.
- nqueue.push(std::make_pair(e->src(), curr_depth + 1));
- visited_nodes.insert(e->src());
- }
- }
- } /* while */
-
- return nullptr;
-}
-
-bool NodeMergeRewritePass::IsApplicableRewriteNode(const Node* n) const {
- CHECK_NOTNULL(n);
-
- // Search for matching rewriteinfo
- // Even if we find one match, we return true.
- bool match_found = false;
- for (const RewriteInfo& ri : rinfo_) {
- if (n->type_string() == ri.node) {
- match_found = true;
- break;
- }
- }
-
- return match_found;
-}
-
-bool NodeMergeRewritePass::RunPass(std::unique_ptr<Graph>* g) {
- bool result = false;
- CHECK_NOTNULL(g);
-
- DumpGraph("Before OptimizeMerge", &**g);
-
- std::vector<Node*> order;
- GetReversePostOrder(**g, &order);
- std::vector<std::pair<Node*, Node*>> nodes_to_be_merged;
- std::vector<Node*> nodes_to_be_rewritten;
-
- for (Node* n : order) {
- if (!n->IsOp()) continue;
- Node* n1 = nullptr;
- if ((n1 = FindNodeForMerge(n)) != nullptr) {
- VLOG(1) << "NodeMergeRewritePass: Scheduled nodes " << n->name()
- << " and " << n1->name() << " for merging";
- nodes_to_be_merged.push_back(std::make_pair(n, n1));
- } else if (IsApplicableRewriteNode(n)) {
- VLOG(1) << "NodeMergeRewritePass: Scheduled node " << n->name()
- << " for rewrite";
- nodes_to_be_rewritten.push_back(n);
- }
- }
-
- for (std::pair<Node*, Node*> i : nodes_to_be_merged) {
- // Even if MergeNode merges single pair of nodes, we
- // need to return true.
- string n1_name = i.first->name();
- string n2_name = i.second->name();
- if (MergeNode(g, i.first, i.second) == Status::OK()) {
- VLOG(1) << "NodeMergeRewritePass: Merged nodes " << n1_name << " and "
- << n2_name;
- result = true;
- }
- }
-
- DumpGraph("After OptimizeMerge(nodemerge)", &**g);
-
- for (Node* i : nodes_to_be_rewritten) {
- string name = i->name();
- if (RewriteNode(g, i) == Status::OK()) {
- VLOG(1) << "NodeMergeRewritePass: Rewrite node: " << name
- << " successful.";
- result = true;
- }
- }
-
- DumpGraph("After OptimizeMerge(noderewrite)", &**g);
-
- return result;
-}
-
-bool OptimizeNodeMerge(std::unique_ptr<Graph>* g) {
- return NodeMergeRewritePass().RunPass(g);
-}
-
-Status NodeMergeRewritePass::Run(const GraphOptimizationPassOptions& options) {
- if (options.graph == nullptr) {
- return Status::OK();
- }
-
- // Get the ownership of graph
- std::unique_ptr<Graph>* g = std::move(options.graph);
-
- RunPass(g);
-
- // Return the ownership of graph back
- options.graph->reset(g->release());
-
- return Status::OK();
-}
-
-} // namespace tensorflow
-
-#endif
diff --git a/tensorflow/core/graph/mkl_optimizer_merge.h b/tensorflow/core/graph/mkl_optimizer_merge.h
deleted file mode 100644
index b2caec58af..0000000000
--- a/tensorflow/core/graph/mkl_optimizer_merge.h
+++ /dev/null
@@ -1,36 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// An optimization pass that performs node merging and rewrite on graph nodes
-
-#ifndef TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
-#define TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
-
-#ifdef INTEL_MKL
-
-#include <sys/types.h>
-#include <memory>
-#include "tensorflow/core/graph/graph.h"
-
-namespace tensorflow {
-// Interface to invoke the pass for unit test
-//
-// Returns true if and only if 'g' is mutated.
-extern bool OptimizeNodeMerge(std::unique_ptr<Graph>* g);
-} // namespace tensorflow
-
-#endif // INTEL_MKL
-
-#endif // TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
diff --git a/tensorflow/core/graph/mkl_optimizer_merge_test.cc b/tensorflow/core/graph/mkl_optimizer_merge_test.cc
deleted file mode 100644
index f752721d6e..0000000000
--- a/tensorflow/core/graph/mkl_optimizer_merge_test.cc
+++ /dev/null
@@ -1,470 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifdef INTEL_MKL
-
-#include "tensorflow/core/graph/mkl_optimizer_merge.h"
-
-#include <vector>
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/graph/testlib.h"
-#include "tensorflow/core/kernels/ops_util.h"
-#include "tensorflow/core/lib/random/simple_philox.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/test_benchmark.h"
-
-namespace tensorflow {
-namespace {
-
-class OptimizerMergeTest : public ::testing::Test {
- public:
- OptimizerMergeTest() : graph_(OpRegistry::Global()) {}
-
- static void InitGraph(const string& s, Graph* graph) {
- GraphDef graph_def;
-
- auto parser = protobuf::TextFormat::Parser();
- CHECK(parser.MergeFromString(s, &graph_def)) << s;
- GraphConstructorOptions opts;
- TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
- }
-
- void InitGraph(const string& s) {
- InitGraph(s, &graph_);
- original_ = CanonicalGraphString(&graph_);
- }
-
- static bool IncludeNode(const Node* n) { return n->IsOp(); }
-
- static string EdgeId(const Node* n, int index) {
- if (index == 0) {
- return n->name();
- } else if (index == Graph::kControlSlot) {
- return strings::StrCat(n->name(), ":control");
- } else {
- return strings::StrCat(n->name(), ":", index);
- }
- }
-
- string CanonicalGraphString(Graph* g) {
- std::vector<string> nodes;
- std::vector<string> edges;
- for (const Node* n : g->nodes()) {
- if (IncludeNode(n)) {
- nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
- }
- }
- for (const Edge* e : g->edges()) {
- if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
- edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
- EdgeId(e->dst(), e->dst_input())));
- }
- }
- // Canonicalize
- std::sort(nodes.begin(), nodes.end());
- std::sort(edges.begin(), edges.end());
- return strings::StrCat(str_util::Join(nodes, ";"), "|",
- str_util::Join(edges, ";"));
- }
-
- string DoNodeMerge() {
- string before = CanonicalGraphString(&graph_);
- LOG(ERROR) << "Before node merge optimize: " << before;
-
- std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_);
- OptimizeNodeMerge(ug);
-
- string result = CanonicalGraphString(&graph_);
- LOG(ERROR) << "After node merge optimize: " << result;
- return result;
- }
-
- const string& OriginalGraph() const { return original_; }
-
- Graph graph_;
- string original_;
-};
-
-REGISTER_OP("Input").Output("o: float").SetIsStateful();
-REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
-
-TEST_F(OptimizerMergeTest, Basic) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }"
- "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoNodeMerge(),
- "A(Input);B(Input);C(Mul);D(Mul)|"
- "A->C;A->D;B->C:1;B->D:1");
-}
-
-// Test set 1: Conv2D + AddBias
-
-// C=MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y)
-TEST_F(OptimizerMergeTest, Conv2DWithBias_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'BiasAdd'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['C', 'D'] }"
- "node { name: 'Y' op: 'Input'}"
- "node { name: 'Z' op: 'Sub'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['E', 'Y']}");
- EXPECT_EQ(DoNodeMerge(),
- "A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);"
- "M(MklInput);N(MklInput);Y(Input);Z(Sub)|A->E;B->E:2;D->E:4;"
- "DMT/_0->E:5;E->Z;M->E:1;N->E:3;Y->Z:1");
-}
-
-// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
-// We do not merge in this case as op is Conv2D and not MklConv2D.
-TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_NoMklConv2D) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'BiasAdd'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['C', 'D'] }"
- "node { name: 'Y' op: 'Input'}"
- "node { name: 'Z' op: 'Sub'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['E', 'Y']}");
- EXPECT_EQ(DoNodeMerge(),
- "A(Input);B(Input);C(Conv2D);D(Input);E(BiasAdd);Y(Input);Z(Sub)|"
- "A->C;B->C:1;C->E;D->E:1;E->Z;Y->Z:1");
-}
-
-// Graph contains only MklConv2D, no AddBias.
-TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_NoAddBias) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}");
- EXPECT_EQ(DoNodeMerge(),
- "A(Input);B(Input);C(MklConv2D);M(MklInput);N(MklInput)|"
- "A->C;B->C:2;M->C:1;N->C:3");
-}
-
-// MklConv2D output does not go to BiasAdd.
-TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_Dataflow1) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'Input'}"
- "node { name: 'F' op: 'BiasAdd'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D', 'E'] }"); // Output of MklConv2D does not go to BiasAdd.
- EXPECT_EQ(DoNodeMerge(),
- "A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
- "M(MklInput);N(MklInput)|A->C;B->C:2;D->F;E->F:1;M->C:1;N->C:3");
-}
-
-// MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
-// Merge should not be done in such case.
-TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_Dataflow2) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'Input'}"
- "node { name: 'F' op: 'BiasAdd'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D', 'E'] }" // Conv2D has two outputs.
- // No merge should happen.
- "node { name: 'G' op: 'Add'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'E'] }");
- EXPECT_EQ(DoNodeMerge(),
- "A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
- "G(Add);M(MklInput);N(MklInput)|A->C;B->C:2;C->G;D->F;"
- "E->F:1;E->G:1;M->C:1;N->C:3");
-}
-
-// data_format attribute value mismatch. Merge should not be done
-// in such case.
-TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_AttrMismatch) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'BiasAdd'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NHCW' } }"
- " input: ['C', 'D'] }");
- EXPECT_EQ(DoNodeMerge(),
- "A(Input);B(Input);C(MklConv2D);D(Input);E(BiasAdd);M(MklInput);"
- "N(MklInput)|A->C;B->C:2;C->E;D->E:1;M->C:1;N->C:3");
-}
-
-#if 0
-// This test set is disabled temporarily as we do not enable node rewrite.
-// This test set will be enabled when we support Mkl-specific kernels for
-// backward bias.
-//
-// Test set 2: MklConv2D..BiasAddGrad -> Conv2DWithBiasBackpropBias
-// rewrite tests
-
-// C=MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(OptimizerMergeTest, Conv2DBackprop_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}"
- "node { name: 'D' op: 'Sub'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'A']}"
- "node { name: 'E' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D'] }");
- EXPECT_EQ(DoNodeMerge(),
- "A(Input);B(Input);C(MklConv2D);D(Sub);E(Conv2DWithBiasBackpropBias);"
- "M(MklInput);N(MklInput)|A->C;A->D:1;B->C:2;C->D;D->E;M->C:1;N->C:3");
-}
-
-// No MklConv2D in context, but Conv2D in context. No rewrite should happen.
-// C=Conv2D(A,B); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(OptimizerMergeTest, Conv2DBackprop_Negative_NoMklConv2D) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Sub'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'A']}"
- "node { name: 'E' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D'] }");
- EXPECT_EQ(DoNodeMerge(),
- "A(Input);B(Input);C(Conv2D);D(Sub);E(BiasAddGrad)|"
- "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-
-// No Conv2D in the context for BiasAddGrad. No rewrite should happen.
-// C=Add(A,B); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(OptimizerMergeTest, Conv2DBackprop_Negative_NoConv2D) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Add'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Sub'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'A']}"
- "node { name: 'E' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D'] }");
- EXPECT_EQ(DoNodeMerge(),
- "A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
- "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-
-// No Conv2D in the context for BiasAddGrad, but MatMul in context.
-// Rewrite should happen, but name of BiasAddGrad does not change.
-// C=MatMul(A,B); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(OptimizerMergeTest, Conv2DBackprop_Negative_NoConv2D_MatMul) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'MatMul'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'transpose_a' value { b: false } }"
- " attr { key: 'transpose_b' value { b: false } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Sub'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'A']}"
- "node { name: 'E' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D'] }");
- EXPECT_EQ(DoNodeMerge(),
- "A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
- "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-
-// Test set 3: MatMul..BiasAddGrad -> BiasAddGrad rewrite tests
-// C=MatMul(A,B); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(OptimizerMergeTest, MatMulBiasAddGrad_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'MatMul'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'transpose_a' value { b: false } }"
- " attr { key: 'transpose_b' value { b: false } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Sub'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'A']}"
- "node { name: 'E' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D'] }");
- EXPECT_EQ(DoNodeMerge(),
- "A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
- "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-
-// No MatMul in the context for BiasAddGrad. No rewrite should happen.
-// C=Add(A,B); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(OptimizerMergeTest, MatMulBiasAddGrad_Negative_NoMatMul) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Add'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Sub'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'A']}"
- "node { name: 'E' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D'] }");
- EXPECT_EQ(DoNodeMerge(),
- "A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
- "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-#endif
-
-static void BM_NodeMerge(int iters, int op_nodes) {
- testing::StopTiming();
- string s;
- for (int in = 0; in < 10; in++) {
- s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
- }
- random::PhiloxRandom philox(301, 17);
- random::SimplePhilox rnd(&philox);
- for (int op = 0; op < op_nodes; op++) {
- s += strings::Printf(
- "node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { "
- "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
- op, rnd.Uniform(10), rnd.Uniform(10));
- }
-
- bool first = true;
- while (iters > 0) {
- Graph* graph = new Graph(OpRegistry::Global());
- OptimizerMergeTest::InitGraph(s, graph);
- int N = graph->num_node_ids();
- if (first) {
- testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
- first = false;
- }
- {
- testing::StartTiming();
- std::unique_ptr<Graph> ug(graph);
- OptimizeNodeMerge(&ug);
- testing::StopTiming();
- }
- iters -= N; // Our benchmark units are individual graph nodes,
- // not whole graphs
- // delete graph;
- }
-}
-BENCHMARK(BM_NodeMerge)->Arg(1000)->Arg(10000);
-
-} // namespace
-} // namespace tensorflow
-
-#endif /* INTEL_MKL */
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index 7c3836b308..55c280719c 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -40,16 +40,16 @@ namespace tensorflow {
// This pass inserts Mkl to Tf tensor conversion nodes (represented by C)
// in the graph in between A and B, where A and B match any one
-// of the following
-// cases:
-// 1) A = layer/Op that generates output in Mkl format and,
-// B = layer/Op that does not accept input in Mkl format and,
+// of the following cases:
+//
+// 1) A = a node that generates output in the Mkl format and,
+// B = a node that does not accept input in the Mkl format and,
// A -> B (there is a direct edge between A and B, then
// We will insert C such that A->C->B.
//
-// 2) A = layer/Op that generates output in Mkl format and,
-// B = NULL (in other words, A is the last layer in the graph), then
-// We will insert C such that A->C->B. (C will be the last layer.)
+// 2) A = a node that generates output in the Mkl format and,
+// B = NULL (in other words, A is the last node in the graph), then
+// We will insert C such that A->C->B. (C will be the last node.)
//
// Note that case 1 applies to all outputs of A that are input to B.
// In other words, the conversions will be required for every output
@@ -59,9 +59,9 @@ namespace tensorflow {
// do the conversion for A1 and A2 only. We do not need to do any conversion
// for A3.
//
-// This pass relies on layers registering themselves about their Mkl compliant.
-// Mkl compliant layer can accept inputs in Mkl format, and produce output in
-// Mkl format. Non-compliant layer accepts inputs and outputs in
+// This pass relies on ops registering themselves about their Mkl compliance.
+// An Mkl-compliant op can accept inputs in the Mkl format, and produce outputs
+// in the Mkl format. Non-compliant ops accept inputs and outputs in the
// TensorFlow format.
//
class MklToTfConversionPass : public GraphOptimizationPass {
@@ -84,7 +84,7 @@ class MklToTfConversionPass : public GraphOptimizationPass {
// @input T Datatype to use for checking input op
// @return true if op is Mkl supported; false, otherwise.
inline bool IsMklSupportedOp(const string& op_name, DataType T) const {
- return mkl_layer_registry::IsMklLayer(op_name, T);
+ return mkl_op_registry::IsMklOp(op_name, T);
}
// Insert layout conversion node on the edge pointed by 'e' from graph 'g'.
@@ -129,14 +129,16 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
}
- // Lets build the conversion node and specify src as input.
+ // Build the conversion node and specify src as input.
TF_CHECK_OK(
- NodeBuilder((*g)->NewName("Mkl2Tf"), "MklToTf")
+ NodeBuilder((*g)->NewName("Mkl2Tf"), "_MklToTf")
.Input(src, e->src_output())
- .Input(src, e->src_output() + 1) // Mkl tensor immediately
- // follows Tf tensor.
- .Device(src->def().device()) // We want to get conversion node
- // on same device as source node.
+ .Input(src, DataIndexToMetaDataIndex(
+ e->src_output(),
+ src->num_outputs())) // Get an Mkl tensor slot
+ // from the Tf tensor slot.
+ .Device(src->def().device()) // We want to get conversion node
+ // on same device as source node.
.Attr("T", src_datatype)
.Finalize(&**g, &conversion_node));
@@ -149,8 +151,8 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
// We want conversion node to be on the same device as the source node.
conversion_node->set_assigned_device_name(src->assigned_device_name());
- // Set the Mkl layer label for this op.
- conversion_node->AddAttr("_kernel", mkl_layer_registry::kMklLayerLabel);
+ // Set the Mkl op label for this op.
+ conversion_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
// Now that we have added edge from src->conversion_node, let's add edge from
// output of conversion_node to the dest node. Since conversion_node
@@ -173,11 +175,11 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
DumpGraph("Before MklToTfConversionPass", &**g);
- // Since we are looking for mkl-supported op node immediately
- // followed by non-mkl op node, we will just iterate over edge
+ // Since we are looking for an Mkl-supported op node immediately
+ // followed by a non-Mkl op node, we will just iterate over edge
// set of the graph.
- // vector to maintain candiadate edges whose source and destination
- // are candidate for inserting conversion node
+ // edge set whose source and destination are candidates for
+ // inserting conversion node
std::vector<Edge*> candidate_edges;
for (const Edge* e : (*g)->edges()) {
@@ -190,9 +192,9 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
}
// We skip adding MklToTf on an edge between X->MklToTf or
- // MklToTf->X, where X is any layer.
- if (src->type_string().compare("MklToTf") == 0 ||
- dst->type_string().compare("MklToTf") == 0) {
+ // MklToTf->X, where X is any node.
+ if (src->type_string().compare("_MklToTf") == 0 ||
+ dst->type_string().compare("_MklToTf") == 0) {
continue;
}
@@ -210,7 +212,6 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
GetNodeAttr(dst->def(), "T", &dst_datatype);
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
-
if (IsMklSupportedOp(src->type_string(), src_datatype) &&
!IsMklSupportedOp(dst->type_string(), dst_datatype)) {
VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index 7d9237f845..bd2cb0989c 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#ifdef INTEL_MKL
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
+#include "tensorflow/core/util/mkl_util.h"
#include <algorithm>
#include <string>
@@ -109,7 +110,7 @@ class MklToTfConversionPass : public ::testing::Test {
REGISTER_OP("Input").Output("o: float").SetIsStateful();
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
-REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
+REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
TEST_F(MklToTfConversionPass, Basic) {
InitGraph(
@@ -125,58 +126,116 @@ TEST_F(MklToTfConversionPass, Basic) {
}
// MklConv2D followed by Non-Mkl layer
-// C=MklConv2D(A,M,B,N); E=Sub(C,D)
+// C=MklConv2D(A,M,B,N); E=Sub(C,D) (for interleaved ordering)
+// C=MklConv2D(A,B,M,N); E=Sub(C,D) (for contiguous ordering)
TEST_F(MklToTfConversionPass, Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'Sub'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'D']}");
- EXPECT_EQ(DoRunMklToTfConversionPass(),
- "A(Input);B(Input);C(MklConv2D);D(Input);E(Sub);M(MklInput);"
- "Mkl2Tf/_0(MklToTf);N(MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
- "C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
+ if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'C' op: '_MklConv2D'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
+ "}"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'M', 'B', 'N']}"
+ "node { name: 'D' op: 'Input'}"
+ "node { name: 'E' op: 'Sub'"
+ " attr {key: 'T' value { type: DT_FLOAT } }"
+ " input: ['C', 'D']}");
+ EXPECT_EQ(DoRunMklToTfConversionPass(),
+ "A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
+ "_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
+ "C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
+ } else {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'C' op: '_MklConv2D'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
+ "}"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B', 'M', 'N']}"
+ "node { name: 'D' op: 'Input'}"
+ "node { name: 'E' op: 'Sub'"
+ " attr {key: 'T' value { type: DT_FLOAT } }"
+ " input: ['C', 'D']}");
+ EXPECT_EQ(DoRunMklToTfConversionPass(),
+ "A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
+ "_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
+ "C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
+ }
}
// MklConv2D followed by MklToTf op followed by Non-Mkl layer.
-// C=MklConv2D(A,M,B,N); D=MklToTf(C:0, C:1) F=Sub(D,E)
+// C=MklConv2D(A,M,B,N); D=MklToTf(C:0, C:1) F=Sub(D,E) (for interleaved)
+// C=MklConv2D(A,B,M,N); D=MklToTf(C:0, C:1) F=Sub(D,E) (for contiguous)
// MklToTf node should not be inserted again.
TEST_F(MklToTfConversionPass, Negative_DoubleInsert) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'M' op: 'MklInput'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'N' op: 'MklInput'}"
- "node { name: 'C' op: 'MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'M', 'B', 'N']}"
- "node { name: 'D' op: 'MklToTf'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['C:0', 'C:1']}"
- "node { name: 'E' op: 'Input'}"
- "node { name: 'F' op: 'Sub'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['D', 'E']}");
- EXPECT_EQ(DoRunMklToTfConversionPass(),
- "A(Input);B(Input);C(MklConv2D);D(MklToTf);E(Input);"
- "F(Sub);M(MklInput);N(MklInput)|"
- "A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
+ if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'C' op: '_MklConv2D'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
+ "}"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'M', 'B', 'N']}"
+ "node { name: 'D' op: '_MklToTf'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " input: ['C:0', 'C:1']}"
+ "node { name: 'E' op: 'Input'}"
+ "node { name: 'F' op: 'Sub'"
+ " attr {key: 'T' value { type: DT_FLOAT } }"
+ " input: ['D', 'E']}");
+ EXPECT_EQ(DoRunMklToTfConversionPass(),
+ "A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);"
+ "F(Sub);M(_MklInput);N(_MklInput)|"
+ "A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
+ } else {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'C' op: '_MklConv2D'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
+ "}"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B', 'M', 'N']}"
+ "node { name: 'D' op: '_MklToTf'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " input: ['C:0', 'C:1']}"
+ "node { name: 'E' op: 'Input'}"
+ "node { name: 'F' op: 'Sub'"
+ " attr {key: 'T' value { type: DT_FLOAT } }"
+ " input: ['D', 'E']}");
+ EXPECT_EQ(DoRunMklToTfConversionPass(),
+ "A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);"
+ "F(Sub);M(_MklInput);N(_MklInput)|"
+ "A->C;B->C:1;C->D;C:1->D:1;D->F;E->F:1;M->C:2;N->C:3");
+ }
}
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.cc b/tensorflow/core/grappler/optimizers/auto_parallel.cc
index b5497d3594..078fb10bc9 100644
--- a/tensorflow/core/grappler/optimizers/auto_parallel.cc
+++ b/tensorflow/core/grappler/optimizers/auto_parallel.cc
@@ -230,7 +230,7 @@ void AutoParallel::BuildGraph(GraphDef* graph) {
AddOneReplica(graph, i);
}
std::set<string> fetches;
- for (int i = 0; i < item_->fetch.size(); i++) {
+ for (size_t i = 0; i < item_->fetch.size(); i++) {
for (int j = 0; j < num_replicas_; j++) {
string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", j);
string fetch = AddPrefixToNodeName(item_->fetch[i], prefix);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index e32f51a3a2..49b12df7aa 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4828,6 +4828,38 @@ tf_mkl_kernel_library(
],
)
+tf_mkl_kernel_library(
+ name = "mkl_fused_batch_norm_op",
+ srcs = ["mkl_fused_batch_norm_op.cc"],
+ deps = NN_DEPS + [
+ "//third_party/mkl:intel_binary_blob",
+ ],
+)
+
+tf_mkl_kernel_library(
+ name = "mkl_concat_op",
+ prefix = "mkl_concat_op",
+ deps = ARRAY_DEPS + [
+ "//third_party/mkl:intel_binary_blob",
+ ],
+)
+
+tf_mkl_kernel_library(
+ name = "mkl_reshape_op",
+ prefix = "mkl_reshape_op",
+ deps = ARRAY_DEPS + [
+ "//third_party/mkl:intel_binary_blob",
+ ],
+)
+
+tf_mkl_kernel_library(
+ name = "mkl_lrn_op",
+ prefix = "mkl_lrn_op",
+ deps = NN_DEPS + [
+ "//third_party/mkl:intel_binary_blob",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/core/kernels/fixed_length_record_reader_op.cc b/tensorflow/core/kernels/fixed_length_record_reader_op.cc
index 637a6cef95..ce7fb9c332 100644
--- a/tensorflow/core/kernels/fixed_length_record_reader_op.cc
+++ b/tensorflow/core/kernels/fixed_length_record_reader_op.cc
@@ -28,12 +28,14 @@ namespace tensorflow {
class FixedLengthRecordReader : public ReaderBase {
public:
FixedLengthRecordReader(const string& node_name, int64 header_bytes,
- int64 record_bytes, int64 footer_bytes, Env* env)
+ int64 record_bytes, int64 footer_bytes,
+ int64 hop_bytes, Env* env)
: ReaderBase(
strings::StrCat("FixedLengthRecordReader '", node_name, "'")),
header_bytes_(header_bytes),
record_bytes_(record_bytes),
footer_bytes_(footer_bytes),
+ hop_bytes_(hop_bytes),
env_(env),
file_pos_limit_(-1),
record_number_(0) {}
@@ -62,14 +64,31 @@ class FixedLengthRecordReader : public ReaderBase {
Status ReadLocked(string* key, string* value, bool* produced,
bool* at_end) override {
- if (input_buffer_->Tell() >= file_pos_limit_) {
+ // The condition `input_buffer_->Tell() + record_bytes_ > file_pos_limit_`
+ // is to confirm that none of record bytes is out of the range of
+ // file_pos_limit_.
+ // This is necessary for the condition `hop_bytes > 0`. For example.
+ // File: "0123456"
+ // Reader setting: `record_bytes=3`, `hop_bytes=2`, `footer_bytes=0`,
+ // `header_bytes=0`
+ // Without this checking condition, the forth time the reader will at
+ // this position: "012345|6" and the reading operation will result in
+ // an error.
+ if (input_buffer_->Tell() >= file_pos_limit_ ||
+ input_buffer_->Tell() + record_bytes_ > file_pos_limit_) {
*at_end = true;
return Status::OK();
}
+ const int64 pos_before_read = input_buffer_->Tell();
TF_RETURN_IF_ERROR(input_buffer_->ReadNBytes(record_bytes_, value));
*key = strings::StrCat(current_work(), ":", record_number_);
*produced = true;
++record_number_;
+
+ if (hop_bytes_ > 0) {
+ input_buffer_->Seek(pos_before_read + hop_bytes_).IgnoreError();
+ }
+
return Status::OK();
}
@@ -87,6 +106,7 @@ class FixedLengthRecordReader : public ReaderBase {
const int64 header_bytes_;
const int64 record_bytes_;
const int64 footer_bytes_;
+ const int64 hop_bytes_;
Env* const env_;
int64 file_pos_limit_;
int64 record_number_;
@@ -98,10 +118,12 @@ class FixedLengthRecordReaderOp : public ReaderOpKernel {
public:
explicit FixedLengthRecordReaderOp(OpKernelConstruction* context)
: ReaderOpKernel(context) {
- int64 header_bytes = -1, record_bytes = -1, footer_bytes = -1;
+ int64 header_bytes = -1, record_bytes = -1, footer_bytes = -1,
+ hop_bytes = -1;
OP_REQUIRES_OK(context, context->GetAttr("header_bytes", &header_bytes));
OP_REQUIRES_OK(context, context->GetAttr("record_bytes", &record_bytes));
OP_REQUIRES_OK(context, context->GetAttr("footer_bytes", &footer_bytes));
+ OP_REQUIRES_OK(context, context->GetAttr("hop_bytes", &hop_bytes));
OP_REQUIRES(context, header_bytes >= 0,
errors::InvalidArgument("header_bytes must be >= 0 not ",
header_bytes));
@@ -111,11 +133,15 @@ class FixedLengthRecordReaderOp : public ReaderOpKernel {
OP_REQUIRES(context, footer_bytes >= 0,
errors::InvalidArgument("footer_bytes must be >= 0 not ",
footer_bytes));
+ OP_REQUIRES(
+ context, hop_bytes >= 0,
+ errors::InvalidArgument("hop_bytes must be >= 0 not ", hop_bytes));
Env* env = context->env();
- SetReaderFactory([this, header_bytes, record_bytes, footer_bytes, env]() {
- return new FixedLengthRecordReader(name(), header_bytes, record_bytes,
- footer_bytes, env);
- });
+ SetReaderFactory(
+ [this, header_bytes, record_bytes, footer_bytes, hop_bytes, env]() {
+ return new FixedLengthRecordReader(name(), header_bytes, record_bytes,
+ footer_bytes, hop_bytes, env);
+ });
}
};
diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc
index 71918fe269..8bd1724e32 100644
--- a/tensorflow/core/kernels/mkl_avgpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc
@@ -29,10 +29,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename Device, typename T>
-class MklAvgPoolingOp : public UnaryOp<T> {
+class MklAvgPoolingOp : public OpKernel {
public:
- explicit MklAvgPoolingOp(OpKernelConstruction* context)
- : UnaryOp<T>(context) {
+ explicit MklAvgPoolingOp(OpKernelConstruction* context) : OpKernel(context) {
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
@@ -78,6 +77,7 @@ class MklAvgPoolingOp : public UnaryOp<T> {
Tensor mkl_tmp_input_buf_tensor_;
mkl_context.MklCreateLayoutsAndPrimitives(context,
&mkl_tmp_input_buf_tensor_);
+ OP_REQUIRES_OK(context, context->status());
Tensor workspace_tensor;
void* workspace_buf;
@@ -120,7 +120,7 @@ class MklAvgPoolingOp : public UnaryOp<T> {
mkl_out_shape.GetMklLayout())) /
sizeof(T));
- AllocateOutputSetMklshape(context, 0, &output, tensor_out_shape,
+ AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape,
mkl_out_shape);
mkl_context.pooling_res[dnnResourceDst] =
static_cast<void*>(output->flat<T>().data());
@@ -138,9 +138,10 @@ class MklAvgPoolingOp : public UnaryOp<T> {
typedef struct {
MklPoolingOpParams params;
MklShape input_shape;
- dnnPrimitive_t prim_pooling_fwd, convert_input;
- dnnLayout_t lt_user_input, lt_prim_input, lt_workspace;
- void* input_buf;
+ dnnPrimitive_t prim_pooling_fwd = nullptr, convert_input = nullptr;
+ dnnLayout_t lt_user_input = nullptr, lt_prim_input = nullptr,
+ lt_workspace = nullptr;
+ void* input_buf = nullptr;
void* pooling_res[dnnResourceNumber];
void MklCreateLayoutsAndPrimitives(OpKernelContext* context,
@@ -243,6 +244,11 @@ class MklAvgPoolingGradOp : public OpKernel {
pool_params.Init(context, ksize_, stride_, padding_, data_format_,
output_shape);
+ if (outbackprop_in_mkl_format == false)
+ mkl_context.params.in_dim = out_backprop.dims();
+ else
+ mkl_context.params.in_dim = mkl_context.out_backprop_shape.GetDimension();
+
// Extract the parameters for the op from the pooling specs
ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
@@ -250,6 +256,7 @@ class MklAvgPoolingGradOp : public OpKernel {
Tensor outbackprop_buf_tensor;
void* outbackprop_buf;
mkl_context.MklCreateLayoutsAndPrimitives(context);
+ OP_REQUIRES_OK(context, context->status());
// Check if outbackprop layout requires conversion.
if (!dnnLayoutCompare_F32(mkl_context.lt_user_outbackprop,
@@ -304,7 +311,7 @@ class MklAvgPoolingGradOp : public OpKernel {
mkl_out_shape.GetMklLayout())) /
sizeof(T));
- AllocateOutputSetMklshape(context, 0, &output, tensor_out_shape,
+ AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape,
mkl_out_shape);
// Set output tensor.
@@ -323,10 +330,10 @@ class MklAvgPoolingGradOp : public OpKernel {
typedef struct {
MklPoolingOpParams params;
MklShape out_backprop_shape;
- dnnPrimitive_t prim_pooling_bwd, convert_outbackprop;
+ dnnPrimitive_t prim_pooling_bwd = nullptr, convert_outbackprop = nullptr;
void* pooling_res[dnnResourceNumber];
- dnnLayout_t lt_user_input, lt_user_outbackprop, lt_prim_outbackprop,
- lt_workspace;
+ dnnLayout_t lt_user_input = nullptr, lt_user_outbackprop = nullptr,
+ lt_prim_outbackprop = nullptr, lt_workspace = nullptr;
void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
const Tensor& tensor_in_shape = MklGetInput(context, 0);
@@ -348,11 +355,6 @@ class MklAvgPoolingGradOp : public OpKernel {
"4-dimensional"));
} else {
// Input in MKL format.
- OP_REQUIRES(
- context, out_backprop.dims() == 2,
- errors::InvalidArgument("out_backprop in MKL format must be "
- "2-dimensional"));
-
// For avgpooling, out_backprop should have 4 dimensions.
OP_REQUIRES(context, out_backprop_shape.GetDimension() == 4,
errors::InvalidArgument("out_backprop must be "
@@ -412,16 +414,16 @@ class MklAvgPoolingGradOp : public OpKernel {
TensorFormat data_format_;
};
-REGISTER_KERNEL_BUILDER(Name("MklAvgPool")
+REGISTER_KERNEL_BUILDER(Name("_MklAvgPool")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T")
- .Label(mkl_layer_registry::kMklLayerLabel),
+ .Label(mkl_op_registry::kMklOpLabel),
MklAvgPoolingOp<CPUDevice, float>);
-REGISTER_KERNEL_BUILDER(Name("MklAvgPoolGrad")
+REGISTER_KERNEL_BUILDER(Name("_MklAvgPoolGrad")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T")
- .Label(mkl_layer_registry::kMklLayerLabel),
+ .Label(mkl_op_registry::kMklOpLabel),
MklAvgPoolingGradOp<CPUDevice, float>);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc
new file mode 100644
index 0000000000..27930c44a6
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_concat_op.cc
@@ -0,0 +1,458 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+
+#include <limits>
+#include <vector>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+#include "third_party/mkl/include/mkl_dnn.h"
+#include "third_party/mkl/include/mkl_dnn_types.h"
+#include "tensorflow/core/util/mkl_util.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
+
+// TODO(intelft) Check if we can reuse existing EigenConcatOp using Mutable
+// reference inputs.
+// --------------------------------------------------------------------------
+// Eigen Concat Op
+// --------------------------------------------------------------------------
+template <typename Device, typename T, AxisArgumentName AxisArgName>
+class EigenConcatBaseOp : public OpKernel {
+ public:
+ typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
+ ConstMatrixVector;
+
+ explicit EigenConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
+
+ // Although, we modify Compute for this call to accept one extra param,
+ // we need to have empty Compute because Compute is pure virtual function.
+ void Compute(OpKernelContext* c) {}
+
+ void Compute(OpKernelContext* c, const std::vector<Tensor>& values) {
+ const Tensor* concat_dim_tensor;
+ const char* axis_attribute_name =
+ AxisArgName == NAME_IS_AXIS
+ ? "axis"
+ : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
+ OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
+ OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()),
+ errors::InvalidArgument(
+ axis_attribute_name,
+ " tensor should be a scalar integer, but got shape ",
+ concat_dim_tensor->shape().DebugString()));
+ const int32 concat_dim =
+ internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
+ // Instead of accessing values from context, we use input to Compute.
+ const int N = values.size();
+ const int input_dims = values[0].dims();
+ const TensorShape& input_shape = values[0].shape();
+
+ int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
+ OP_REQUIRES(c,
+ (0 <= axis && axis < input_dims) ||
+ (allow_legacy_scalars() && concat_dim == 0),
+ errors::InvalidArgument(
+ "ConcatOp : Expected concatenating dimensions in the range "
+ "[",
+ -input_dims, ", ", input_dims, "), but got ", concat_dim));
+ // Note that we reduce the concat of n-dimensional tensors into a two
+ // dimensional concat. Assuming the dimensions of any input/output
+ // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along
+ // the dimension indicated with size y0, we flatten it to {x, y}, where y =
+ // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1).
+ ConstMatrixVector inputs_flat;
+ inputs_flat.reserve(N);
+ int64 inputs_flat_dim0 = 1;
+ for (int d = 0; d < axis; ++d) {
+ inputs_flat_dim0 *= input_shape.dim_size(d);
+ }
+ int64 output_concat_dim = 0;
+ const bool input_is_scalar = IsLegacyScalar(input_shape);
+ for (int i = 0; i < N; ++i) {
+ const auto in = values[i];
+ const bool in_is_scalar = IsLegacyScalar(in.shape());
+ OP_REQUIRES(
+ c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
+ errors::InvalidArgument(
+ "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
+ input_shape.DebugString(), " vs. shape[", i,
+ "] = ", in.shape().DebugString()));
+ for (int j = 0; j < input_dims; ++j) {
+ if (j == axis) {
+ continue;
+ }
+ OP_REQUIRES(
+ c, in.dim_size(j) == input_shape.dim_size(j),
+ errors::InvalidArgument(
+ "ConcatOp : Dimensions of inputs should match: shape[0] = ",
+ input_shape.DebugString(), " vs. shape[", i,
+ "] = ", in.shape().DebugString()));
+ }
+ if (in.NumElements() > 0) {
+ int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
+ inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
+ in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
+ }
+ // TODO(irving): Remove check once !allow_legacy_scalars().
+ output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
+ }
+
+ TensorShape output_shape(input_shape);
+ // TODO(irving): Remove rank 0 case once !allow_legacy_scalars().
+ if (output_shape.dims() == 0) {
+ output_shape.AddDim(output_concat_dim);
+ } else {
+ output_shape.set_dim(axis, output_concat_dim);
+ }
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
+ if (output->NumElements() > 0) {
+ int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
+ auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
+ ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
+ }
+ }
+};
+
+// --------------------------------------------------------------------------
+// Mkl Concat Op
+// --------------------------------------------------------------------------
+
+template <typename Device, typename T, AxisArgumentName AxisArgName>
+class MklConcatOp : public OpKernel {
+ private:
+ TensorFormat data_format_;
+ EigenConcatBaseOp<Device, T, AxisArgName> eigen_concat_op_;
+
+ public:
+ typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
+ ConstMatrixVector;
+
+ explicit MklConcatOp(OpKernelConstruction* c)
+ : OpKernel(c), eigen_concat_op_(c) {}
+
+ void Compute(OpKernelContext* context) override {
+ MklConcatOpContext mkl_context;
+
+ // Get input tensors.
+ OpInputList input_tensors;
+ GetMklInputList(context, "values", &input_tensors);
+ const int N = input_tensors.size();
+ // Get MKL shapes.
+ MklShapeList input_shapes(N);
+ GetMklShapeList(context, "values", &input_shapes);
+
+ // If this is Concat, then concat_dim is 0th input.
+ // If this is ConcatV2, then axis is Nth input.
+ const Tensor& concat_dim_tensor = AxisArgName == NAME_IS_CONCAT_DIM
+ ? MklGetInput(context, 0)
+ : MklGetInput(context, N);
+
+ // Sanity checks
+ OP_REQUIRES(
+ context, IsLegacyScalar(concat_dim_tensor.shape()),
+ errors::InvalidArgument(
+ "Concat dim tensor should be a scalar integer, but got shape ",
+ concat_dim_tensor.shape().DebugString()));
+ int32 concat_dim =
+ internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()());
+
+ MklShape& inpshape0 = input_shapes[0];
+
+ // Check that all tensors are Mkl, if not we call Eigen version.
+ bool invoke_eigen = false;
+ bool is_concat_dim_channel = true;
+ if (!AreAllMklTensors(input_shapes)) {
+ invoke_eigen = true;
+ }
+
+ // Check that total number of dimensions is 4, if not call Eigen.
+ if (!invoke_eigen) {
+ for (auto& s : input_shapes) {
+ if (s.GetDimension() != 4) {
+ invoke_eigen = true;
+ break;
+ }
+ }
+ }
+
+ // check that concat_dim is channel, if not call Eigen version.
+ if (!invoke_eigen) {
+ for (auto& s : input_shapes) {
+ if (!s.IsMklChannelDim(concat_dim)) {
+ invoke_eigen = true;
+ is_concat_dim_channel = false;
+ break;
+ }
+ }
+ }
+
+ if (invoke_eigen) {
+ string msg = std::string("Invoking Eigen version of Concat. Reason:") +
+ (!is_concat_dim_channel
+ ? std::string("Concat dimension is not channel")
+ : std::string("Not all tensors are in Mkl layout"));
+ VLOG(1) << "_MklConcatOp: " << msg;
+ CallEigenVersion(context, input_tensors, input_shapes);
+ return;
+ }
+
+ // For MKL format, the channel is dimension number 2.
+ // So if we are concating over channel and _all_ inputs are in MKL
+ // format, then we set concat_dim to 2.
+ // Since we have reached till here, it means we are concating
+ // over channel.
+ concat_dim = MklDims::C;
+
+ // One more sanity check: check that ranks of all tensors match
+ // and that their shapes match except for concat_dim.
+ int i = 0;
+ for (auto& s : input_shapes) {
+ size_t exp_dims = inpshape0.GetDimension();
+ OP_REQUIRES(context, s.GetDimension() == exp_dims,
+ errors::InvalidArgument(
+ "_MklConcatOp : Ranks of all input tensors should match:"
+ " input dimensions = ",
+ s.GetDimension(), " vs. expected rank = ", exp_dims));
+
+ for (int d = 0; d < exp_dims; ++d) {
+ if (d == concat_dim) {
+ continue;
+ }
+
+ size_t exp_size = inpshape0.GetSizes()[d];
+ OP_REQUIRES(
+ context, exp_size == s.GetSizes()[d],
+ errors::InvalidArgument("_MklConcatOp : Dimensions of inputs"
+ "should match: shape[0][",
+ d, "]= ", exp_size, " vs. shape[", i, "][",
+ d, "] = ", s.GetSizes()[d]));
+ }
+ ++i;
+ }
+
+ // Use input MKL layout instead of creating new layouts.
+ int64 output_concat_dim_size = 0;
+ for (auto& s : input_shapes) {
+ output_concat_dim_size +=
+ s.GetDimension() > 0 ? s.GetSizes()[concat_dim] : 1;
+ }
+ mkl_context.MklCreateInputLayouts(context, input_shapes);
+
+ CHECK_EQ(dnnConcatCreate_F32(&mkl_context.prim_concat, NULL, N,
+ &mkl_context.lt_inputs[0]),
+ E_SUCCESS);
+
+ // Calculate output sizes and strides
+ TensorFormat data_format;
+ if (inpshape0.IsTensorInNHWCFormat()) {
+ data_format = FORMAT_NHWC;
+ } else {
+ OP_REQUIRES(
+ context, inpshape0.IsTensorInNCHWFormat(),
+ errors::InvalidArgument(
+ "_MklConcat only supports all inputs in NCHW or NHWC format "));
+ data_format = FORMAT_NCHW;
+ }
+
+ // Since all tensors are in Mkl layout, we copy sizes from input tensor.
+ mkl_context.out_sizes[MklDims::W] = inpshape0.GetSizes()[MklDims::W];
+ mkl_context.out_sizes[MklDims::H] = inpshape0.GetSizes()[MklDims::H];
+ mkl_context.out_sizes[MklDims::C] = output_concat_dim_size;
+ mkl_context.out_sizes[MklDims::N] = inpshape0.GetSizes()[MklDims::N];
+ GetStridesFromSizes(data_format, mkl_context.out_strides,
+ mkl_context.out_sizes);
+
+ // Set output Mkl shape.
+ int64 dim = 4;
+ MklShape mkl_output_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(true);
+ mkl_output_mkl_shape.SetMklLayout(mkl_context.prim_concat, dnnResourceDst);
+ mkl_output_mkl_shape.SetTfLayout(dim, mkl_context.out_sizes,
+ mkl_context.out_strides);
+ mkl_output_mkl_shape.SetTfDimOrder(dim, inpshape0.GetTfToMklDimMap());
+
+ TensorShape mkl_output_tf_shape;
+ mkl_output_tf_shape.AddDim(1);
+ mkl_output_tf_shape.AddDim(
+ dnnLayoutGetMemorySize_F32(
+ static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
+ sizeof(T));
+
+ Tensor* output = nullptr;
+ AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
+ mkl_output_mkl_shape);
+
+ // Set destination resource.
+ mkl_context.concat_res[dnnResourceDst] =
+ const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
+
+ mkl_context.mkl_tmp_tensors.resize(N);
+ mkl_context.MklPrepareConcatInputs(context, input_tensors);
+
+ // Execute primitive.
+ CHECK_EQ(dnnExecute_F32(mkl_context.prim_concat, mkl_context.concat_res),
+ E_SUCCESS);
+
+ mkl_context.MklCleanup();
+ }
+
+ private:
+ typedef struct {
+ TensorFormat data_format;
+ size_t out_sizes[4];
+ size_t out_strides[4];
+ dnnPrimitive_t prim_concat;
+ void* concat_res[dnnResourceNumber];
+ std::vector<dnnLayout_t> lt_inputs;
+ std::vector<Tensor> mkl_tmp_tensors;
+
+ // Create MKL dnnLayout_t objects for tensors coming into the layer
+ // We only support case where input tensors are all in Mkl layout.
+ void MklCreateInputLayouts(OpKernelContext* context,
+ MklShapeList& input_shapes) {
+ for (auto& is : input_shapes) {
+ CHECK_EQ(is.IsMklTensor(), true);
+ lt_inputs.push_back((dnnLayout_t)is.GetCurLayout());
+ }
+ }
+
+ void MklPrepareConcatInputs(OpKernelContext* context,
+ OpInputList& input_tensors) {
+ CHECK_EQ(lt_inputs.size(), mkl_tmp_tensors.size());
+
+ for (int i = 0; i < lt_inputs.size(); ++i) {
+ dnnPrimitive_t mkl_prim_convert_input;
+ dnnLayout_t mkl_lt_internal_input;
+ void* mkl_buf_convert_input = nullptr;
+
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
+ &mkl_lt_internal_input, prim_concat,
+ (dnnResourceType_t)(dnnResourceMultipleSrc + i)),
+ E_SUCCESS);
+
+ if (!dnnLayoutCompare_F32(lt_inputs[i], mkl_lt_internal_input)) {
+ CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input,
+ lt_inputs[i], mkl_lt_internal_input),
+ E_SUCCESS);
+
+ AllocTmpBuffer(context, &mkl_tmp_tensors[i], mkl_lt_internal_input,
+ &mkl_buf_convert_input);
+
+ CHECK_EQ(dnnConversionExecute_F32(
+ mkl_prim_convert_input,
+ const_cast<void*>(static_cast<const void*>(
+ input_tensors[i].flat<T>().data())),
+ mkl_buf_convert_input),
+ E_SUCCESS);
+
+ concat_res[dnnResourceMultipleSrc + i] = mkl_buf_convert_input;
+ CHECK_EQ(dnnDelete_F32(mkl_prim_convert_input), E_SUCCESS);
+ } else {
+ concat_res[dnnResourceMultipleSrc + i] = const_cast<void*>(
+ static_cast<const void*>(input_tensors[i].flat<T>().data()));
+ }
+
+ CHECK_EQ(dnnLayoutDelete_F32(mkl_lt_internal_input), E_SUCCESS);
+ }
+ }
+
+ void MklCleanup() {
+ for (auto& lt : lt_inputs) {
+ lt = nullptr;
+ }
+ CHECK_EQ(dnnDelete_F32(prim_concat), E_SUCCESS);
+ }
+ } MklConcatOpContext;
+
+ void CallEigenVersion(OpKernelContext* context, const OpInputList& values,
+ const MklShapeList& input_shapes) {
+ // Before calling Eigen version, we need to convert Mkl tensors to TF.
+ // First check that the number of input tensors and the number of Mkl
+ // shapes match.
+ CHECK_EQ(values.size(), input_shapes.size());
+
+ std::vector<Tensor> converted_values;
+ for (int i = 0; i < input_shapes.size(); i++) {
+ if (input_shapes[i].IsMklTensor()) {
+ // If input tensor is Mkl, then do the conversion.
+ Tensor tmp_tensor =
+ ConvertMklToTF<T>(context, values[i], input_shapes[i]);
+ converted_values.push_back(tmp_tensor);
+ } else {
+ // If input tensor is TF already, then we do not need any conversion.
+ converted_values.push_back(values[i]);
+ }
+ }
+
+ // Call Eigen concat.
+ eigen_concat_op_.Compute(context, converted_values);
+
+ // Set dummy Mkl tensor as output Mkl tensor for this op.
+ MklShape mkl_tensor_mkl_shape;
+ mkl_tensor_mkl_shape.SetMklTensor(false);
+ mkl_tensor_mkl_shape.SetDimensions(4);
+ mkl_tensor_mkl_shape.SetTfDimOrder(4); // Dimensions
+ Tensor* mkl_tensor = nullptr;
+ TensorShape mkl_tensor_tf_shape;
+ mkl_tensor_tf_shape.AddDim(
+ SIZE_OF_MKL_SERIAL_DATA(mkl_tensor_mkl_shape.GetDimension()));
+ int tf_output_index = 0;
+ context->allocate_output(
+ GetTensorMetaDataIndex(tf_output_index, context->num_outputs()),
+ mkl_tensor_tf_shape, &mkl_tensor);
+ mkl_tensor_mkl_shape.SerializeMklShape(
+ mkl_tensor->flat<uint8>().data(),
+ mkl_tensor->flat<uint8>().size() * sizeof(uint8));
+ }
+};
+
+/* Use optimized concat for float type only */
+#define REGISTER_MKL_CPU(type) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConcat") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("concat_dim") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConcatV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("axis") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConcatOp<CPUDevice, type, NAME_IS_AXIS>)
+
+TF_CALL_float(REGISTER_MKL_CPU);
+
+#undef REGISTER_CONCAT_MKL
+} // namespace tensorflow
+
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc
index 627fd83b0d..8a1006a8e9 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc
@@ -87,7 +87,7 @@ class MklConv2DCustomBackpropBiasOp : public OpKernel {
Tensor* bias_backprop = nullptr;
MklShape output_mkl_shape;
output_mkl_shape.SetMklTensor(false);
- AllocateOutputSetMklshape(context, 0, &bias_backprop, output_shape,
+ AllocateOutputSetMklShape(context, 0, &bias_backprop, output_shape,
output_mkl_shape);
mkl_context.in_dims = 4;
@@ -251,11 +251,11 @@ class MklConv2DCustomBackpropBiasOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(MklConv2DCustomBackpropBiasOp);
};
-#define REGISTER_CPU_KERNELS(T) \
- REGISTER_KERNEL_BUILDER(Name("MklConv2DWithBiasBackpropBias") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_layer_registry::kMklLayerLabel), \
+#define REGISTER_CPU_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBiasBackpropBias") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
MklConv2DCustomBackpropBiasOp<CPUDevice, T>);
TF_CALL_float(REGISTER_CPU_KERNELS);
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 85198d89f5..6381b527a1 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -217,7 +217,7 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
mkl_context.grad_filter_shape.SetTfLayout(mkl_context.filter_dims,
mkl_context.filter_sizes,
mkl_context.filter_strides);
- AllocateOutputSetMklshape(context, 0, &grad_filter, filter_shape,
+ AllocateOutputSetMklShape(context, 0, &grad_filter, filter_shape,
mkl_context.grad_filter_shape);
// Need to set member variable for TF layout
@@ -408,11 +408,11 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
TensorFormat data_format_;
};
-#define REGISTER_MKL_FILTER_KERNELS(T) \
- REGISTER_KERNEL_BUILDER(Name("MklConv2DBackpropFilter") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_layer_registry::kMklLayerLabel), \
+#define REGISTER_MKL_FILTER_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
MklConv2DCustomBackpropFilterOp<CPUDevice, T>);
TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index c7d95c86bc..638ce4c024 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -202,7 +202,7 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
mkl_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
mklOutputShape.GetMklLayout())) /
sizeof(T));
- AllocateOutputSetMklshape(context, 0, &in_backprop, mkl_out_shape,
+ AllocateOutputSetMklShape(context, 0, &in_backprop, mkl_out_shape,
mklOutputShape);
mkl_context.conv_res[dnnResourceDiffSrc] =
@@ -341,11 +341,11 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
TensorFormat data_format;
};
-#define REGISTER_MKL_CPU_KERNELS(T) \
- REGISTER_KERNEL_BUILDER(Name("MklConv2DBackpropInput") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_layer_registry::kMklLayerLabel), \
+#define REGISTER_MKL_CPU_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
MklConv2DCustomBackpropInputOp<CPUDevice, T>);
TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index e5c4c21a10..b818819b02 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -178,7 +178,7 @@ class MklConv2DOp : public OpKernel {
// Nothing to do, allocate output tensor and return
MklShape mkl_output_mkl_shape;
mkl_output_mkl_shape.SetMklTensor(false);
- AllocateOutputSetMklshape(context, 0, &output, input.shape(),
+ AllocateOutputSetMklShape(context, 0, &output, input.shape(),
mkl_output_mkl_shape);
return;
}
@@ -264,7 +264,7 @@ class MklConv2DOp : public OpKernel {
dnnLayoutGetMemorySize_F32(
static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
sizeof(T));
- AllocateOutputSetMklshape(context, 0, &output, mkl_output_tf_shape,
+ AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
mkl_output_mkl_shape);
mkl_context.conv_res[dnnResourceDst] =
static_cast<void*>(output->flat<T>().data());
@@ -437,16 +437,16 @@ class MklConv2DOp : public OpKernel {
TensorFormat data_format_;
};
-#define REGISTER_MKL_CPU(T) \
- REGISTER_KERNEL_BUILDER(Name("MklConv2D") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_layer_registry::kMklLayerLabel), \
- MklConv2DOp<CPUDevice, T, false>); \
- REGISTER_KERNEL_BUILDER(Name("MklConv2DWithBias") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_layer_registry::kMklLayerLabel), \
+#define REGISTER_MKL_CPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConv2DOp<CPUDevice, T, false>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
MklConv2DOp<CPUDevice, T, true>);
TF_CALL_float(REGISTER_MKL_CPU);
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
new file mode 100644
index 0000000000..512e799d15
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
@@ -0,0 +1,689 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifdef INTEL_MKL
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+#include "third_party/mkl/include/mkl_dnn.h"
+#include "third_party/mkl/include/mkl_dnn_types.h"
+#include "tensorflow/core/util/mkl_util.h"
+
+// TODO(inteltf) Address comments from PR 8968.
+
+namespace tensorflow {
+using CPUDevice = Eigen::ThreadPoolDevice;
+template <typename Device, typename T>
+class MklFusedBatchNormOp : public OpKernel {
+ public:
+ explicit MklFusedBatchNormOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ float epsilon;
+ OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
+ epsilon_ = T(epsilon);
+ string tensor_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
+ OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ MklFusedBatchNormOpContext mkl_context;
+
+ const Tensor& input = MklGetInput(context, 0);
+ const Tensor& scale = MklGetInput(context, 1);
+ const Tensor& shift = MklGetInput(context, 2);
+ const Tensor& est_mean = MklGetInput(context, 3);
+ const Tensor& est_variance = MklGetInput(context, 4);
+
+ GetMklShape(context, 0, &(mkl_context.mkl_shape_input_shape));
+ bool input_in_mkl_format = mkl_context.mkl_shape_input_shape.IsMklTensor();
+ if (!input_in_mkl_format) {
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().DebugString()));
+ }
+ OP_REQUIRES(context, scale.dims() == 1,
+ errors::InvalidArgument("scale must be 1-dimensional",
+ scale.shape().DebugString()));
+ OP_REQUIRES(context, shift.dims() == 1,
+ errors::InvalidArgument("offset must be 1-dimensional",
+ shift.shape().DebugString()));
+ OP_REQUIRES(context, est_mean.dims() == 1,
+ errors::InvalidArgument("estimated_mean must be 1-dimensional",
+ est_mean.shape().DebugString()));
+ OP_REQUIRES(
+ context, est_variance.dims() == 1,
+ errors::InvalidArgument("estimated_variance must be 1-dimensional",
+ est_variance.shape().DebugString()));
+ if (is_training_) {
+ OP_REQUIRES(context, est_mean.dim_size(0) == 0,
+ errors::InvalidArgument("estimated_mean empty for training",
+ est_mean.shape().DebugString()));
+ OP_REQUIRES(context, est_variance.dim_size(0) == 0,
+ errors::InvalidArgument(
+ "estimated_variance must be empty for training",
+ est_variance.shape().DebugString()));
+ }
+
+ unsigned int flag_batch_norm =
+ is_training_ ? dnnUseScaleShift
+ : (dnnUseInputMeanVariance | dnnUseScaleShift);
+
+ mkl_context.MklExtractParams(context, tensor_format_);
+
+ // Create layout only for input data as it is used in Op primitive.
+ mkl_context.MklCreateInputLayout(context);
+
+ // Create Op primitive.
+ CHECK_EQ(dnnBatchNormalizationCreateForward_v2_F32(
+ &(mkl_context.mkl_prim_batchnorm), nullptr,
+ mkl_context.mkl_lt_input, static_cast<float>(epsilon_),
+ flag_batch_norm),
+ E_SUCCESS);
+
+ // Temporary tensors with buffers for the context inputs, if
+ // conversion to MKL-Op specific layouts are required. It is assumed here
+ // that TF's 1D tensors (scale, shift, est_mean, and est_variance) won't
+ // require any conversion.
+ // Since scale-shift is combined in MKL, a buffer is required.
+ Tensor mkl_tmp_input_buf_tensor, mkl_tmp_scale_shift_buf_tensor;
+ mkl_context.MklPrepareContextInputs(context, &mkl_tmp_input_buf_tensor,
+ &mkl_tmp_scale_shift_buf_tensor);
+
+ // Output data in MKL layout
+ Tensor* output = nullptr;
+ TensorShape tf_shape_output;
+ MklShape mkl_shape_output;
+ mkl_shape_output.SetMklTensor(true);
+ mkl_shape_output.SetMklLayout(mkl_context.mkl_prim_batchnorm,
+ dnnResourceDst);
+ mkl_shape_output.SetTfLayout(mkl_context.mkl_params.in_dim,
+ mkl_context.mkl_params.in_sizes,
+ mkl_context.mkl_params.in_strides);
+ mkl_shape_output.SetTfDimOrder(mkl_context.mkl_params.in_dim,
+ tensor_format_);
+ tf_shape_output.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
+ mkl_shape_output.GetMklLayout())) /
+ sizeof(T));
+ AllocateOutputSetMklShape(context, 0, &output, tf_shape_output,
+ mkl_shape_output);
+ mkl_context.mkl_res_batchnorm[dnnResourceDst] =
+ static_cast<void*>(output->flat<T>().data());
+
+ // Batch mean in TF layout
+ Tensor* batch_mean = nullptr;
+ MklShape mkl_shape_batch_mean;
+ mkl_shape_batch_mean.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, 1, &batch_mean, scale.shape(),
+ mkl_shape_batch_mean);
+ // Batch variance in TF layout
+ Tensor* batch_variance = nullptr;
+ MklShape mkl_shape_batch_variance;
+ mkl_shape_batch_variance.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, 2, &batch_variance, scale.shape(),
+ mkl_shape_batch_variance);
+ // If training mode, set dnnResourceMean and dnnResourceVariance to
+ // output tensors for batch mean and variance.
+ // Otherwise, set dnnResourceMean and dnnResourceVariance to
+ // estimated mean and variance.
+ if (is_training_)
+ mkl_context.MklSetMeanVariance(*batch_mean, *batch_variance);
+ else
+ mkl_context.MklSetMeanVariance(est_mean, est_variance);
+
+ // Now that all resources are set, it is ready for dnnExecute
+ CHECK_EQ(dnnExecute_F32(mkl_context.mkl_prim_batchnorm,
+ mkl_context.mkl_res_batchnorm),
+ E_SUCCESS);
+
+ // Mean and variance (without Bessel's correction) saved for backward
+ // computation to serve as pre-computed mean and variance.
+ Tensor* saved_mean = nullptr;
+ MklShape mkl_shape_saved_mean;
+ mkl_shape_saved_mean.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, 3, &saved_mean, scale.shape(),
+ mkl_shape_saved_mean);
+ std::memcpy(
+ reinterpret_cast<char*>(saved_mean->flat<float>().data()),
+ reinterpret_cast<char*>(mkl_context.mkl_res_batchnorm[dnnResourceMean]),
+ scale.NumElements() * sizeof(float));
+ Tensor* saved_variance = nullptr;
+ MklShape mkl_shape_saved_variance;
+ mkl_shape_saved_variance.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, 4, &saved_variance, scale.shape(),
+ mkl_shape_saved_variance);
+ std::memcpy(reinterpret_cast<char*>(saved_variance->flat<float>().data()),
+ reinterpret_cast<char*>(
+ mkl_context.mkl_res_batchnorm[dnnResourceVariance]),
+ scale.NumElements() * sizeof(float));
+
+ // Bessel's correction on variance, if training mode is on
+ if (is_training_) {
+ float* p_var = static_cast<float*>(batch_variance->flat<T>().data());
+ auto depth = mkl_context.mkl_params.depth;
+ size_t orig_size = mkl_context.mkl_params.in_sizes[0] *
+ mkl_context.mkl_params.in_sizes[1] *
+ mkl_context.mkl_params.in_sizes[3];
+ size_t adjust_size = orig_size - 1;
+ float adjust_factor = (static_cast<float>(orig_size)) / adjust_size;
+ for (int i = 0; i < depth; i++) p_var[i] = adjust_factor * p_var[i];
+ }
+
+ mkl_context.MklCleanup();
+ }
+
+ private:
+ T epsilon_;
+ TensorFormat tensor_format_;
+ bool is_training_;
+
+ // Structure containing all info for MklOp
+ typedef struct {
+ // Parameters used for input and output layouts
+ struct MklBatchNormParams {
+ // BatchNormOp src and
+ size_t in_dim;
+ size_t in_sizes[4];
+ size_t in_strides[4];
+ size_t depth; // Batch normalization is done for per channel.
+ } mkl_params;
+
+ MklShape mkl_shape_input_shape;
+
+ // MKL primitive and resources for BatchNormOp
+ dnnPrimitive_t mkl_prim_batchnorm = nullptr;
+ void* mkl_res_batchnorm[dnnResourceNumber];
+
+ // MKL layouts for inputs in the context
+ dnnLayout_t mkl_lt_input = nullptr;
+
+ void MklCleanup() {
+ bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
+ if (!input_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_input);
+ if (mkl_prim_batchnorm != nullptr) dnnDelete_F32(mkl_prim_batchnorm);
+ }
+
+ void MklExtractParams(OpKernelContext* context,
+ const TensorFormat& tensor_format) {
+ const Tensor& input = MklGetInput(context, 0);
+ bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
+ mkl_params.in_dim = input_in_mkl_format
+ ? mkl_shape_input_shape.GetDimension()
+ : input.dims();
+ mkl_params.in_sizes[0] = static_cast<size_t>(
+ input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[0]
+ : GetTensorDim(input, tensor_format, 'W'));
+ mkl_params.in_sizes[1] = static_cast<size_t>(
+ input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[1]
+ : GetTensorDim(input, tensor_format, 'H'));
+ mkl_params.in_sizes[2] = static_cast<size_t>(
+ input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[2]
+ : GetTensorDim(input, tensor_format, 'C'));
+ mkl_params.in_sizes[3] = static_cast<size_t>(
+ input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[3]
+ : GetTensorDim(input, tensor_format, 'N'));
+ mkl_params.depth = mkl_params.in_sizes[2];
+ GetStridesFromSizes(tensor_format, mkl_params.in_strides,
+ mkl_params.in_sizes);
+ }
+
+ void MklCreateInputLayout(OpKernelContext* context) {
+ const Tensor& input = MklGetInput(context, 0);
+ bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
+ if (input_in_mkl_format) {
+ mkl_lt_input =
+ static_cast<dnnLayout_t>(mkl_shape_input_shape.GetCurLayout());
+ } else {
+ CHECK_EQ(
+ dnnLayoutCreate_F32(&mkl_lt_input, mkl_params.in_dim,
+ mkl_params.in_sizes, mkl_params.in_strides),
+ E_SUCCESS);
+ }
+ }
+
+ void MklPrepareContextInputs(OpKernelContext* context,
+ Tensor* mkl_tmp_input_buf_tensor,
+ Tensor* mkl_tmp_scale_shift_buf_tensor) {
+ bool mkl_convert_input;
+ dnnPrimitive_t mkl_prim_convert_input = nullptr;
+ dnnLayout_t mkl_lt_internal_input = nullptr;
+ void* mkl_buf_converted_input = nullptr;
+ // Compare with internal layouts and convert if needed
+ const Tensor& input = MklGetInput(context, 0);
+ void* mkl_buf_input =
+ const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
+ &mkl_lt_internal_input, mkl_prim_batchnorm, dnnResourceSrc),
+ E_SUCCESS);
+ mkl_convert_input =
+ !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_input);
+ if (mkl_convert_input) {
+ CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, mkl_lt_input,
+ mkl_lt_internal_input),
+ E_SUCCESS);
+ AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
+ &mkl_buf_converted_input);
+ CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
+ mkl_buf_converted_input),
+ E_SUCCESS);
+ dnnDelete_F32(mkl_prim_convert_input);
+ }
+ dnnLayoutDelete_F32(mkl_lt_internal_input);
+ mkl_res_batchnorm[dnnResourceSrc] =
+ (mkl_convert_input) ? mkl_buf_converted_input : mkl_buf_input;
+
+ // scale-shift layout is created from primitive. So no conversion
+ // is needed, however, a buffer has to be allocated.
+ dnnLayout_t mkl_lt_scale_shift = nullptr;
+ void* mkl_buf_scale_shift = nullptr;
+ CHECK_EQ(
+ dnnLayoutCreateFromPrimitive_F32(
+ &mkl_lt_scale_shift, mkl_prim_batchnorm, dnnResourceScaleShift),
+ E_SUCCESS);
+ AllocTmpBuffer(context, mkl_tmp_scale_shift_buf_tensor,
+ mkl_lt_scale_shift, &mkl_buf_scale_shift);
+ // Fill the scale-shift buffer with data, presumably buffer is 2D array
+ const Tensor& scale = MklGetInput(context, 1);
+ const Tensor& shift = MklGetInput(context, 2);
+ float* buf_scale_shift = static_cast<float*>(mkl_buf_scale_shift);
+ float* buf_scale = const_cast<float*>(
+ static_cast<const float*>(scale.flat<float>().data()));
+ float* buf_shift = const_cast<float*>(
+ static_cast<const float*>(shift.flat<float>().data()));
+ auto depth = mkl_params.depth;
+ for (int i = 0; i < depth; i++) {
+ buf_scale_shift[i] = buf_scale[i];
+ buf_scale_shift[i + depth] = buf_shift[i];
+ }
+ mkl_res_batchnorm[dnnResourceScaleShift] = mkl_buf_scale_shift;
+ }
+
+ inline void MklSetMeanVariance(const Tensor& mean, const Tensor& variance) {
+ mkl_res_batchnorm[dnnResourceMean] = const_cast<void*>(
+ static_cast<const void*>(mean.flat<float>().data()));
+ mkl_res_batchnorm[dnnResourceVariance] = const_cast<void*>(
+ static_cast<const void*>(variance.flat<float>().data()));
+ }
+ } MklFusedBatchNormOpContext;
+};
+
+#define REGISTER_MKL_CPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNorm") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklFusedBatchNormOp<CPUDevice, T>);
+TF_CALL_float(REGISTER_MKL_CPU);
+#undef REGISTER_MKL_CPU
+
+template <typename Device, typename T>
+class MklFusedBatchNormGradOp : public OpKernel {
+ public:
+ explicit MklFusedBatchNormGradOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ float epsilon;
+ OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
+ epsilon_ = T(epsilon);
+ string tensor_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
+ OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
+ errors::InvalidArgument("Invalid data format"));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ MklFusedBatchNormGradOpContext mkl_context;
+
+ const Tensor& out_backprop = MklGetInput(context, 0);
+ const Tensor& input = MklGetInput(context, 1);
+ const Tensor& scale = MklGetInput(context, 2);
+ const Tensor& saved_mean = MklGetInput(context, 3);
+ const Tensor& saved_var = MklGetInput(context, 4);
+
+ // Here scale, mean, and variance are 1D and considered
+ // those having same layout in MKL and TF
+ GetMklShape(context, 0, &(mkl_context.mkl_shape_out_backprop));
+ GetMklShape(context, 1, &(mkl_context.mkl_shape_input_shape));
+
+ bool input_in_mkl_format = mkl_context.mkl_shape_input_shape.IsMklTensor();
+ bool out_backprop_in_mkl_format =
+ mkl_context.mkl_shape_out_backprop.IsMklTensor();
+ if (!out_backprop_in_mkl_format) {
+ OP_REQUIRES(context, out_backprop.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ out_backprop.shape().DebugString()));
+ }
+ if (!input_in_mkl_format) {
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().DebugString()));
+ }
+ OP_REQUIRES(context, scale.dims() == 1,
+ errors::InvalidArgument("scale must be 1-dimensional",
+ scale.shape().DebugString()));
+ OP_REQUIRES(context, saved_mean.dims() == 1,
+ errors::InvalidArgument("saved mean must be 1-dimensional",
+ saved_mean.shape().DebugString()));
+ OP_REQUIRES(context, saved_var.dims() == 1,
+ errors::InvalidArgument("saved variance must be 1-dimensional",
+ saved_var.shape().DebugString()));
+
+ mkl_context.MklExtractParams(context, tensor_format_);
+
+ mkl_context.MklCreateInputLayout(context);
+
+ unsigned int flag_batch_norm_grad = dnnUseScaleShift;
+
+ // Create Backward Op primitive.
+ CHECK_EQ(dnnBatchNormalizationCreateBackward_v2_F32(
+ &(mkl_context.mkl_prim_batchnorm_bwd), nullptr,
+ mkl_context.mkl_lt_input, static_cast<float>(epsilon_),
+ flag_batch_norm_grad),
+ E_SUCCESS);
+
+ // Temporary tensors and their buffers if conversion is required
+ Tensor mkl_tmp_input_buf_tensor, mkl_tmp_outbackprop_buf_tensor,
+ mkl_tmp_scaleshift_buf_tensor;
+ mkl_context.MklPrepareContextInputs(context, &mkl_tmp_input_buf_tensor,
+ &mkl_tmp_outbackprop_buf_tensor,
+ &mkl_tmp_scaleshift_buf_tensor);
+
+ // Allocate tensor for grad w.r.t. input(x)
+ Tensor* in_backprop = nullptr;
+ TensorShape tf_shape_in_backprop;
+ MklShape mkl_shape_in_backprop;
+ mkl_shape_in_backprop.SetMklTensor(true);
+ mkl_shape_in_backprop.SetMklLayout(mkl_context.mkl_prim_batchnorm_bwd,
+ dnnResourceDiffSrc);
+ mkl_shape_in_backprop.SetTfLayout(mkl_context.mkl_params.in_dims,
+ mkl_context.mkl_params.in_sizes,
+ mkl_context.mkl_params.in_strides);
+ mkl_shape_in_backprop.SetTfDimOrder(mkl_context.mkl_params.in_dims,
+ tensor_format_);
+ tf_shape_in_backprop.AddDim(
+ dnnLayoutGetMemorySize_F32(
+ static_cast<dnnLayout_t>(mkl_shape_in_backprop.GetMklLayout())) /
+ sizeof(T));
+ AllocateOutputSetMklShape(context, 0, &in_backprop, tf_shape_in_backprop,
+ mkl_shape_in_backprop);
+ mkl_context.mkl_res_batchnorm_bwd[dnnResourceDiffSrc] =
+ static_cast<void*>(in_backprop->flat<T>().data());
+
+ // grad_scale and grad_shift are combined together in MKL
+ // So create a single temporary buffer for those.
+ // Also set dnnResourceDiffScaleShift to the temporary buffer
+ Tensor mkl_tmp_grad_scale_shift_buf_tensor;
+ mkl_context.MklPrepareGradScaleShift(context,
+ &mkl_tmp_grad_scale_shift_buf_tensor);
+
+ // All dnn resources are set now, ready to execute
+ CHECK_EQ(dnnExecute_F32(mkl_context.mkl_prim_batchnorm_bwd,
+ mkl_context.mkl_res_batchnorm_bwd),
+ E_SUCCESS);
+
+ // Now separate out scale and shift grad and copy to individual tensors
+ const TensorShape& tf_shape_scale_shift = scale.shape();
+ // Allocate tensor for grad w.r.t. scale (beta)
+ Tensor* scale_backprop = nullptr;
+ MklShape mkl_shape_scale_backprop;
+ AllocateOutputSetMklShape(context, 1, &scale_backprop, tf_shape_scale_shift,
+ mkl_shape_scale_backprop);
+
+ // Allocate tensor for grad w.r.t. shift(gamma)
+ Tensor* shift_backprop = nullptr;
+ MklShape mkl_shape_shift_backprop;
+ AllocateOutputSetMklShape(context, 2, &shift_backprop, tf_shape_scale_shift,
+ mkl_shape_shift_backprop);
+
+ // copy scale and shift grads to tensors
+ float* mkl_buf_scale_shift = const_cast<float*>(static_cast<const float*>(
+ mkl_tmp_grad_scale_shift_buf_tensor.flat<T>().data()));
+ float* tf_buf_scale = const_cast<float*>(
+ static_cast<const float*>(scale_backprop->flat<T>().data()));
+ float* tf_buf_shift = const_cast<float*>(
+ static_cast<const float*>(shift_backprop->flat<T>().data()));
+ auto depth = mkl_context.mkl_params.depth;
+ for (int i = 0; i < depth; i++) {
+ tf_buf_scale[i] = mkl_buf_scale_shift[i];
+ tf_buf_shift[i] = mkl_buf_scale_shift[i + depth];
+ }
+
+ // Two placeholders for estimated_mean and estimated_variance, which are
+ // used for inference and thus not needed here for gradient computation.
+ Tensor* placeholder_1 = nullptr;
+ MklShape mkl_shape_placeholder_1;
+ AllocateOutputSetMklShape(context, 3, &placeholder_1, TensorShape({}),
+ mkl_shape_placeholder_1);
+ Tensor* placeholder_2 = nullptr;
+ MklShape mkl_shape_placeholder_2;
+ AllocateOutputSetMklShape(context, 4, &placeholder_2, TensorShape({}),
+ mkl_shape_placeholder_2);
+
+ mkl_context.MklCleanup();
+ }
+
+ private:
+ T epsilon_;
+ TensorFormat tensor_format_;
+
+ // Structure containing all info for MklOp
+ typedef struct {
+ // Parameters used for input and output layouts
+ struct MklBatchNormParams {
+ // BatchNormOp src and
+ size_t in_dims;
+ size_t in_sizes[4];
+ size_t in_strides[4];
+ size_t depth; // Batch normalization is done for per channel.
+ } mkl_params;
+
+ MklShape mkl_shape_out_backprop;
+ MklShape mkl_shape_input_shape;
+
+ // MKL primitive and resources for BatchNormOp
+ dnnPrimitive_t mkl_prim_batchnorm_bwd = nullptr;
+ void* mkl_res_batchnorm_bwd[dnnResourceNumber];
+
+ // MKL layouts for inputs in the context
+ dnnLayout_t mkl_lt_out_backprop = nullptr;
+ dnnLayout_t mkl_lt_input = nullptr;
+
+ void MklCleanup() {
+ bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
+ bool out_backprop_in_mkl_format = mkl_shape_out_backprop.IsMklTensor();
+ if (!input_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_input);
+ if (!out_backprop_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_out_backprop);
+
+ dnnDelete_F32(mkl_prim_batchnorm_bwd);
+ }
+
+ void MklExtractParams(OpKernelContext* context,
+ const TensorFormat& tensor_format) {
+ const Tensor& input = MklGetInput(context, 1);
+ bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
+ mkl_params.in_dims = input_in_mkl_format
+ ? mkl_shape_input_shape.GetDimension()
+ : input.dims();
+ mkl_params.in_sizes[0] = static_cast<size_t>(
+ input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[0]
+ : GetTensorDim(input, tensor_format, 'W'));
+ mkl_params.in_sizes[1] = static_cast<size_t>(
+ input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[1]
+ : GetTensorDim(input, tensor_format, 'H'));
+ mkl_params.in_sizes[2] = static_cast<size_t>(
+ input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[2]
+ : GetTensorDim(input, tensor_format, 'C'));
+ mkl_params.in_sizes[3] = static_cast<size_t>(
+ input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[3]
+ : GetTensorDim(input, tensor_format, 'N'));
+ mkl_params.depth = mkl_params.in_sizes[2];
+ GetStridesFromSizes(tensor_format, mkl_params.in_strides,
+ mkl_params.in_sizes);
+ }
+
+ void MklCreateInputLayout(OpKernelContext* context) {
+ bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
+ if (input_in_mkl_format) {
+ mkl_lt_input =
+ static_cast<dnnLayout_t>(mkl_shape_input_shape.GetCurLayout());
+ } else {
+ CHECK_EQ(
+ dnnLayoutCreate_F32(&mkl_lt_input, mkl_params.in_dims,
+ mkl_params.in_sizes, mkl_params.in_strides),
+ E_SUCCESS);
+ }
+
+ bool out_backprop_in_mkl_format = mkl_shape_out_backprop.IsMklTensor();
+ if (out_backprop_in_mkl_format) {
+ mkl_lt_out_backprop =
+ static_cast<dnnLayout_t>(mkl_shape_out_backprop.GetCurLayout());
+ } else {
+ CHECK_EQ(
+ dnnLayoutCreate_F32(&mkl_lt_out_backprop, mkl_params.in_dims,
+ mkl_params.in_sizes, mkl_params.in_strides),
+ E_SUCCESS);
+ }
+ }
+
+ void MklPrepareContextInputs(OpKernelContext* context,
+ Tensor* mkl_tmp_input_buf_tensor,
+ Tensor* mkl_tmp_outbackprop_buf_tensor,
+ Tensor* mkl_tmp_scaleshift_buf_tensor) {
+ bool mkl_convert_input;
+ dnnPrimitive_t mkl_prim_convert_input = nullptr;
+ dnnLayout_t mkl_lt_internal_input = nullptr;
+ void* mkl_buf_converted_input = nullptr;
+ // Compare with internal layouts and convert if needed
+ const Tensor& input = MklGetInput(context, 1);
+ void* mkl_buf_input =
+ const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
+ CHECK_EQ(
+ dnnLayoutCreateFromPrimitive_F32(
+ &mkl_lt_internal_input, mkl_prim_batchnorm_bwd, dnnResourceSrc),
+ E_SUCCESS);
+ mkl_convert_input =
+ !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_input);
+ if (mkl_convert_input) {
+ CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, mkl_lt_input,
+ mkl_lt_internal_input),
+ E_SUCCESS);
+ AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
+ &mkl_buf_converted_input);
+ CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
+ mkl_buf_converted_input),
+ E_SUCCESS);
+ dnnDelete_F32(mkl_prim_convert_input);
+ }
+ dnnLayoutDelete_F32(mkl_lt_internal_input);
+ mkl_res_batchnorm_bwd[dnnResourceSrc] =
+ (mkl_convert_input) ? mkl_buf_converted_input : mkl_buf_input;
+
+ bool mkl_convert_out_backprop;
+ dnnPrimitive_t mkl_prim_convert_out_backprop = nullptr;
+ dnnLayout_t mkl_lt_internal_out_backprop = nullptr;
+ void* mkl_buf_converted_out_backprop = nullptr;
+ // Compare with internal layouts and convert if needed
+ const Tensor& out_backprop = MklGetInput(context, 0);
+ void* mkl_buf_out_backprop = const_cast<void*>(
+ static_cast<const void*>(out_backprop.flat<T>().data()));
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_out_backprop,
+ mkl_prim_batchnorm_bwd,
+ dnnResourceDiffDst),
+ E_SUCCESS);
+ mkl_convert_out_backprop = !dnnLayoutCompare_F32(
+ mkl_lt_internal_out_backprop, mkl_lt_out_backprop);
+ if (mkl_convert_out_backprop) {
+ CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_out_backprop,
+ mkl_lt_out_backprop,
+ mkl_lt_internal_out_backprop),
+ E_SUCCESS);
+ AllocTmpBuffer(context, mkl_tmp_outbackprop_buf_tensor,
+ mkl_lt_internal_out_backprop,
+ &mkl_buf_converted_out_backprop);
+ CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_out_backprop,
+ mkl_buf_out_backprop,
+ mkl_buf_converted_out_backprop),
+ E_SUCCESS);
+ dnnDelete_F32(mkl_prim_convert_out_backprop);
+ }
+ dnnLayoutDelete_F32(mkl_lt_internal_out_backprop);
+ mkl_res_batchnorm_bwd[dnnResourceDiffDst] =
+ (mkl_convert_out_backprop) ? mkl_buf_converted_out_backprop
+ : mkl_buf_out_backprop;
+
+ // Set dnnResourceMean and dnnResourceVariance
+ const Tensor& saved_mean = MklGetInput(context, 3);
+ const Tensor& saved_var = MklGetInput(context, 4);
+ void* mkl_buf_saved_mean = const_cast<void*>(
+ static_cast<const void*>(saved_mean.flat<T>().data()));
+ void* mkl_buf_saved_var = const_cast<void*>(
+ static_cast<const void*>(saved_var.flat<T>().data()));
+ mkl_res_batchnorm_bwd[dnnResourceMean] = mkl_buf_saved_mean;
+ mkl_res_batchnorm_bwd[dnnResourceVariance] = mkl_buf_saved_var;
+
+ // Set dnnResourceScaleShift
+ // Note backward Op needs only current values of scale parameters,
+ // shift parameters could be garbage and won't be used
+ const Tensor& scale = MklGetInput(context, 2);
+ dnnLayout_t mkl_lt_scale_shift = nullptr;
+ void* mkl_buf_scale_shift = nullptr;
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_scale_shift,
+ mkl_prim_batchnorm_bwd,
+ dnnResourceScaleShift),
+ E_SUCCESS);
+ AllocTmpBuffer(context, mkl_tmp_scaleshift_buf_tensor, mkl_lt_scale_shift,
+ &mkl_buf_scale_shift);
+ float* pscale =
+ const_cast<float*>(static_cast<const float*>(scale.flat<T>().data()));
+ float* pscale_shift = static_cast<float*>(mkl_buf_scale_shift);
+ auto depth = mkl_params.depth;
+ for (int i = 0; i < depth; i++) pscale_shift[i] = pscale[i];
+ mkl_res_batchnorm_bwd[dnnResourceScaleShift] = mkl_buf_scale_shift;
+ dnnLayoutDelete_F32(mkl_lt_scale_shift);
+ }
+
+ void MklPrepareGradScaleShift(OpKernelContext* context,
+ Tensor* mkl_tmp_grad_scale_shift_buf_tensor) {
+ dnnLayout_t mkl_lt_grad_scaleshift = nullptr;
+ void* mkl_buf_grad_scaleshift = nullptr;
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_grad_scaleshift,
+ mkl_prim_batchnorm_bwd,
+ dnnResourceDiffScaleShift),
+ E_SUCCESS);
+ AllocTmpBuffer(context, mkl_tmp_grad_scale_shift_buf_tensor,
+ mkl_lt_grad_scaleshift, &mkl_buf_grad_scaleshift);
+ mkl_res_batchnorm_bwd[dnnResourceDiffScaleShift] =
+ mkl_buf_grad_scaleshift;
+ dnnLayoutDelete_F32(mkl_lt_grad_scaleshift);
+ }
+ } MklFusedBatchNormGradOpContext;
+};
+
+#define REGISTER_MKL_CPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNormGrad") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklFusedBatchNormGradOp<CPUDevice, T>);
+TF_CALL_float(REGISTER_MKL_CPU);
+#undef REGISTER_MKL_CPU
+} // namespace tensorflow
+
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc
new file mode 100644
index 0000000000..edca8e2553
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_lrn_op.cc
@@ -0,0 +1,722 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// LRN = Local Response Normalization
+// See docs in ../ops/nn_ops.cc. This opkernel uses MKL library, create MKL
+// layout and primitives, use MKL dnn primitives to compute local
+// response normalization
+
+#ifdef INTEL_MKL
+
+#define EIGEN_USE_THREADS
+#include <vector>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "third_party/mkl/include/mkl_dnn.h"
+#include "third_party/mkl/include/mkl_dnn_types.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/mkl_util.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+#if !defined(IS_MOBILE_PLATFORM)
+#include "tensorflow/core/util/work_sharder.h"
+#endif
+
+namespace tensorflow {
+
+namespace {
+// Create a depth-by-depth band matrix with 1s along a swath of size (2 *
+// depth_radius + 1) around the diagonal.
+template <typename T>
+void GetBandMatrix(int depth, int depth_radius,
+ Eigen::Tensor<T, 2, Eigen::RowMajor>* result) {
+ result->setZero();
+ for (int row = 0; row < depth; ++row) {
+ const int begin = std::max<int>(0, row - depth_radius);
+ const int end = std::min<int>(depth, row + depth_radius + 1);
+ Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
+ Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
+ result->slice(start, sizes).setConstant(T(1));
+ }
+}
+
+} // namespace
+
+template <typename T>
+class MklLRNOp : public OpKernel {
+ public:
+ ~MklLRNOp() {}
+
+ explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) {
+ int64 depth_radius64;
+ OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
+ OP_REQUIRES(
+ context,
+ FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("depth_radius = ", depth_radius64,
+ " larger than int max"));
+ depth_radius_ = static_cast<size_t>(depth_radius64);
+
+ OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
+ OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
+ workspace_enabled_ = false;
+ context->GetAttr("workspace_enabled", &workspace_enabled_);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ MklLRNOpContext mkl_context;
+
+ const Tensor& input = MklGetInput(context, 0);
+ GetMklShape(context, 0, &mkl_context.input_shape);
+ bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
+
+ // Sanity checks
+ mkl_context.in_dims = input_in_mkl_format
+ ? mkl_context.input_shape.GetDimension()
+ : input.dims();
+ OP_REQUIRES(context, mkl_context.in_dims == 4,
+ errors::InvalidArgument("input must be 4-dimensional"));
+ OP_REQUIRES(
+ context,
+ FastBoundsCheck(input.NumElements(), std::numeric_limits<int>::max()),
+ errors::InvalidArgument("argument to LRN too large"));
+
+ if (!input_in_mkl_format) {
+ mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
+ beta_, input);
+ return;
+ }
+
+ if (input_in_mkl_format) {
+ // MKL supports normalization over channel dimension only
+ if (mkl_context.input_shape.tf_dim_idx(mkl_context.in_dims - 1) ==
+ MklDims::C) {
+ mkl_context.lt_input =
+ static_cast<dnnLayout_t>(mkl_context.input_shape.GetCurLayout());
+ workspace_enabled_ = true;
+ } else {
+ mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
+ beta_, input);
+ return;
+ }
+ }
+
+ int kernel_size = 2 * depth_radius_ + 1;
+
+ CHECK_EQ(dnnLRNCreateForward_F32(
+ &mkl_context.lrn_fwd, NULL, mkl_context.lt_input, kernel_size,
+ static_cast<float>(alpha_ * kernel_size), beta_, bias_),
+ E_SUCCESS);
+
+ // Allocate output tensor and shape
+ Tensor* output = nullptr;
+ Tensor* workspace = nullptr;
+
+ // Convert Inputs if needed
+ Tensor mkl_tmp_input_buf_tensor;
+ mkl_context.MklPrepareLRNInputs(context, &mkl_tmp_input_buf_tensor);
+
+ // Allocate Layer Outputs
+ mkl_context.MklAllocateOutputs(context, &output, &workspace,
+ workspace_enabled_);
+
+ Tensor mkl_tmp_workspace_buf_tensor;
+ mkl_context.MklPrepareLRNOutputs(context, output, workspace,
+ &mkl_tmp_workspace_buf_tensor,
+ workspace_enabled_);
+
+ // Execute LRN.
+ CHECK_EQ(dnnExecute_F32(mkl_context.lrn_fwd, mkl_context.lrn_res),
+ E_SUCCESS);
+
+ // Release MKL resources.
+ mkl_context.MklCleanup();
+ }
+
+ private:
+ typedef struct {
+ size_t in_dims;
+ size_t in_sizes[4];
+ size_t in_strides[4];
+ size_t out_sizes[4];
+ size_t out_strides[4];
+ MklShape input_shape;
+ dnnPrimitive_t lrn_fwd = nullptr;
+ dnnPrimitive_t convert_input = nullptr;
+ /* dnnPrimitive_t convert_output; */
+ dnnLayout_t lt_input = nullptr;
+ /* dnnLayout_t lt_output; */
+ dnnLayout_t lt_internal_input = nullptr;
+ dnnLayout_t lt_internal_workspace = nullptr;
+ dnnLayout_t lt_internal_output = nullptr;
+ void* lrn_res[dnnResourceNumber];
+
+ // Convert Inputs if needed
+ void MklPrepareLRNInputs(OpKernelContext* context,
+ Tensor* mkl_tmp_input_buf_tensor) {
+ const Tensor& input = MklGetInput(context, 0);
+ void* mkl_buf_input =
+ const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
+
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_internal_input, lrn_fwd,
+ dnnResourceSrc),
+ E_SUCCESS);
+
+ void* mkl_buf_convert_input = nullptr;
+ bool mkl_convert_input = false;
+ mkl_convert_input = !dnnLayoutCompare_F32(lt_internal_input, lt_input);
+
+ if (mkl_convert_input) {
+ CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_input,
+ lt_internal_input),
+ E_SUCCESS);
+ AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_internal_input,
+ &mkl_buf_convert_input);
+ CHECK_EQ(dnnConversionExecute_F32(convert_input, mkl_buf_input,
+ mkl_buf_convert_input),
+ E_SUCCESS);
+ dnnDelete_F32(convert_input);
+ }
+
+ lrn_res[dnnResourceSrc] =
+ (mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input;
+ }
+
+ // Allocate Layer Outputs
+ void MklAllocateOutputs(OpKernelContext* context, Tensor** output,
+ Tensor** workspace, bool workspace_enabled_) {
+ TensorShape mkl_output_tf_shape; /* First tensor */
+ MklShape mkl_output_mkl_shape; /* Second tensor */
+
+ mkl_output_mkl_shape.SetMklTensor(true);
+ mkl_output_mkl_shape.SetMklLayout(lrn_fwd, dnnResourceDst);
+ mkl_output_mkl_shape.SetTfLayout(in_dims, input_shape.GetSizes(),
+ input_shape.GetStrides());
+ mkl_output_mkl_shape.SetTfDimOrder(in_dims,
+ input_shape.GetTfToMklDimMap());
+ mkl_output_tf_shape.AddDim(
+ dnnLayoutGetMemorySize_F32(
+ static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
+ sizeof(T));
+ AllocateOutputSetMklShape(context, 0, output,
+ mkl_output_tf_shape /* First tensor */,
+ mkl_output_mkl_shape /* Second Tensor */);
+
+ if (workspace_enabled_) {
+ TensorShape mkl_workspace_tf_shape; /* First tensor */
+ MklShape mkl_workspace_mkl_shape; /* Second tensor */
+ mkl_workspace_mkl_shape.SetMklTensor(false);
+ mkl_workspace_mkl_shape.SetMklLayout(lrn_fwd, dnnResourceWorkspace);
+ // Assumes workspace has same TF layout and TF dim order as input
+ mkl_workspace_mkl_shape.SetTfLayout(in_dims, input_shape.GetSizes(),
+ input_shape.GetStrides());
+ mkl_workspace_mkl_shape.SetTfDimOrder(in_dims,
+ input_shape.GetTfToMklDimMap());
+ mkl_workspace_tf_shape.AddDim(
+ dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
+ mkl_workspace_mkl_shape.GetMklLayout())) /
+ sizeof(T));
+ AllocateOutputSetMklShape(context, 1, workspace,
+ mkl_workspace_tf_shape /* First tensor */,
+ mkl_workspace_mkl_shape /* Second Tensor */);
+ }
+ }
+
+ void MklPrepareLRNOutputs(OpKernelContext* context, Tensor* output,
+ Tensor* workspace,
+ Tensor* mkl_tmp_workspace_buf_tensor,
+ bool workspace_enabled_) {
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_internal_workspace, lrn_fwd,
+ dnnResourceWorkspace),
+ E_SUCCESS);
+
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_internal_output, lrn_fwd,
+ dnnResourceDst),
+ E_SUCCESS);
+
+ void* mkl_buf_output =
+ const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
+ lrn_res[dnnResourceDst] = mkl_buf_output;
+
+ void* mkl_buf_workspace = nullptr;
+ if (workspace_enabled_) {
+ mkl_buf_workspace = const_cast<void*>(
+ static_cast<const void*>(workspace->flat<T>().data()));
+ } else {
+ AllocTmpBuffer(context, mkl_tmp_workspace_buf_tensor,
+ lt_internal_workspace, &mkl_buf_workspace);
+ }
+ lrn_res[dnnResourceWorkspace] = mkl_buf_workspace;
+ }
+
+ // Fallback implementation - Taken from lrn_op.cc
+ // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
+ // copy.
+ void MklDefaultToEigen(OpKernelContext* context, int depth_radius_,
+ float bias_, float alpha_, float beta_,
+ const Tensor& input) {
+ const int batch = static_cast<int>(input.dim_size(0));
+ const int rows = static_cast<int>(input.dim_size(1));
+ const int cols = static_cast<int>(input.dim_size(2));
+ const int depth = static_cast<int>(input.dim_size(3));
+ const int nodes = cols * rows;
+
+ auto in_shaped = input.shaped<T, 2>({nodes * batch, depth});
+ // Multiplying the input with the band matrix has the effect of reducing
+ // the
+ // correct patch along the depth.
+ Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
+ GetBandMatrix<T>(depth, depth_radius_, &multiplier);
+
+ Tensor *output, *workspace;
+ MklShape mkl_output_mkl_shape, mkl_workspace_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(false);
+ mkl_output_mkl_shape.SetDimensions(4);
+ AllocateOutputSetMklShape(context, 0, &output, input.shape(),
+ mkl_output_mkl_shape);
+
+ mkl_workspace_mkl_shape.SetMklTensor(false);
+ mkl_workspace_mkl_shape.SetDimensions(4);
+ AllocateOutputSetMklShape(context, 1, &workspace, input.shape(),
+ mkl_workspace_mkl_shape);
+
+ auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
+ Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
+ auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
+ if (beta_ == T(1)) {
+ out_shaped.device(context->eigen_cpu_device()) =
+ in_shaped * tmp.inverse();
+ } else if (beta_ == T(0.5)) {
+ out_shaped.device(context->eigen_cpu_device()) =
+ in_shaped * tmp.rsqrt();
+ } else {
+ out_shaped.device(context->eigen_cpu_device()) =
+ in_shaped * (tmp.log() * -beta_).exp();
+ }
+ }
+
+ // Release MKL resources.
+ void MklCleanup() {
+ dnnDelete_F32(lrn_fwd);
+ dnnLayoutDelete_F32(lt_internal_input);
+ dnnLayoutDelete_F32(lt_internal_workspace);
+ dnnLayoutDelete_F32(lt_internal_output);
+ }
+ } MklLRNOpContext;
+
+ typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
+
+ bool workspace_enabled_;
+ int depth_radius_;
+ float bias_;
+ float alpha_;
+ float beta_;
+};
+
+template <typename T>
+class MklLRNGradOp : public OpKernel {
+ public:
+ explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
+ int64 depth_radius64;
+ OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
+ OP_REQUIRES(
+ context,
+ FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("depth_radius = ", depth_radius64,
+ " larger than int max"));
+ depth_radius_ = static_cast<int>(depth_radius64);
+ OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
+ OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
+ workspace_enabled_ = false;
+ context->GetAttr("workspace_enabled", &workspace_enabled_);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ MklLRNGradOpContext mkl_context;
+ mkl_context.depth_radius_ = depth_radius_;
+ mkl_context.bias_ = bias_;
+ mkl_context.alpha_ = alpha_;
+ mkl_context.beta_ = beta_;
+
+ const Tensor& in_grads = MklGetInput(context, 0);
+ const Tensor& in_image = MklGetInput(context, 1);
+ const Tensor& out_image = MklGetInput(context, 2);
+
+ GetMklShape(context, 0, &mkl_context.ingrad_shape);
+ GetMklShape(context, 1, &mkl_context.inimage_shape);
+ GetMklShape(context, 2, &mkl_context.outimage_shape);
+
+ bool ingrad_in_mkl_format = mkl_context.ingrad_shape.IsMklTensor();
+ bool inimage_in_mkl_format = mkl_context.inimage_shape.IsMklTensor();
+ bool outimage_in_mkl_format = mkl_context.outimage_shape.IsMklTensor();
+
+ mkl_context.in_dims = inimage_in_mkl_format
+ ? mkl_context.inimage_shape.GetDimension()
+ : in_image.dims();
+ OP_REQUIRES(context, mkl_context.in_dims == 4,
+ errors::InvalidArgument("input images must be 4-dimensional"));
+
+ if (!workspace_enabled_) {
+ mkl_context.MklDefaultToEigen(context);
+ return;
+ }
+ if (ingrad_in_mkl_format || inimage_in_mkl_format) {
+ const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format)
+ ? &mkl_context.ingrad_shape
+ : &mkl_context.inimage_shape;
+ if (tmp_mkl_shape->tf_dim_idx(mkl_context.in_dims - 1) != MklDims::C) {
+ // Fallback to eigen
+ mkl_context.MklDefaultToEigen(context);
+ return;
+ } else { // MKL supports normalization over channel dimension only
+ for (int i = 0; i < mkl_context.in_dims; i++) {
+ mkl_context.in_sizes[i] = mkl_context.out_sizes[i] =
+ tmp_mkl_shape->GetSizes()[i];
+ mkl_context.in_strides[i] = mkl_context.out_strides[i] =
+ tmp_mkl_shape->GetStrides()[i];
+ }
+ }
+ } else {
+ // Fallback to eigen
+ mkl_context.MklDefaultToEigen(context);
+ return;
+ }
+
+ // Dimensions check for sanity purpose
+ if (ingrad_in_mkl_format) {
+ OP_REQUIRES(
+ context, mkl_context.ingrad_shape.GetDimension() == 4,
+ errors::InvalidArgument("input gradient must be 4-dimensional"));
+ } else {
+ OP_REQUIRES(
+ context, in_grads.dims() == 4,
+ errors::InvalidArgument("input gradient must be 4-dimensional"));
+ }
+
+ if (outimage_in_mkl_format) {
+ OP_REQUIRES(
+ context, mkl_context.outimage_shape.GetDimension() == 4,
+ errors::InvalidArgument("Output image must be 4-dimensional"));
+ } else {
+ OP_REQUIRES(
+ context, out_image.dims() == 4,
+ errors::InvalidArgument("Output image must be 4-dimensional"));
+ }
+
+ // Prepare mkl input layout
+ mkl_context.MklPrepareLRNInputsLayouts(context);
+ int ksize = 2 * depth_radius_ + 1;
+
+ CHECK_EQ(dnnLRNCreateBackward_F32(
+ &mkl_context.lrn_bwd, NULL, mkl_context.lt_input,
+ mkl_context.lt_output, ksize,
+ static_cast<float>(alpha_ * ksize), beta_, bias_),
+ E_SUCCESS);
+
+ // Allocate output tensor and shape.
+ TensorShape mkl_output_tf_shape; /* First tensor */
+ MklShape mkl_output_mkl_shape; /* Second tensor */
+ mkl_output_mkl_shape.SetMklTensor(true);
+ CHECK_NE(mkl_context.lrn_bwd, nullptr);
+ mkl_output_mkl_shape.SetMklLayout(mkl_context.lrn_bwd, dnnResourceDiffSrc);
+ mkl_output_mkl_shape.SetTfLayout(mkl_context.in_dims, mkl_context.out_sizes,
+ mkl_context.out_strides);
+ if (ingrad_in_mkl_format) {
+ mkl_output_mkl_shape.SetTfDimOrder(
+ mkl_context.in_dims, mkl_context.ingrad_shape.GetTfToMklDimMap());
+ } else {
+ mkl_output_mkl_shape.SetTfDimOrder(
+ mkl_context.in_dims, mkl_context.inimage_shape.GetTfToMklDimMap());
+ }
+ mkl_output_tf_shape.AddDim(
+ dnnLayoutGetMemorySize_F32(
+ static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
+ sizeof(T));
+ Tensor* output = nullptr;
+ AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
+ mkl_output_mkl_shape);
+
+ // Get pointers to output data.
+ void* user_output =
+ const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
+
+ Tensor mkl_tmp_input_buf_tensor, mkl_tmp_image_buf_tensor,
+ mkl_tmp_outimage_buf_tensor, mkl_tmp_workspace_buf_tensor;
+ // Convert Inputs if needed
+ mkl_context.MklPrepareLRNGradInput(
+ context, &mkl_tmp_input_buf_tensor, &mkl_tmp_image_buf_tensor,
+ &mkl_tmp_outimage_buf_tensor, &mkl_tmp_workspace_buf_tensor);
+
+ // We do not do any conversion for output. But we simply emit it
+ // in MKL format.
+ mkl_context.res_lrn_bwd[dnnResourceDiffSrc] = user_output;
+ // Execute LRN backward using dnnExecute
+ CHECK_EQ(dnnExecute_F32(mkl_context.lrn_bwd, mkl_context.res_lrn_bwd),
+ E_SUCCESS);
+ // Release MKL resources.
+ mkl_context.Mklcleanup();
+ }
+
+ private:
+ typedef struct {
+ int depth_radius_;
+ float bias_;
+ float alpha_;
+ float beta_;
+ size_t in_dims;
+ size_t in_sizes[4];
+ size_t in_strides[4];
+ size_t out_sizes[4];
+ size_t out_strides[4];
+ MklShape ingrad_shape, inimage_shape, outimage_shape;
+ dnnPrimitive_t lrn_bwd = nullptr;
+ dnnPrimitive_t convert_input = nullptr;
+ /* dnnPrimitive_t convert_output; */
+ dnnLayout_t lt_input = nullptr;
+ dnnLayout_t lt_output = nullptr;
+ dnnLayout_t lt_bdw_input = nullptr;
+ dnnLayout_t lt_workspace = nullptr;
+ dnnLayout_t lt_internal_input = nullptr;
+ /* dnnLayout_t lt_internal_workspace;
+ dnnLayout_t lt_internal_output; */
+ void* res_lrn_bwd[dnnResourceNumber];
+
+ // prepare mkl input
+ void MklPrepareLRNInputsLayouts(OpKernelContext* context) {
+ bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
+ bool inimage_in_mkl_format = inimage_shape.IsMklTensor();
+ if (!ingrad_in_mkl_format) {
+ CHECK_EQ(dnnLayoutCreate_F32(&lt_input, in_dims, in_sizes, in_strides),
+ E_SUCCESS);
+ } else {
+ lt_input = static_cast<dnnLayout_t>(ingrad_shape.GetCurLayout());
+ }
+
+ if (!inimage_in_mkl_format) {
+ CHECK_EQ(
+ dnnLayoutCreate_F32(&lt_output, in_dims, out_sizes, out_strides),
+ E_SUCCESS);
+ } else {
+ lt_output = static_cast<dnnLayout_t>(inimage_shape.GetCurLayout());
+ }
+ }
+
+ // convert input if needed
+ void MklPrepareLRNGradInput(OpKernelContext* context,
+ Tensor* mkl_tmp_input_buf_tensor,
+ Tensor* mkl_tmp_image_buf_tensor,
+ Tensor* mkl_tmp_outimage_buf_tensor,
+ Tensor* mkl_tmp_workspace_buf_tensor) {
+ const Tensor& in_grads = MklGetInput(context, 0);
+ const Tensor& in_image = MklGetInput(context, 1);
+ const Tensor& out_image = MklGetInput(context, 2);
+
+ void* user_input = const_cast<void*>(
+ static_cast<const void*>(in_grads.flat<T>().data()));
+ void* user_fwd_input = const_cast<void*>(
+ static_cast<const void*>(in_image.flat<T>().data()));
+ void* user_fwd_output = const_cast<void*>(
+ static_cast<const void*>(out_image.flat<T>().data()));
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace, lrn_bwd,
+ dnnResourceWorkspace),
+ E_SUCCESS);
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_bdw_input, lrn_bwd,
+ dnnResourceDiffDst),
+ E_SUCCESS);
+
+ bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
+ if (ingrad_in_mkl_format) {
+ if (!dnnLayoutCompare_F32(lt_bdw_input, lt_input)) {
+ AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_bdw_input,
+ &res_lrn_bwd[dnnResourceDiffDst]);
+ ingrad_shape.GetConvertedFlatData(lt_bdw_input, user_input,
+ res_lrn_bwd[dnnResourceDiffDst]);
+ } else {
+ res_lrn_bwd[dnnResourceDiffDst] = user_input;
+ }
+ } else {
+ if (!dnnLayoutCompare_F32(lt_bdw_input, lt_input)) {
+ CHECK_EQ(
+ dnnConversionCreate_F32(&convert_input, lt_input, lt_bdw_input),
+ E_SUCCESS);
+
+ AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_bdw_input,
+ &res_lrn_bwd[dnnResourceDiffDst]);
+ CHECK_EQ(dnnConversionExecute_F32(convert_input, user_input,
+ res_lrn_bwd[dnnResourceDiffDst]),
+ E_SUCCESS);
+ dnnDelete_F32(convert_input);
+ } else {
+ res_lrn_bwd[dnnResourceDiffDst] = user_input;
+ }
+ }
+
+// Although MKL documentation for LRN does not specify setting/getting
+// of dnnResourceSrc and dnnResourceDst, Caffe code sets dnnResourceSrc.
+// So we set dnnResourceSrc here. But we do not know why we are setting
+// dnnResourceDst.
+#if 0
+ // NOTE: The code below is kept just so that we know how we should handle
+ // dnnResourceSrc if the primitive layout for dnnResourceSrc was supported.
+
+ if (!dnnLayoutCompare_F32(lt_internal_input,
+ static_cast<dnnLayout_t>inimage_shape.GetCurLayout())) {
+ AllocTmpBuffer(context, mkl_tmp_image_buf_tensor, lt_internal_input,
+ &res_lrn_bwd[dnnResourceSrc]);
+ inimage_shape.GetConvertedFlatData(lt_internal_input,
+ user_fwd_input,
+ res_lrn_bwd[dnnResourceSrc]);
+ } else {
+ res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
+ }
+#endif
+
+ // Since we cannot get expected layout for dnnResourceSrc, we construct
+ // buffer using
+ // MKL format if input is in MKL format.
+ if (inimage_shape.IsMklTensor()) {
+ AllocTmpBuffer(context, mkl_tmp_image_buf_tensor,
+ (dnnLayout_t)inimage_shape.GetCurLayout(),
+ &res_lrn_bwd[dnnResourceSrc]);
+ } else {
+ res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
+ }
+
+ // Same comment as above.
+ if (outimage_shape.IsMklTensor()) {
+ AllocTmpBuffer(context, mkl_tmp_outimage_buf_tensor,
+ (dnnLayout_t)outimage_shape.GetCurLayout(),
+ &res_lrn_bwd[dnnResourceDst]);
+ } else {
+ res_lrn_bwd[dnnResourceDst] = user_fwd_output;
+ }
+
+ // Allocate buffer for workspace.
+ AllocTmpBuffer(context, mkl_tmp_workspace_buf_tensor, lt_workspace,
+ &res_lrn_bwd[dnnResourceWorkspace]);
+ }
+
+ // Fallback implementation - Taken from lrn_op.cc
+ // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
+ // copy.
+ void MklDefaultToEigen(OpKernelContext* context) {
+ // CHECK(false);
+ Tensor in_grads = MklGetInput(context, 0);
+ Tensor in_image = MklGetInput(context, 1);
+ Tensor out_image = MklGetInput(context, 2);
+
+ GetMklShape(context, 0, &ingrad_shape);
+ GetMklShape(context, 1, &inimage_shape);
+ GetMklShape(context, 2, &outimage_shape);
+
+ const int64 batch = static_cast<int64>(in_grads.dim_size(0));
+ const int64 rows = static_cast<int64>(in_grads.dim_size(1));
+ const int64 cols = static_cast<int64>(in_grads.dim_size(2));
+ const int64 depth = static_cast<int64>(in_grads.dim_size(3));
+ const auto nodes = cols * rows;
+
+ auto grads_shaped = in_grads.shaped<T, 2>({nodes * batch, depth});
+ auto in_shaped = in_image.shaped<T, 2>({nodes * batch, depth});
+ auto activations = out_image.shaped<T, 2>({nodes * batch, depth});
+
+ Tensor* output;
+ MklShape mkl_output_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(false);
+ mkl_output_mkl_shape.SetDimensions(4);
+ AllocateOutputSetMklShape(context, 0, &output, in_grads.shape(),
+ mkl_output_mkl_shape);
+
+ auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
+ out_shaped.setZero();
+ auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
+ depth](int64 begin, int64 end) {
+ for (int64 i = begin; i < end; ++i) {
+ for (int64 j = 0; j < depth; ++j) {
+ int64 depth_begin = std::max<int64>(0, j - depth_radius_);
+ int64 depth_end = std::min<int64>(depth, j + depth_radius_ + 1);
+
+ T norm(0);
+ for (int64 k = depth_begin; k < depth_end; ++k) {
+ norm += in_shaped(i, k) * in_shaped(i, k);
+ }
+ norm = alpha_ * norm + bias_;
+ DCHECK_GT(norm, T(1e-6));
+ for (int64 k = depth_begin; k < depth_end; ++k) {
+ T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) *
+ activations(i, j) / norm;
+ if (k == j) {
+ dyi += Eigen::numext::pow(norm, -beta_);
+ }
+ dyi *= grads_shaped(i, j);
+ const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) +=
+ dyi;
+ }
+ }
+ }
+ };
+ auto worker_threads =
+ *(context->device()->tensorflow_cpu_worker_threads());
+ Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
+ depth * depth, shard);
+ }
+
+ // release mkl resources
+ void Mklcleanup() {
+ bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
+ bool inimage_in_mkl_format = inimage_shape.IsMklTensor();
+ if (!ingrad_in_mkl_format) {
+ CHECK_EQ(dnnLayoutDelete_F32(lt_input), E_SUCCESS);
+ }
+
+ if (!inimage_in_mkl_format) {
+ CHECK_EQ(dnnLayoutDelete_F32(lt_output), E_SUCCESS);
+ }
+ dnnDelete_F32(lrn_bwd);
+ dnnLayoutDelete_F32(lt_bdw_input);
+ dnnLayoutDelete_F32(lt_workspace);
+ }
+ } MklLRNGradOpContext;
+
+ typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
+ bool workspace_enabled_;
+ int depth_radius_;
+ float bias_;
+ float alpha_;
+ float beta_;
+};
+
+#define REGISTER_MKL_LRN_CPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklLRN") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklLRNOp<T>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklLRNGrad") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklLRNGradOp<T>);
+
+TF_CALL_float(REGISTER_MKL_LRN_CPU);
+
+} // namespace tensorflow
+
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc
index 9d6cfb0c97..e27881f882 100644
--- a/tensorflow/core/kernels/mkl_maxpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc
@@ -83,10 +83,11 @@ class MklMaxPoolingOp : public OpKernel {
ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
mkl_context.MklCreateLayoutsAndPrimitives(context);
+ OP_REQUIRES_OK(context, context->status());
// Declare output tensor
TensorShape tensor_out_shape;
- MklShape mkl_out_shape;
+ MklShape mkl_out_shape, mkl_workspace_shape;
mkl_out_shape.SetMklTensor(true);
mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst);
mkl_out_shape.SetTfLayout(mkl_context.params.in_dim,
@@ -98,31 +99,22 @@ class MklMaxPoolingOp : public OpKernel {
tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
mkl_out_shape.GetMklLayout())) /
sizeof(T));
- AllocateOutputSetMklshape(context, 0, &output_tensor, tensor_out_shape,
+ AllocateOutputSetMklShape(context, 0, &output_tensor, tensor_out_shape,
mkl_out_shape);
- if (!workspace_enabled_) {
- mkl_out_shape.SetMklTensor(false);
- }
-
Tensor* workspace_tensor;
void* workspace_buf = nullptr;
- if (workspace_enabled_) {
- TensorShape workspace_shape;
- workspace_shape.AddDim(
- dnnLayoutGetMemorySize_F32(
- static_cast<dnnLayout_t>(mkl_context.lt_workspace)) /
- sizeof(T));
- AllocateOutputSetMklshape(context, 1, &workspace_tensor, workspace_shape,
- mkl_out_shape);
- mkl_context.pooling_res[dnnResourceWorkspace] = const_cast<void*>(
- static_cast<const void*>(workspace_tensor->flat<T>().data()));
- } else {
- AllocTmpBuffer(context, workspace_tensor, mkl_context.lt_workspace,
- &workspace_buf);
- mkl_context.pooling_res[dnnResourceWorkspace] = workspace_buf;
- }
+ TensorShape workspace_shape;
+ mkl_workspace_shape.SetMklTensor(false);
+ workspace_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
+ mkl_context.lt_workspace)) /
+ sizeof(T));
+ AllocateOutputSetMklShape(context, 1, &workspace_tensor, workspace_shape,
+ mkl_workspace_shape);
+
+ mkl_context.pooling_res[dnnResourceWorkspace] = const_cast<void*>(
+ static_cast<const void*>(workspace_tensor->flat<T>().data()));
mkl_context.pooling_res[dnnResourceSrc] =
const_cast<void*>(static_cast<const void*>(tensor_in.flat<T>().data()));
mkl_context.pooling_res[dnnResourceDst] = const_cast<void*>(
@@ -140,8 +132,8 @@ class MklMaxPoolingOp : public OpKernel {
MklPoolingOpParams params;
MklShape input_shape;
void* pooling_res[dnnResourceNumber];
- dnnPrimitive_t prim_pooling_fwd;
- dnnLayout_t lt_user_input, lt_workspace;
+ dnnPrimitive_t prim_pooling_fwd = nullptr;
+ dnnLayout_t lt_user_input = nullptr, lt_workspace = nullptr;
void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
bool input_in_mkl_format = input_shape.IsMklTensor();
@@ -256,8 +248,13 @@ class MklMaxPoolingGradOp : public OpKernel {
ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
mkl_context.MklCreateLayouts(context);
+ OP_REQUIRES_OK(context, context->status());
+
mkl_context.MklCreatePrimitives(context, workspace_enabled_);
+ OP_REQUIRES_OK(context, context->status());
+
mkl_context.MklPrepareInputs(context, workspace_enabled_);
+ OP_REQUIRES_OK(context, context->status());
// Create shape for the input back prop output
TensorShape mkl_input_backprop;
@@ -274,7 +271,7 @@ class MklMaxPoolingGradOp : public OpKernel {
dnnLayoutGetMemorySize_F32(
static_cast<dnnLayout_t>(mkl_output_shape.GetMklLayout())) /
sizeof(T));
- AllocateOutputSetMklshape(context, 0, &output_tensor, mkl_input_backprop,
+ AllocateOutputSetMklShape(context, 0, &output_tensor, mkl_input_backprop,
mkl_output_shape);
mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>(
static_cast<const void*>(output_tensor->flat<T>().data()));
@@ -297,12 +294,15 @@ class MklMaxPoolingGradOp : public OpKernel {
MklShape input_shape, output_backprop_shape;
void* pooling_resfwd[dnnResourceNumber];
void* pooling_res[dnnResourceNumber];
- dnnPrimitive_t prim_pooling_fwd, prim_pooling_bwd, convert_input,
- convert_outbackprop;
- dnnLayout_t lt_outbackprop_user, lt_outbackprop_prim, lt_input_user,
- lt_input_prim;
+ dnnPrimitive_t prim_pooling_fwd = nullptr, prim_pooling_bwd = nullptr,
+ convert_input = nullptr, convert_outbackprop = nullptr;
+ dnnLayout_t lt_outbackprop_user = nullptr, lt_outbackprop_prim = nullptr,
+ lt_input_user = nullptr, lt_input_prim = nullptr;
void* input_buf;
void* outbackprop_buf;
+ Tensor tmp_output_buf_tensor;
+ Tensor workspace_buf_tensor;
+ Tensor input_buf_tensor, outbackprop_buf_tensor;
void MklCreateLayouts(OpKernelContext* context) {
bool input_in_mkl_format = input_shape.IsMklTensor();
@@ -351,9 +351,6 @@ class MklMaxPoolingGradOp : public OpKernel {
&lt_outbackprop_prim, prim_pooling_bwd, dnnResourceDiffDst),
E_SUCCESS);
- // Tensors needed to create temporary buffers
- Tensor input_buf_tensor, outbackprop_buf_tensor;
-
if (workspace_enabled == false) {
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
&lt_input_prim, prim_pooling_fwd, dnnResourceSrc),
@@ -384,11 +381,8 @@ class MklMaxPoolingGradOp : public OpKernel {
bool input_in_mkl_format = input_shape.IsMklTensor();
bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
- void* tmp_output_buf;
- Tensor tmp_output_buf_tensor;
-
- void* workspace_buf;
- Tensor workspace_buf_tensor;
+ void* tmp_output_buf = nullptr;
+ void* workspace_buf = nullptr;
if (workspace_enabled == false) {
if (convert_input != nullptr) {
@@ -490,16 +484,16 @@ class MklMaxPoolingGradOp : public OpKernel {
bool workspace_enabled_;
};
-REGISTER_KERNEL_BUILDER(Name("MklMaxPool")
+REGISTER_KERNEL_BUILDER(Name("_MklMaxPool")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T")
- .Label(mkl_layer_registry::kMklLayerLabel),
+ .Label(mkl_op_registry::kMklOpLabel),
MklMaxPoolingOp<CPUDevice, float>);
-REGISTER_KERNEL_BUILDER(Name("MklMaxPoolGrad")
+REGISTER_KERNEL_BUILDER(Name("_MklMaxPoolGrad")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T")
- .Label(mkl_layer_registry::kMklLayerLabel),
+ .Label(mkl_op_registry::kMklOpLabel),
MklMaxPoolingGradOp<CPUDevice, float>);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 7809711524..25c8359cc5 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -63,7 +63,7 @@ class MklReluOp : public OpKernel {
const TensorShape& o_shape = input.shape();
Tensor* out_tensor = nullptr;
mkl_context.output_shape.SetMklTensor(false);
- AllocateOutputSetMklshape(context, 0, &out_tensor, o_shape,
+ AllocateOutputSetMklShape(context, 0, &out_tensor, o_shape,
mkl_context.output_shape);
void* out_o = static_cast<void*>(out_tensor->flat<T>().data());
(static_cast<T*>(out_o))[0] =
@@ -114,12 +114,12 @@ class MklReluOp : public OpKernel {
tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
mkl_context.output_shape.GetMklLayout())) /
sizeof(T));
- AllocateOutputSetMklshape(context, 0, &output, tf_shape,
+ AllocateOutputSetMklShape(context, 0, &output, tf_shape,
mkl_context.output_shape);
} else {
const TensorShape& o_shape = input.shape();
mkl_context.output_shape.SetMklTensor(false);
- AllocateOutputSetMklshape(context, 0, &output, o_shape,
+ AllocateOutputSetMklShape(context, 0, &output, o_shape,
mkl_context.output_shape);
}
@@ -293,7 +293,7 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
// Allocate space for g and
const TensorShape& g_shape = g.shape();
mkl_context.output_shape.SetMklTensor(false);
- AllocateOutputSetMklshape(context, 0, &output, g_shape,
+ AllocateOutputSetMklShape(context, 0, &output, g_shape,
mkl_context.output_shape);
void* out_o = static_cast<void*>(output->flat<T>().data());
(static_cast<T*>(out_o))[0] =
@@ -359,13 +359,13 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
mkl_context.output_shape.GetMklLayout())) /
sizeof(T));
- AllocateOutputSetMklshape(context, 0, &output, tf_shape,
+ AllocateOutputSetMklShape(context, 0, &output, tf_shape,
mkl_context.output_shape);
} else {
const TensorShape& o_shape = g.shape();
mkl_context.output_shape.SetMklTensor(false);
- AllocateOutputSetMklshape(context, 0, &output, o_shape,
+ AllocateOutputSetMklShape(context, 0, &output, o_shape,
mkl_context.output_shape);
}
@@ -379,16 +379,16 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
/* Register DNN kernels for supported operations and supported types - right now
* it is only Relu and f32*/
-#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
- REGISTER_KERNEL_BUILDER(Name("MklRelu") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .Label(mkl_layer_registry::kMklLayerLabel), \
- MklReluOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER(Name("MklReluGrad") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .Label(mkl_layer_registry::kMklLayerLabel), \
+#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
+ REGISTER_KERNEL_BUILDER(Name("_MklRelu") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklReluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
MklReluGradOp<CPUDevice, type>);
TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES);
diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc
new file mode 100644
index 0000000000..753a8b52b4
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_reshape_op.cc
@@ -0,0 +1,149 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+#include "third_party/mkl/include/mkl_dnn.h"
+#include "third_party/mkl/include/mkl_dnn_types.h"
+#include "tensorflow/core/util/mkl_util.h"
+
+namespace tensorflow {
+using CPUDevice = Eigen::ThreadPoolDevice;
+template <typename Device, typename T>
+class MklReshapeOp : public OpKernel {
+ public:
+ explicit MklReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = MklGetInput(context, 0);
+ const Tensor& sizes = MklGetInput(context, 1);
+
+ // Preliminary validation of sizes.
+ OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
+ errors::InvalidArgument("sizes input must be 1-D, not shape ",
+ sizes.shape().DebugString()));
+ const int64 num_dims = sizes.NumElements();
+
+ // Compute the output shape. Determine product of specified
+ // dimensions, and find the index of the unspecified one.
+ TensorShape shape;
+ int64 product = 1;
+ int unknown_index = -1;
+ auto vec_size = sizes.flat<int32>();
+ for (int d = 0; d < num_dims; ++d) {
+ const int32 size = vec_size(d);
+ if (size == -1) {
+ OP_REQUIRES(
+ context, unknown_index == -1,
+ errors::InvalidArgument("only one input size may be -1, not both ",
+ unknown_index, " and ", d));
+ unknown_index = d;
+ shape.AddDim(1);
+ } else {
+ OP_REQUIRES(context, size >= 0,
+ errors::InvalidArgument(
+ "size ", d, " must be non-negative, not ", size));
+ shape.AddDim(size);
+ product *= size;
+ }
+ }
+ if (unknown_index != -1) {
+ OP_REQUIRES(
+ context, product > 0,
+ errors::InvalidArgument("Reshape cannot infer the missing input size "
+ "for an empty tensor unless all specified "
+ "input sizes are non-zero"));
+ const int64 missing = input.NumElements() / product;
+ OP_REQUIRES(
+ context, product * missing == input.NumElements(),
+ errors::InvalidArgument(
+ "Input to reshape is a tensor with ", input.NumElements(),
+ " values, but the requested shape requires a multiple of ",
+ product));
+ shape.set_dim(unknown_index, missing);
+ }
+ OP_REQUIRES(context, shape.num_elements() == input.NumElements(),
+ errors::InvalidArgument("Input to reshape is a tensor with ",
+ input.NumElements(),
+ " values, but the requested shape has ",
+ shape.num_elements()));
+
+ MklShape mkl_shape_input;
+ GetMklShape(context, 0, &mkl_shape_input);
+ bool input_in_mkl_format = mkl_shape_input.IsMklTensor();
+ if (input_in_mkl_format) {
+ TensorShape& shape_to = shape;
+ TensorShape shape_from;
+ for (size_t i = 0; i < mkl_shape_input.GetDimension(); i++) {
+ // Outermost to innermost dimension
+ shape_from.AddDim(
+ mkl_shape_input.GetSizes()[mkl_shape_input.tf_dim_idx(i)]);
+ }
+
+ if (shape_from == shape_to) {
+ CopyMklTensorInToOut(context, 0, 0);
+ return;
+ } else {
+ // Allocate output tensor.
+ Tensor* output_tensor = NULL;
+ MklShape mkl_shape_output;
+ mkl_shape_output.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
+ mkl_shape_output);
+
+ // Get output layout pointer.
+ dnnLayout_t output_layout =
+ static_cast<dnnLayout_t>(mkl_shape_input.GetTfLayout());
+
+ // Execute DNNConversion.
+ // Note: we assume an MKL tensor always have float as its data type.
+ void* input_buffer =
+ static_cast<void*>(const_cast<float*>(input.flat<float>().data()));
+ void* output_buffer = static_cast<void*>(
+ const_cast<float*>(output_tensor->flat<float>().data()));
+ mkl_shape_input.GetConvertedFlatData(output_layout, input_buffer,
+ output_buffer);
+
+ VLOG(1) << "MKLToTFConversion complete successfully.";
+ return;
+ }
+ } else {
+ CopyTFTensorInToOut(context, 0, 0, shape);
+ }
+ }
+};
+
+#define REGISTER_MKL_CPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklReshape") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("shape") \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tshape") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklReshapeOp<CPUDevice, T>);
+TF_CALL_float(REGISTER_MKL_CPU);
+#undef REGISTER_MKL_CPU
+} // namespace tensorflow
+
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.cc b/tensorflow/core/kernels/mkl_tfconv_op.cc
index 51f90b3f90..c31ef5c255 100644
--- a/tensorflow/core/kernels/mkl_tfconv_op.cc
+++ b/tensorflow/core/kernels/mkl_tfconv_op.cc
@@ -105,11 +105,11 @@ class MklToTfOp : public OpKernel {
// Register kernel
///////////////////////////////////////////////////////////
-#define REGISTER_CPU(T) \
- REGISTER_KERNEL_BUILDER(Name("MklToTf") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_layer_registry::kMklLayerLabel), \
+#define REGISTER_CPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("MklToTf") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
MklToTfOp<CPUDevice, T>);
TF_CALL_float(REGISTER_CPU);
diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc
index afa1f65aef..56a7e161df 100644
--- a/tensorflow/core/kernels/quantized_conv_ops.cc
+++ b/tensorflow/core/kernels/quantized_conv_ops.cc
@@ -233,9 +233,9 @@ class Im2ColConvFunctor {
int filter_top_offset;
if (padding == VALID) {
filter_left_offset =
- ((output_width - 1) * stride + filter_width - input_width) / 2;
+ ((output_width - 1) * stride + filter_width - input_width + 1) / 2;
filter_top_offset =
- ((output_height - 1) * stride + filter_height - input_height) / 2;
+ ((output_height - 1) * stride + filter_height - input_height + 1) / 2;
} else {
filter_left_offset =
((output_width - 1) * stride + filter_width - input_width) / 2;
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
index 2bcc7f407d..30026f222a 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
@@ -29,7 +29,7 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-template <typename Device, typename T>
+template <typename Device, typename T, typename Tindices>
class SparseTensorDenseMatMulOp : public OpKernel {
public:
explicit SparseTensorDenseMatMulOp(OpKernelConstruction* ctx)
@@ -139,15 +139,14 @@ class SparseTensorDenseMatMulOp : public OpKernel {
TensorShape({0}), &scratch));
}
-#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \
- if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \
- Status functor_status = functor::SparseTensorDenseMatMulFunctor< \
- Device, T, ADJ_A, ADJ_B>::Compute(ctx->eigen_device<Device>(), \
- out->matrix<T>(), \
- a_indices->matrix<int64>(), \
- a_values->vec<T>(), b->matrix<T>(), \
- scratch.vec<T>()); \
- OP_REQUIRES_OK(ctx, functor_status); \
+#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \
+ if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \
+ Status functor_status = functor::SparseTensorDenseMatMulFunctor< \
+ Device, T, Tindices, ADJ_A, \
+ ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<T>(), \
+ a_indices->matrix<Tindices>(), a_values->vec<T>(), \
+ b->matrix<T>(), scratch.vec<T>()); \
+ OP_REQUIRES_OK(ctx, functor_status); \
}
MAYBE_ADJOINT(false, false);
@@ -163,53 +162,73 @@ class SparseTensorDenseMatMulOp : public OpKernel {
bool adjoint_b_;
};
-#define REGISTER_CPU(T) \
- REGISTER_KERNEL_BUILDER(Name("SparseTensorDenseMatMul") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .HostMemory("a_shape"), \
- SparseTensorDenseMatMulOp<CPUDevice, T>);
-
-REGISTER_CPU(float);
-REGISTER_CPU(double);
-REGISTER_CPU(int32);
-REGISTER_CPU(complex64);
-REGISTER_CPU(complex128);
+#define REGISTER_CPU(TypeT, TypeIndex) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseTensorDenseMatMul") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TypeT>("T") \
+ .TypeConstraint<TypeIndex>("Tindices") \
+ .HostMemory("a_shape"), \
+ SparseTensorDenseMatMulOp<CPUDevice, TypeT, TypeIndex>);
+
+#define REGISTER_KERNELS_CPU(T) \
+ REGISTER_CPU(T, int64); \
+ REGISTER_CPU(T, int32)
+
+REGISTER_KERNELS_CPU(float);
+REGISTER_KERNELS_CPU(double);
+REGISTER_KERNELS_CPU(int32);
+REGISTER_KERNELS_CPU(complex64);
+REGISTER_KERNELS_CPU(complex128);
#if GOOGLE_CUDA
namespace functor {
-#define DECLARE_GPU_SPEC(T, ADJ_A, ADJ_B) \
- template <> \
- Status SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A, ADJ_B>::Compute( \
- const GPUDevice& d, typename TTypes<T>::Matrix out, \
- TTypes<int64>::ConstMatrix a_indices, \
- typename TTypes<T>::ConstVec a_values, \
- typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch); \
- extern template struct SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A, \
- ADJ_B>;
-
-#define DECLARE_ADJOINT_GPU_SPEC(T) \
- DECLARE_GPU_SPEC(T, false, false) \
- DECLARE_GPU_SPEC(T, false, true) \
- DECLARE_GPU_SPEC(T, true, false) \
- DECLARE_GPU_SPEC(T, true, true)
+#define DECLARE_GPU_SPEC(T, Tindices, ADJ_A, ADJ_B) \
+ template <> \
+ Status SparseTensorDenseMatMulFunctor< \
+ GPUDevice, T, Tindices, ADJ_A, \
+ ADJ_B>::Compute(const GPUDevice& d, typename TTypes<T>::Matrix out, \
+ typename TTypes<Tindices>::ConstMatrix a_indices, \
+ typename TTypes<T>::ConstVec a_values, \
+ typename TTypes<T>::ConstMatrix b, \
+ typename TTypes<T>::Vec scratch); \
+ extern template struct SparseTensorDenseMatMulFunctor< \
+ GPUDevice, T, Tindices, ADJ_A, ADJ_B>;
+
+#define REGISTER_GPU_SPEC(T, ADJ_A, ADJ_B) \
+ DECLARE_GPU_SPEC(T, int32, ADJ_A, ADJ_B); \
+ DECLARE_GPU_SPEC(T, int64, ADJ_A, ADJ_B)
+
+#define DECLARE_ADJOINT_GPU_SPEC(T) \
+ REGISTER_GPU_SPEC(T, false, false) \
+ REGISTER_GPU_SPEC(T, false, true) \
+ REGISTER_GPU_SPEC(T, true, false) \
+ REGISTER_GPU_SPEC(T, true, true)
DECLARE_ADJOINT_GPU_SPEC(float);
#undef DECLARE_ADJOINT_GPU_SPEC
#undef DECLARE_GPU_SPEC
+#undef REGISTER_GPU_SPEC
} // namespace functor
-#define REGISTER_GPU(T) \
- REGISTER_KERNEL_BUILDER(Name("SparseTensorDenseMatMul") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<T>("T") \
- .HostMemory("a_shape"), \
- SparseTensorDenseMatMulOp<GPUDevice, T>);
+#define REGISTER_GPU(TypeT, TypeIndex) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseTensorDenseMatMul") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<TypeT>("T") \
+ .TypeConstraint<TypeIndex>("Tindices") \
+ .HostMemory("a_shape"), \
+ SparseTensorDenseMatMulOp<GPUDevice, TypeT, TypeIndex>);
+
+#define REGISTER_KERNELS_GPU(T) \
+ REGISTER_GPU(T, int64); \
+ REGISTER_GPU(T, int32)
-REGISTER_GPU(float);
+REGISTER_KERNELS_GPU(float);
#undef REGISTER_GPU
+#undef REGISTER_KERNELS_GPU
#endif // GOOGLE_CUDA
namespace functor {
@@ -228,13 +247,13 @@ Status MOutOfBoundsError(int64 m, std::size_t i, int lhs_index_a,
}
} // namespace
-template <typename T, bool ADJ_A, bool ADJ_B>
-struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> {
+template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
+struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
// Vectorize certain operations above this size.
static const std::size_t kNumVectorize = 32;
static Status Compute(const CPUDevice& d, typename TTypes<T>::Matrix out,
- TTypes<int64>::ConstMatrix a_indices,
+ typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<T>::ConstVec a_values,
typename TTypes<T>::ConstMatrix b,
typename TTypes<T>::Vec scratch) {
@@ -255,8 +274,8 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> {
auto maybe_adjoint_b = MaybeAdjoint<decltype(b), ADJ_B>(b);
for (std::size_t i = 0; i < nnz; ++i) {
- const int64 m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
- const int64 k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
+ const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
+ const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
if (!FastBoundsCheck(k, lhs_right)) {
return KOutOfBoundsError(k, i, rhs_index_a, lhs_right);
}
@@ -273,19 +292,19 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> {
// Vectorization via Eigen.
const int b_chip_index = ADJ_B ? 1 : 0;
-#define LOOP_NNZ(b_passed) \
- for (std::size_t i = 0; i < nnz; ++i) { \
- const int64 m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \
- const int64 k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \
- const T a_value = (ADJ_A) ? MaybeConj(a_values(i)) : a_values(i); \
- if (!FastBoundsCheck(k, lhs_right)) { \
- return KOutOfBoundsError(k, i, rhs_index_a, lhs_right); \
- } \
- if (!FastBoundsCheck(m, out.dimension(0))) { \
- return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0)); \
- } \
- out.template chip<0>(m) += \
- b_passed.template chip<b_chip_index>(k) * a_value; \
+#define LOOP_NNZ(b_passed) \
+ for (std::size_t i = 0; i < nnz; ++i) { \
+ const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \
+ const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \
+ const T a_value = (ADJ_A) ? MaybeConj(a_values(i)) : a_values(i); \
+ if (!FastBoundsCheck(k, lhs_right)) { \
+ return KOutOfBoundsError(k, i, rhs_index_a, lhs_right); \
+ } \
+ if (!FastBoundsCheck(m, out.dimension(0))) { \
+ return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0)); \
+ } \
+ out.template chip<0>(m) += \
+ b_passed.template chip<b_chip_index>(k) * a_value; \
}
if (ADJ_B) {
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
index bcb836367b..e707743f78 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
@@ -25,11 +25,12 @@ namespace tensorflow {
namespace functor {
-template <typename Device, typename T, bool ADJ_A, bool ADJ_B>
+template <typename Device, typename T, typename Tindices, bool ADJ_A,
+ bool ADJ_B>
struct SparseTensorDenseMatMulFunctor {
static EIGEN_ALWAYS_INLINE Status
Compute(const Device& d, typename TTypes<T>::Matrix out,
- TTypes<int64>::ConstMatrix a_indices,
+ typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<T>::ConstVec a_values,
typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch);
};
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
index 07d218311e..7266e0cf81 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
@@ -27,12 +27,12 @@ typedef Eigen::GpuDevice GPUDevice;
namespace generator {
-template <typename T, bool ADJ_A, bool ADJ_B>
+template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
class SparseTensorDenseMatMulGPUGenerator {
public:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseTensorDenseMatMulGPUGenerator(
typename TTypes<T, 2>::Tensor32Bit out,
- TTypes<const int64, 2>::Tensor32Bit a_indices,
+ typename TTypes<const Tindices, 2>::Tensor32Bit a_indices,
typename TTypes<const T, 1>::Tensor32Bit a_values,
typename TTypes<const T, 2>::Tensor32Bit b)
: out_(out),
@@ -77,7 +77,7 @@ class SparseTensorDenseMatMulGPUGenerator {
mutable typename TTypes<T, 2>::Tensor32Bit out_;
const int lhs_index_a_;
const int rhs_index_a_;
- TTypes<const int64, 2>::Tensor32Bit a_indices_;
+ typename TTypes<const Tindices, 2>::Tensor32Bit a_indices_;
typename TTypes<const T, 1>::Tensor32Bit a_values_;
const int lhs_right_size;
functor::MaybeAdjoint<typename TTypes<const T, 2>::Tensor32Bit, ADJ_B>
@@ -88,14 +88,14 @@ class SparseTensorDenseMatMulGPUGenerator {
namespace functor {
-template <typename T, bool ADJ_A, bool ADJ_B>
-struct SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A, ADJ_B> {
+template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
+struct SparseTensorDenseMatMulFunctor<GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
static EIGEN_ALWAYS_INLINE Status
Compute(const GPUDevice& d, typename TTypes<T>::Matrix out,
- TTypes<int64>::ConstMatrix a_indices,
+ typename TTypes<Tindices>::ConstMatrix a_indices,
typename TTypes<T>::ConstVec a_values,
typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch) {
- generator::SparseTensorDenseMatMulGPUGenerator<T, ADJ_A, ADJ_B>
+ generator::SparseTensorDenseMatMulGPUGenerator<T, Tindices, ADJ_A, ADJ_B>
sparse_tensor_dense_matmul_generator(To32Bit(out), To32Bit(a_indices),
To32Bit(a_values), To32Bit(b));
To32Bit(out).device(d) = To32Bit(out).constant(T(0));
@@ -146,17 +146,18 @@ struct SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A, ADJ_B> {
} // namespace functor
-#define DEFINE(T) \
- template struct functor::SparseTensorDenseMatMulFunctor<GPUDevice, T, false, \
- false>; \
- template struct functor::SparseTensorDenseMatMulFunctor<GPUDevice, T, false, \
- true>; \
- template struct functor::SparseTensorDenseMatMulFunctor<GPUDevice, T, true, \
- false>; \
- template struct functor::SparseTensorDenseMatMulFunctor<GPUDevice, T, true, \
- true>;
-
-DEFINE(float);
+#define DEFINE(T, Tindices) \
+ template struct functor::SparseTensorDenseMatMulFunctor< \
+ GPUDevice, T, Tindices, false, false>; \
+ template struct functor::SparseTensorDenseMatMulFunctor< \
+ GPUDevice, T, Tindices, false, true>; \
+ template struct functor::SparseTensorDenseMatMulFunctor< \
+ GPUDevice, T, Tindices, true, false>; \
+ template struct functor::SparseTensorDenseMatMulFunctor< \
+ GPUDevice, T, Tindices, true, true>;
+
+DEFINE(float, int32);
+DEFINE(float, int64);
#undef DEFINE
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 11df3c43c7..e540ecfa8d 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -394,6 +394,28 @@ output: A `Tensor` with the concatenation of values stacked along the
in `concat_dim` where it has the sum of the sizes.
)doc");
+// TODO(vivek.v.rane@intel.com): Prefix the op names with underscore if the ops
+// are not to be made user-accessible.
+#ifdef INTEL_MKL
+REGISTER_OP("_MklConcatV2")
+ .Input("values: N * T")
+ .Input("axis: Tidx")
+ .Input("mkl_values: N * uint8")
+ .Input("mkl_axis: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("N: int >= 2")
+ .Attr("T: type")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .SetShapeFn(shape_inference::ConcatV2Shape)
+ .Doc(R"doc(
+MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+#endif
+
REGISTER_OP("ConcatOffset")
.Input("concat_dim: int32")
.Input("shape: N * int32")
@@ -1638,6 +1660,21 @@ reshape(t, []) ==> 7
shape: Defines the shape of the output tensor.
)Doc");
+#ifdef INTEL_MKL
+REGISTER_OP("_MklReshape")
+ .Input("tensor: T")
+ .Input("shape: Tshape")
+ .Input("mkl_tensor: uint8")
+ .Input("mkl_shape: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: type")
+ .Attr("Tshape: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); })
+ .Doc(R"Doc( MKL implementation of ReshapeOp.
+)Doc");
+#endif // INTEL_MKL
+
// --------------------------------------------------------------------------
REGISTER_OP("InvertPermutation")
.Input("x: T")
@@ -4965,6 +5002,27 @@ backprop_wrt_max: Backpropagated gradients w.r.t. max parameter, shape `[d]`:
`sum_per_d(gradients * (inputs > max))`.
)doc");
+#ifdef INTEL_MKL
+REGISTER_OP("_MklConcat")
+ .Input("concat_dim: int32")
+ .Input("values: N * T")
+ .Input("mkl_concat_dim: uint8")
+ .Input("mkl_values: N * uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("N: int >= 2")
+ .Attr("T: type")
+ .SetShapeFn([](InferenceContext* c) {
+ return shape_inference::ConcatShape(c, c->num_inputs() - 3);
+ })
+ .Doc(R"doc(
+MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+#endif
+
// Deprecated op registrations:
// The following can be deleted after 10mar2017.
diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc
index 3e2583f706..0bce6fc0ea 100644
--- a/tensorflow/core/ops/io_ops.cc
+++ b/tensorflow/core/ops/io_ops.cc
@@ -440,6 +440,7 @@ REGISTER_OP("FixedLengthRecordReader")
.Attr("header_bytes: int = 0")
.Attr("record_bytes: int")
.Attr("footer_bytes: int = 0")
+ .Attr("hop_bytes: int = 0")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
@@ -448,6 +449,11 @@ REGISTER_OP("FixedLengthRecordReader")
A Reader that outputs fixed-length records from a file.
reader_handle: The handle to reference the Reader.
+header_bytes: Number of bytes in the header, defaults to 0.
+record_bytes: Number of bytes in the record.
+footer_bytes: Number of bytes in the footer, defaults to 0.
+hop_bytes: Number of bytes to hop before each read. Default of 0 means using
+ record_bytes.
container: If non-empty, this reader is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this reader is named in the given bucket
@@ -459,6 +465,7 @@ REGISTER_OP("FixedLengthRecordReaderV2")
.Attr("header_bytes: int = 0")
.Attr("record_bytes: int")
.Attr("footer_bytes: int = 0")
+ .Attr("hop_bytes: int = 0")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
@@ -467,6 +474,11 @@ REGISTER_OP("FixedLengthRecordReaderV2")
A Reader that outputs fixed-length records from a file.
reader_handle: The handle to reference the Reader.
+header_bytes: Number of bytes in the header, defaults to 0.
+record_bytes: Number of bytes in the record.
+footer_bytes: Number of bytes in the footer, defaults to 0.
+hop_bytes: Number of bytes to hop before each read. Default of 0 means using
+ record_bytes.
container: If non-empty, this reader is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this reader is named in the given bucket
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index e9d5897af0..932113bf2c 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -2612,10 +2612,10 @@ scale_after_normalization: A bool indicating whether the resulted tensor
)doc");
#ifdef INTEL_MKL
-REGISTER_OP("MklConv2D")
+REGISTER_OP("_MklConv2D")
.Input("input: T")
- .Input("mkl_input: uint8")
.Input("filter: T")
+ .Input("mkl_input: uint8")
.Input("mkl_filter: uint8")
.Output("output: T")
.Output("mkl_output: uint8")
@@ -2632,12 +2632,12 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
-REGISTER_OP("MklConv2DWithBias")
+REGISTER_OP("_MklConv2DWithBias")
.Input("input: T")
- .Input("mkl_input: uint8")
.Input("filter: T")
- .Input("mkl_filter: uint8")
.Input("bias: T")
+ .Input("mkl_input: uint8")
+ .Input("mkl_filter: uint8")
.Input("mkl_bias: uint8")
.Output("output: T")
.Output("mkl_output: uint8")
@@ -2654,12 +2654,12 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
-REGISTER_OP("MklConv2DBackpropFilter")
+REGISTER_OP("_MklConv2DBackpropFilter")
.Input("input: T")
- .Input("mkl_input: uint8")
.Input("filter_sizes: int32")
- .Input("mkl_filter_size: uint8")
.Input("out_backprop: T")
+ .Input("mkl_input: uint8")
+ .Input("mkl_filter_size: uint8")
.Input("mkl_out_backprop: uint8")
.Output("output: T")
.Output("mkl_output: uint8")
@@ -2669,7 +2669,7 @@ REGISTER_OP("MklConv2DBackpropFilter")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
- return InputTensorShapeOrUnknown(c, 2 /* input_idx */, 4 /* ndims */);
+ return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */);
})
.Doc(R"doc(
MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the
@@ -2679,7 +2679,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
-REGISTER_OP("MklConv2DWithBiasBackpropBias")
+REGISTER_OP("_MklConv2DWithBiasBackpropBias")
.Input("out_backprop: T")
.Input("mkl_out_backprop: uint8")
.Output("output: T")
@@ -2695,12 +2695,12 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
-REGISTER_OP("MklConv2DBackpropInput")
+REGISTER_OP("_MklConv2DBackpropInput")
.Input("input_sizes: int32")
- .Input("mkl_input_sizes: uint8")
.Input("filter: T")
- .Input("mkl_filter: uint8")
.Input("out_backprop: T")
+ .Input("mkl_input_sizes: uint8")
+ .Input("mkl_filter: uint8")
.Input("mkl_out_backprop: uint8")
.Output("output: T")
.Output("mkl_output: uint8")
@@ -2720,7 +2720,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
-REGISTER_OP("MklRelu")
+REGISTER_OP("_MklRelu")
.Input("features: T")
.Input("mkl_features: uint8")
.Output("activations: T")
@@ -2734,10 +2734,10 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
-REGISTER_OP("MklReluGrad")
+REGISTER_OP("_MklReluGrad")
.Input("gradients: T")
- .Input("mkl_gradients: uint8")
.Input("features: T")
+ .Input("mkl_gradients: uint8")
.Input("mkl_features: uint8")
.Output("backprops: T")
.Output("mkl_backprops: uint8")
@@ -2751,7 +2751,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
-REGISTER_OP("MklMaxPool")
+REGISTER_OP("_MklMaxPool")
.Attr("T: {float, half} = DT_FLOAT")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
@@ -2761,8 +2761,8 @@ REGISTER_OP("MklMaxPool")
.Input("input: T")
.Input("mkl_input: uint8")
.Output("output: T")
- .Output("mkl_output: uint8")
.Output("workspace: T")
+ .Output("mkl_output: uint8")
.Output("mkl_workspace: uint8")
.SetShapeFn(shape_inference::MaxPoolShape)
.Doc(R"doc(
@@ -2773,7 +2773,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
-REGISTER_OP("MklMaxPoolGrad")
+REGISTER_OP("_MklMaxPoolGrad")
.Attr("T: {float, half} = DT_FLOAT")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
@@ -2781,12 +2781,12 @@ REGISTER_OP("MklMaxPoolGrad")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Input("orig_input: T")
- .Input("mkl_orig_input: uint8")
.Input("orig_output: T")
- .Input("mkl_orig_output: uint8")
.Input("grad: T")
- .Input("mkl_grad: uint8")
.Input("workspace: T")
+ .Input("mkl_orig_input: uint8")
+ .Input("mkl_orig_output: uint8")
+ .Input("mkl_grad: uint8")
.Input("mkl_workspace: uint8")
.Output("output: T")
.Output("mkl_output: uint8")
@@ -2801,7 +2801,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
-REGISTER_OP("MklAvgPool")
+REGISTER_OP("_MklAvgPool")
.Input("value: T")
.Input("mkl_input: uint8")
.Output("output: T")
@@ -2820,10 +2820,10 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
-REGISTER_OP("MklAvgPoolGrad")
+REGISTER_OP("_MklAvgPoolGrad")
.Input("orig_input_shape: int32")
- .Input("mkl_orig_input: uint8")
.Input("grad: T")
+ .Input("mkl_orig_input: uint8")
.Input("mkl_grad: uint8")
.Output("output: T")
.Output("mkl_output: uint8")
@@ -2843,7 +2843,212 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
-REGISTER_OP("MklToTf")
+REGISTER_OP("_MklLRN")
+ .Input("input: T")
+ .Input("mkl_input: uint8")
+ .Output("output: T")
+ .Output("workspace: T")
+ .Output("mkl_output: uint8")
+ .Output("mkl_workspace: uint8")
+ .Attr("depth_radius: int = 5")
+ .Attr("bias: float = 1.0")
+ .Attr("alpha: float = 1.0")
+ .Attr("beta: float = 0.5")
+ .Attr("workspace_enabled: bool = false")
+ .Attr("T: {float, half} = DT_FLOAT")
+ .SetShapeFn([](InferenceContext* c) {
+ return UnchangedShapeWithRank(c, 4);
+ })
+ .Doc(R"doc(
+MKL version of LRN operator. Uses MKL DNN APIs to perform local response
+normalization.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklLRNGrad")
+ .Input("input_grads: T")
+ .Input("input_image: T")
+ .Input("output_image: T")
+ .Input("workspace: T")
+ .Input("mkl_input_grads: uint8")
+ .Input("mkl_input_image: uint8")
+ .Input("mkl_output_image: uint8")
+ .Input("mkl_workspace: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("depth_radius: int = 5")
+ .Attr("bias: float = 1.0")
+ .Attr("alpha: float = 1.0")
+ .Attr("beta: float = 0.5")
+ .Attr("workspace_enabled: bool = false")
+ .Attr("T: {float, half} = DT_FLOAT")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads
+ TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image
+ TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image
+ c->set_output(0, s);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of LRNGrad operator. Uses MKL DNN APIs to compute gradient for
+local response normalization.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklFusedBatchNorm")
+ .Input("x: T")
+ .Input("scale: T")
+ .Input("offset: T")
+ .Input("mean: T")
+ .Input("variance: T")
+ .Input("mkl_x: uint8")
+ .Input("mkl_scale: uint8")
+ .Input("mkl_offset: uint8")
+ .Input("mkl_mean: uint8")
+ .Input("mkl_variance: uint8")
+ .Output("y: T")
+ .Output("batch_mean: T")
+ .Output("batch_variance: T")
+ .Output("reserve_space_1: T")
+ .Output("reserve_space_2: T")
+ .Output("mkl_y: uint8")
+ .Output("mkl_batch_mean: uint8")
+ .Output("mkl_batch_variance: uint8")
+ .Output("mkl_reserve_space_1: uint8")
+ .Output("mkl_reserve_space_2: uint8")
+ .Attr("T: numbertype")
+ .Attr("epsilon: float = 0.0001")
+ .Attr("data_format: string = 'NHWC'")
+ .Attr("is_training: bool = true")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle x;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
+
+ bool is_training;
+ c->GetAttr("is_training", &is_training);
+ int number_inputs = (is_training) ? 3 : 5;
+ string data_format;
+ c->GetAttr("data_format", &data_format);
+ DimensionHandle channel_dim =
+ (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1);
+
+ // covers scale, offset, and if is_training is false, mean, variance
+ for (int i = 1; i < number_inputs; ++i) {
+ ShapeHandle vec;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
+ TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
+ }
+
+ ShapeHandle y;
+ if (data_format == "NHWC") {
+ TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y));
+ } else {
+ TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y));
+ }
+ c->set_output(0, y);
+ ShapeHandle vector_shape = c->Vector(channel_dim);
+ c->set_output(1, vector_shape);
+ c->set_output(2, vector_shape);
+ c->set_output(3, vector_shape);
+ c->set_output(4, vector_shape);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of FusedBatchNorm operator. Uses MKL DNN APIs to perform fused
+batch normalization.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklFusedBatchNormGrad")
+ .Input("y_backprop: T")
+ .Input("x: T")
+ .Input("scale: T")
+ .Input("reserve_space_1: T")
+ .Input("reserve_space_2: T")
+ .Input("mkl_y_backprop: uint8")
+ .Input("mkl_x: uint8")
+ .Input("mkl_scale: uint8")
+ .Input("mkl_reserve_space_1: uint8")
+ .Input("mkl_reserve_space_2: uint8")
+ .Output("x_backprop: T")
+ .Output("scale_backprop: T")
+ .Output("offset_backprop: T")
+ .Output("reserve_space_3: T")
+ .Output("reserve_space_4: T")
+ .Output("mkl_x_backprop: uint8")
+ .Output("mkl_scale_backprop: uint8")
+ .Output("mkl_offset_backprop: uint8")
+ .Output("mkl_reserve_space_3: uint8")
+ .Output("mkl_reserve_space_4: uint8")
+ .Attr("T: numbertype")
+ .Attr("epsilon: float = 0.0001")
+ .Attr("data_format: string = 'NHWC'")
+ .Attr("is_training: bool = true")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle y_backprop;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
+ ShapeHandle x;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
+
+ bool is_training;
+ string data_format;
+ c->GetAttr("is_training", &is_training);
+ c->GetAttr("data_format", &data_format);
+ DimensionHandle channel_dim = (data_format == "NHWC")
+ ? c->Dim(y_backprop, 3)
+ : c->Dim(y_backprop, 1);
+ if (data_format == "NHWC") {
+ TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim));
+ } else {
+ TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim));
+ }
+
+ // covers scale, mean (reserve_space_1), variance (reserve_space_2)
+ for (int i = 2; i < 5; ++i) {
+ ShapeHandle vec;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
+ TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
+ }
+
+ ShapeHandle x_backprop;
+ if (data_format == "NHWC") {
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop));
+ } else {
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop));
+ }
+ c->set_output(0, x_backprop);
+ c->set_output(1, c->Vector(channel_dim));
+ c->set_output(2, c->Vector(channel_dim));
+ // Set the correct shapes for reserve_spaces
+ // so that gradients can be performed when
+ // the op is in a symbolic condition.
+ if (is_training) {
+ c->set_output(3, c->Vector(0));
+ c->set_output(4, c->Vector(0));
+ } else {
+ c->set_output(3, c->Vector(channel_dim));
+ c->set_output(4, c->Vector(channel_dim));
+ }
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of FusedBatchNormGrad operator. Uses MKL DNN APIs to compute
+gradients for fused batch normalization.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklToTf")
.Input("input: T")
.Input("mkl_input: uint8")
.Output("output: T")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 6d28cb7e84..cbbabe0b87 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -26417,6 +26417,59 @@ op {
description: "Read @{$math_ops#segmentation$the section on segmentation} for an explanation of\nsegments.\n\nComputes a tensor such that\n`(output[i] = sum_{j...} data[j...]` where the sum is over tuples `j...` such\nthat `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\nrange of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"https://www.tensorflow.org/images/UnsortedSegmentSum.png\" alt>\n</div>"
}
op {
+ name: "UnsortedSegmentSum"
+ input_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "segment_ids"
+ description: "A tensor whose shape is a prefix of `data.shape`."
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "num_segments"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ description: "Has same shape as data, except for the first `segment_ids.rank`\ndimensions, which are replaced with a single dimension which has size\n`num_segments`."
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ summary: "Computes the max along segments of a tensor."
+ description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\sum_j data_j\\\\) where sum is over `j` such\nthat `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\n range of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>"
+}
+op {
name: "Unstage"
output_arg {
name: "values"
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
index 860b3475e9..b90f7a5dfb 100644
--- a/tensorflow/core/ops/sparse_ops.cc
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -128,12 +128,13 @@ pair takes space.
)doc");
REGISTER_OP("SparseTensorDenseMatMul")
- .Input("a_indices: int64")
+ .Input("a_indices: Tindices")
.Input("a_values: T")
.Input("a_shape: int64")
.Input("b: T")
.Output("product: T")
.Attr("T: type")
+ .Attr("Tindices: {int32,int64} = DT_INT64")
.Attr("adjoint_a: bool = false")
.Attr("adjoint_b: bool = false")
.SetShapeFn([](InferenceContext* c) {
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 58ccda5c9b..10414cbca2 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -194,13 +194,15 @@ def tf_kernel_tests_linkstatic():
def tf_additional_lib_defines():
return select({
- "//tensorflow:with_jemalloc": ["TENSORFLOW_USE_JEMALLOC"],
+ "//tensorflow:with_jemalloc_linux_x86_64": ["TENSORFLOW_USE_JEMALLOC"],
+ "//tensorflow:with_jemalloc_linux_ppc64le":["TENSORFLOW_USE_JEMALLOC"],
"//conditions:default": [],
})
def tf_additional_lib_deps():
return select({
- "//tensorflow:with_jemalloc": ["@jemalloc"],
+ "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc"],
+ "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc"],
"//conditions:default": [],
})
@@ -246,3 +248,9 @@ def tf_lib_proto_parsing_deps():
":protos_all_cc",
"//tensorflow/core/platform/default/build_config:proto_parsing",
]
+
+def tf_additional_verbs_lib_defines():
+ return select({
+ "//tensorflow:with_verbs_support": ["TENSORFLOW_USE_VERBS"],
+ "//conditions:default": [],
+ })
diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl
index 79f97c1234..eb804bfc78 100644
--- a/tensorflow/core/platform/default/build_config_root.bzl
+++ b/tensorflow/core/platform/default/build_config_root.bzl
@@ -22,3 +22,11 @@ def tf_additional_license_deps():
"//tensorflow:with_xla_support": ["@llvm//:LICENSE.TXT"],
"//conditions:default": [],
})
+
+def tf_additional_verbs_deps():
+ return select({
+ "//tensorflow:with_verbs_support": [
+ "//tensorflow/contrib/verbs:verbs_server_lib",
+ "//tensorflow/contrib/verbs:grpc_verbs_client"],
+ "//conditions:default": [],
+ })
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index dfffbfa396..df33cf38c9 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -20,11 +20,11 @@ limitations under the License.
#define TF_MAJOR_VERSION 1
#define TF_MINOR_VERSION 1
-#define TF_PATCH_VERSION 0-rc1
+#define TF_PATCH_VERSION 0
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX ""
+#define TF_VERSION_SUFFIX "-rc2"
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index ebbe195bbc..897b174eff 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -75,7 +75,6 @@ class MklShape {
void SetTfLayout(const size_t dimension, const size_t* sizes,
const size_t* strides) {
dimension_ = dimension;
-
if (dimension > 0) { // MKl doesn't support zero dimension tensors
sizes_ = new size_t[dimension];
strides_ = new size_t[dimension];
@@ -140,6 +139,39 @@ class MklShape {
const size_t* GetTfToMklDimMap() const { return tf_to_mkl_dim_map_; }
size_t tf_dim_idx(int index) const { return tf_to_mkl_dim_map_[index]; }
+ // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
+ // corresponds to MKL's Channel dimension.
+ bool IsMklChannelDim(int d) const { return tf_dim_idx(d) == MklDims::C; }
+ // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
+ // corresponds to MKL's Batch dimension.
+ bool IsMklBatchDim(int d) const { return tf_dim_idx(d) == MklDims::N; }
+ // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
+ // corresponds to MKL's Width dimension.
+ bool IsMklWidthDim(int d) const { return tf_dim_idx(d) == MklDims::W; }
+ // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
+ // corresponds to MKL's Height dimension.
+ bool IsMklHeightDim(int d) const { return tf_dim_idx(d) == MklDims::H; }
+
+ // Check if the TF-Mkl dimension ordering map specifies if the input
+ // tensor is in NCHW format.
+ bool IsTensorInNCHWFormat() const {
+ TensorFormat data_format = FORMAT_NCHW;
+ return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
+ IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
+ IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
+ IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
+ }
+
+ // Check if the TF-Mkl dimension ordering map specifies if the input
+ // tensor is in NHWC format.
+ bool IsTensorInNHWCFormat() const {
+ TensorFormat data_format = FORMAT_NHWC;
+ return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
+ IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
+ IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
+ IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
+ }
+
void GetConvertedFlatData(dnnLayout_t targetLayout, void* input,
void* output) const {
dnnLayout_t curLayout;
@@ -194,9 +226,9 @@ class MklShape {
(STRIDES_OFFSET(dims) + dims * sizeof(size_t)) // Location of mklLayout_
#define TF_LAYOUT_OFFSET(dims) \
(MKL_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF) // Location of tfLayout_
-// Location of tf_to_mkl_dim_map_
#define TF_TO_MKL_DIM_MAP_OFFSET(dims) \
- (TF_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)
+ (TF_LAYOUT_OFFSET(dims) + \
+ SIZE_OF_MKL_DNN_BUF) // Location of tf_to_mkl_dim_map_
// TODO(agramesh1) make sure to create a const to share with rewrite pass
// for min size of MKL metadata tensor.
@@ -265,45 +297,166 @@ class MklShape {
size_t dimension_ = 0;
size_t* sizes_ = nullptr; // Required by MKL for conversions
size_t* strides_ = nullptr; // Required by MKL for conversions
- // TF dimension corresponding to this MKL dimension
- size_t* tf_to_mkl_dim_map_ = nullptr;
+ size_t* tf_to_mkl_dim_map_ =
+ nullptr; // TF dimension corresponding to this MKL dimension
};
-int inline GetTensorDataIndex(int n) {
- return 2 * n; // index corresponding to nth input/output tensor
+// List of MklShape objects. Used in Concat/Split layers.
+typedef std::vector<MklShape> MklShapeList;
+
+// Check if all tensors specified by MklShapes are MKL tensors.
+inline bool AreAllMklTensors(const MklShapeList& shapes) {
+ for (auto& s : shapes) {
+ if (!s.IsMklTensor()) {
+ return false;
+ }
+ }
+ return true;
+}
+
+template <typename T>
+inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
+ const MklShape& mkl_shape) {
+ Tensor output_tensor;
+ TensorShape output_shape;
+
+ for (size_t j = 0; j < mkl_shape.GetDimension(); j++) {
+ // Outermost to innermost dimension
+ output_shape.AddDim(mkl_shape.GetSizes()[mkl_shape.tf_dim_idx(j)]);
+ }
+
+ // Allocate output tensor.
+ context->allocate_temp(DataTypeToEnum<T>::v(), output_shape, &output_tensor);
+
+ dnnLayout_t output_layout = static_cast<dnnLayout_t>(mkl_shape.GetTfLayout());
+ void* input_buffer = const_cast<T*>(mkl_tensor.flat<T>().data());
+ void* output_buffer = const_cast<T*>(output_tensor.flat<T>().data());
+
+ if (mkl_tensor.NumElements() != 0) {
+ mkl_shape.GetConvertedFlatData(output_layout, input_buffer, output_buffer);
+ }
+
+ return output_tensor;
}
-int inline GetTensorMetaDataIndex(int n) {
- // index corresponding to meta data of nth input/output tensor
- return 2 * n + 1;
+// Since our ops are going to produce and also consume N addition tensors
+// (Mkl) for N Tensorflow tensors, we can have following different
+// orderings among these 2N tensors.
+//
+// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
+// consume A_m, B_m, and C_m additionally.
+//
+// INTERLEAVED: in this case 2N tensors are interleaved. So for above
+// example, the ordering looks like: A, A_m, B, B_m, C, C_m.
+//
+// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
+// by N Mkl tensors. So for above example, the ordering looks
+// like: A, B, C, A_m, B_m, C_m
+//
+// Following APIs map index of original Tensorflow tensors to their appropriate
+// position based on selected ordering. For contiguous ordering, we need to know
+// the total number of tensors (parameter total).
+//
+typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
+// NOTE: Currently, we use contiguous ordering. If you change this, then you
+// would need to change Mkl op definitions in nn_ops.cc.
+static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
+
+// Get index of MetaData tensor from index 'n' of Data tensor.
+inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
+ if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+ // For interleaved ordering, Mkl tensor follows immediately after
+ // Tensorflow tensor.
+ return n + 1;
+ } else {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
+ return n + total_tensors / 2;
+ }
}
+
+int inline GetTensorDataIndex(int n, int total_tensors) {
+ if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+ return 2 * n; // index corresponding to nth input/output tensor
+ } else {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ return n;
+ }
+}
+
+int inline GetTensorMetaDataIndex(int n, int total_tensors) {
+ // Get index for TensorData first and then use mapping function
+ // to get TensorMetaData index from TensorData index.
+ int tidx = GetTensorDataIndex(n, total_tensors);
+ return DataIndexToMetaDataIndex(tidx, total_tensors);
+}
+
// Get the MKL shape from the second string tensor
inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
mklshape->DeSerializeMklShape(
- ctext->input(GetTensorMetaDataIndex(n)).flat<uint8>().data(),
- ctext->input(GetTensorMetaDataIndex(n)).flat<uint8>().size() *
+ ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
+ .flat<uint8>()
+ .data(),
+ ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
+ .flat<uint8>()
+ .size() *
sizeof(uint8));
}
// Gets the actual input
inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
- return ctext->input(GetTensorDataIndex(n));
+ return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
+}
+
+inline void GetMklInputList(OpKernelContext* ctext, StringPiece name,
+ OpInputList* input_tensors) {
+ CHECK_NOTNULL(input_tensors);
+ ctext->input_list(name, input_tensors);
+}
+
+inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
+ MklShapeList* mkl_shapes) {
+ OpInputList input_mkl_tensors;
+ GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
+
+ for (int i = 0; i < input_mkl_tensors.size(); i++) {
+ (*mkl_shapes)[i].DeSerializeMklShape(
+ input_mkl_tensors[i].flat<uint8>().data(),
+ input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
+ }
+}
+
+// Allocate the second output tensor that will contain
+// the MKL shape serialized
+inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
+ const MklShape& mkl_shape) {
+ Tensor* second_tensor = nullptr;
+ TensorShape second_shape;
+ second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
+ OP_REQUIRES_OK(ctext, ctext->allocate_output(
+ GetTensorMetaDataIndex(n, ctext->num_outputs()),
+ second_shape, &second_tensor));
+ mkl_shape.SerializeMklShape(
+ second_tensor->flat<uint8>().data(),
+ second_tensor->flat<uint8>().size() * sizeof(uint8));
}
// Allocate the output tensor, create a second output tensor that will contain
// the MKL shape serialized
-inline void AllocateOutputSetMklshape(OpKernelContext* ctext, int n,
+inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
Tensor** output,
- const TensorShape& tfshape,
- const MklShape& mklshape) {
+ const TensorShape& tf_shape,
+ const MklShape& mkl_shape) {
Tensor* second_tensor = nullptr;
TensorShape second_shape;
- second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mklshape.GetDimension()));
+ second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
OP_REQUIRES_OK(
- ctext, ctext->allocate_output(GetTensorDataIndex(n), tfshape, output));
- OP_REQUIRES_OK(ctext, ctext->allocate_output(GetTensorMetaDataIndex(n),
- second_shape, &second_tensor));
- mklshape.SerializeMklShape(
+ ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
+ tf_shape, output));
+ OP_REQUIRES_OK(ctext, ctext->allocate_output(
+ GetTensorMetaDataIndex(n, ctext->num_outputs()),
+ second_shape, &second_tensor));
+ mkl_shape.SerializeMklShape(
second_tensor->flat<uint8>().data(),
second_tensor->flat<uint8>().size() * sizeof(uint8));
}
@@ -342,12 +495,11 @@ inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides,
inline void MklSizesToTFSizes(OpKernelContext* context,
TensorFormat data_format_,
- const MklShape& mklshape, TensorShape* tfshape) {
- size_t tf_dim = mklshape.GetDimension();
- const size_t* tf_sizes = mklshape.GetSizes();
+ const MklShape& mkl_shape,
+ TensorShape* tf_shape) {
+ size_t tf_dim = mkl_shape.GetDimension();
+ const size_t* tf_sizes = mkl_shape.GetSizes();
- // TODO(agramesh1): check if this constraint is applicable in other cases
- // (besides BackpropInput, BackpropFilter).
OP_REQUIRES(context, tf_dim == 4,
errors::InvalidArgument("MKLSizesToTFSizes: size must be 4-dim"));
std::vector<int32> sizes;
@@ -364,7 +516,7 @@ inline void MklSizesToTFSizes(OpKernelContext* context,
sizes.push_back(tf_sizes[0]);
}
- OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(sizes, tfshape));
+ OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(sizes, tf_shape));
}
inline int32 GetMklTensorDimIndex(char dimension) {
@@ -383,38 +535,71 @@ inline int32 GetMklTensorDimIndex(char dimension) {
}
}
-inline int64 GetMklTensorDim(const MklShape& mklshape, char dimension) {
+inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) {
int index = GetMklTensorDimIndex(dimension);
- CHECK(index >= 0 && index < mklshape.GetDimension())
+ CHECK(index >= 0 && index < mkl_shape.GetDimension())
<< "Invalid index from the dimension: " << index << ", " << dimension;
- return mklshape.dim_size(index);
+ return mkl_shape.dim_size(index);
}
-namespace mkl_layer_registry {
+inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
+ int idx_out) {
+ int num_inputs = context->num_inputs();
+ int num_outputs = context->num_outputs();
+ int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
+ int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
+ int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
+ int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
+
+ const Tensor& data = context->input(idx_data_in);
+ const Tensor& meta = context->input(idx_meta_in);
+ Tensor output(data.dtype());
+ Tensor meta_output(meta.dtype());
+
+ // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
+ CHECK(output.CopyFrom(data, data.shape()));
+ CHECK(meta_output.CopyFrom(meta, meta.shape()));
+ context->set_output(idx_data_out, output);
+ context->set_output(idx_meta_out, meta_output);
+}
+
+inline void CopyTFTensorInToOut(OpKernelContext* context, int idx_in,
+ int idx_out, const TensorShape& shape) {
+ int num_inputs = context->num_inputs();
+ int num_outputs = context->num_outputs();
+ int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
+ int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
+
+ const Tensor& data = context->input(idx_data_in);
+ MklShape mkl_shape_output;
+ mkl_shape_output.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
+ Tensor output(data.dtype());
+ // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
+ CHECK(output.CopyFrom(data, shape));
+ context->set_output(idx_data_out, output);
+}
-static const char* kMklLayerLabel = "MklLayer";
-static const char* kMklLayerLabelPattern = "label='MklLayer'";
+namespace mkl_op_registry {
+static const char* kMklOpLabel = "MklOp";
+static const char* kMklOpLabelPattern = "label='MklOp'";
// Check whether opname with type T is registered as MKL-compliant.
//
// @input: name of the op
// @input: T datatype to be used for checking op
-// @return: true if opname is registered as Mkl layer op
-static inline bool IsMklLayer(const std::string& op_name, DataType T) {
+// @return: true if opname is registered as Mkl op
+static inline bool IsMklOp(const std::string& op_name, DataType T) {
string kernel = KernelsRegisteredForOp(op_name);
- // Currently, MKL only supports float type for ops. So we check if
- // the type is float. Actually, we should query kernel registration and
- // find out if op is supported for type T. But there is no API to query
- // kernel registration using name and type.
bool result =
- (kernel.find(kMklLayerLabelPattern) != string::npos) && (T == DT_FLOAT);
- if (result == true) {
- VLOG(1) << "mkl_layer_registry::" << op_name << " is " << kMklLayerLabel;
+ kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
+ if (result) {
+ VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel;
}
return result;
}
-} // namespace mkl_layer_registry
+} // namespace mkl_op_registry
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/docs_src/community/style_guide.md b/tensorflow/docs_src/community/style_guide.md
index a2df61bc80..767e33c3d0 100644
--- a/tensorflow/docs_src/community/style_guide.md
+++ b/tensorflow/docs_src/community/style_guide.md
@@ -115,31 +115,31 @@ Example:
def my_op(tensor_in, other_tensor_in, my_param, other_param=0.5,
output_collections=(), name=None):
- """My operation that adds two tensors with given coefficients.
-
- Args:
- tensor_in: `Tensor`, input tensor.
- other_tensor_in: `Tensor`, same shape as `tensor_in`, other input tensor.
- my_param: `float`, coefficient for `tensor_in`.
- other_param: `float`, coefficient for `other_tensor_in`.
- output_collections: `tuple` of `string`s, name of the collection to
- collect result of this op.
- name: `string`, name of the operation.
-
- Returns:
- `Tensor` of same shape as `tensor_in`, sum of input values with coefficients.
-
- Example:
- >>> my_op([1., 2.], [3., 4.], my_param=0.5, other_param=0.6,
- output_collections=['MY_OPS'], name='add_t1t2')
- [2.3, 3.4]
- """
- with tf.name_scope(name, "my_op", [tensor_in, other_tensor_in]):
- tensor_in = tf.convert_to_tensor(tensor_in)
- other_tensor_in = tf.convert_to_tensor(other_tensor_in)
- result = my_param * tensor_in + other_param * other_tensor_in
- tf.add_to_collections(output_collections, result)
- return result
+ """My operation that adds two tensors with given coefficients.
+
+ Args:
+ tensor_in: `Tensor`, input tensor.
+ other_tensor_in: `Tensor`, same shape as `tensor_in`, other input tensor.
+ my_param: `float`, coefficient for `tensor_in`.
+ other_param: `float`, coefficient for `other_tensor_in`.
+ output_collections: `tuple` of `string`s, name of the collection to
+ collect result of this op.
+ name: `string`, name of the operation.
+
+ Returns:
+ `Tensor` of same shape as `tensor_in`, sum of input values with coefficients.
+
+ Example:
+ >>> my_op([1., 2.], [3., 4.], my_param=0.5, other_param=0.6,
+ output_collections=['MY_OPS'], name='add_t1t2')
+ [2.3, 3.4]
+ """
+ with tf.name_scope(name, "my_op", [tensor_in, other_tensor_in]):
+ tensor_in = tf.convert_to_tensor(tensor_in)
+ other_tensor_in = tf.convert_to_tensor(other_tensor_in)
+ result = my_param * tensor_in + other_param * other_tensor_in
+ tf.add_to_collection(output_collections, result)
+ return result
Usage:
diff --git a/tensorflow/docs_src/extend/adding_an_op.md b/tensorflow/docs_src/extend/adding_an_op.md
index c75c7f111d..f54f79cbf4 100644
--- a/tensorflow/docs_src/extend/adding_an_op.md
+++ b/tensorflow/docs_src/extend/adding_an_op.md
@@ -121,16 +121,16 @@ class ZeroOutOp : public OpKernel {
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
- auto output = output_tensor->flat<int32>();
+ auto output_flat = output_tensor->flat<int32>();
// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
- output(i) = 0;
+ output_flat(i) = 0;
}
// Preserve the first input value if possible.
- if (N > 0) output(0) = input(0);
+ if (N > 0) output_flat(0) = input(0);
}
};
```
diff --git a/tensorflow/docs_src/get_started/get_started.md b/tensorflow/docs_src/get_started/get_started.md
index 33ff1d87d5..6bee7529d0 100644
--- a/tensorflow/docs_src/get_started/get_started.md
+++ b/tensorflow/docs_src/get_started/get_started.md
@@ -323,7 +323,7 @@ for i in range(1000):
sess.run(train, {x:x_train, y:y_train})
# evaluate training accuracy
-curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x:x_train, y:y_train})
+curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x:x_train, y:y_train})
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
```
When run, it produces
diff --git a/tensorflow/docs_src/get_started/monitors.md b/tensorflow/docs_src/get_started/monitors.md
index 99d583b23d..7db88c8981 100644
--- a/tensorflow/docs_src/get_started/monitors.md
+++ b/tensorflow/docs_src/get_started/monitors.md
@@ -282,18 +282,15 @@ validation_metrics = {
"accuracy":
tf.contrib.learn.MetricSpec(
metric_fn=tf.contrib.metrics.streaming_accuracy,
- prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
- CLASSES),
+ prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
"precision":
tf.contrib.learn.MetricSpec(
metric_fn=tf.contrib.metrics.streaming_precision,
- prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
- CLASSES),
+ prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
"recall":
tf.contrib.learn.MetricSpec(
metric_fn=tf.contrib.metrics.streaming_recall,
- prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
- CLASSES)
+ prediction_key=tf.contrib.learn.PredictionKey.CLASSES)
}
```
diff --git a/tensorflow/docs_src/get_started/tflearn.md b/tensorflow/docs_src/get_started/tflearn.md
index 0912c7a5b4..4a893e4a45 100644
--- a/tensorflow/docs_src/get_started/tflearn.md
+++ b/tensorflow/docs_src/get_started/tflearn.md
@@ -282,7 +282,7 @@ enough that it can be stored in @{tf.constant TensorFlow constants}. The
following code produces the simplest possible input pipeline:
```python
-# Define the test inputs
+# Define the training inputs
def get_train_inputs():
x = tf.constant(training_set.data)
y = tf.constant(training_set.target)
diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md
index 0f3914d52d..c1581efb4f 100644
--- a/tensorflow/docs_src/install/install_c.md
+++ b/tensorflow/docs_src/install/install_c.md
@@ -35,7 +35,7 @@ enable TensorFlow for C:
OS="linux" # Change to "darwin" for Mac OS
TARGET_DIRECTORY="/usr/local"
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.1.0-rc1.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.1.0-rc2.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md
index 6874a1f03f..dd713e4786 100644
--- a/tensorflow/docs_src/install/install_go.md
+++ b/tensorflow/docs_src/install/install_go.md
@@ -35,7 +35,7 @@ steps to install this library and enable TensorFlow for Go:
TF_TYPE="cpu" # Change to "gpu" for GPU support
TARGET_DIRECTORY='/usr/local'
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.1.0-rc1.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.1.0-rc2.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
index 127d8fd029..1abf3b69f5 100644
--- a/tensorflow/docs_src/install/install_java.md
+++ b/tensorflow/docs_src/install/install_java.md
@@ -34,7 +34,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.1.0-rc1</version>
+ <version>1.1.0-rc2</version>
</dependency>
```
@@ -63,7 +63,7 @@ As an example, these steps will create a Maven project that uses TensorFlow:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.1.0-rc1</version>
+ <version>1.1.0-rc2</version>
</dependency>
</dependencies>
</project>
@@ -122,7 +122,7 @@ refer to the simpler instructions above instead.
Take the following steps to install TensorFlow for Java on Linux or Mac OS:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc1.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc2.jar),
which is the TensorFlow Java Archive (JAR).
2. Decide whether you will run TensorFlow for Java on CPU(s) only or with
@@ -141,7 +141,7 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS:
OS=$(uname -s | tr '[:upper:]' '[:lower:]')
mkdir -p ./jni
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.1.0-rc1.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.1.0-rc2.tar.gz" |
tar -xz -C ./jni
### Install on Windows
@@ -149,10 +149,10 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS:
Take the following steps to install TensorFlow for Java on Windows:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc1.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc2.jar),
which is the TensorFlow Java Archive (JAR).
2. Download the following Java Native Interface (JNI) file appropriate for
- [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.1.0-rc1.zip).
+ [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.1.0-rc2.zip).
3. Extract this .zip file.
@@ -200,7 +200,7 @@ must be part of your `classpath`. For example, you can include the
downloaded `.jar` in your `classpath` by using the `-cp` compilation flag
as follows:
-<pre><b>javac -cp libtensorflow-1.1.0-rc1.jar HelloTF.java</b></pre>
+<pre><b>javac -cp libtensorflow-1.1.0-rc2.jar HelloTF.java</b></pre>
### Running
@@ -213,7 +213,7 @@ two files are available to the JVM:
For example, the following command line executes the `HelloTF` program:
-<pre><b>java -cp libtensorflow-1.1.0-rc1.jar:. -Djava.library.path=./jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.1.0-rc2.jar:. -Djava.library.path=./jni HelloTF</b></pre>
If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
installed TensorFlow for Java and are ready to use the API. If the program
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index 4a5d63f337..8ee31fe692 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -165,8 +165,8 @@ Take the following steps to install TensorFlow with Virtualenv:
issue the following command to install TensorFlow in the active
virtualenv environment:
- <pre> (tensorflow)$ <b>pip install --upgrade \\
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl</b></pre>
+ <pre>(tensorflow)$ <b>pip3 install --upgrade \
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl</b></pre>
If you encounter installation problems, see
[Common Installation Problems](#common_installation_problems).
@@ -269,8 +269,10 @@ take the following steps:
install TensorFlow for Linux, Python 2.7, and CPU-only support, issue
the following command:
- <pre> $ <b>sudo pip install --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl</b></pre>
+ <pre>
+ $ <b>sudo pip3 install --upgrade \
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl</b>
+ </pre>
If this step fails, see
[Common Installation Problems](#common_installation_problems).
@@ -456,7 +458,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
<pre>
(tensorflow)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl</b></pre>
<a name="ValidateYourInstallation"></a>
@@ -624,14 +626,14 @@ This section documents the relevant values for Linux installations.
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp27-none-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc1-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc2-cp27-none-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -643,14 +645,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc1-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc2-cp34-cp34m-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -662,14 +664,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp35-cp35m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc1-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc2-cp35-cp35m-linux_x86_64.whl
</pre>
@@ -681,14 +683,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp36-cp36m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc1-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc2-cp36-cp36m-linux_x86_64.whl
</pre>
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index ccfe9ada6d..0882422e4d 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -163,7 +163,7 @@ Take the following steps to install TensorFlow with Virtualenv:
TensorFlow in the active Virtualenv is as follows:
<pre> $ <b>pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py2-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py2-none-any.whl</b></pre>
If you encounter installation problems, see
[Common Installation Problems](#CommonInstallationProblems).
@@ -286,7 +286,7 @@ take the following steps:
support, issue the following command:
<pre> $ <b>sudo pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py2-none-any.whl</b> </pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py2-none-any.whl</b> </pre>
If the preceding command fails, see
[Common installation problems](#CommonInstallationProblems).
@@ -398,7 +398,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
TensorFlow for Python 2.7:
<pre> (tensorflow)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py2-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py2-none-any.whl</b></pre>
<a name="ValidateYourInstallation"></a>
@@ -604,13 +604,13 @@ This section documents the relevant values for Mac OS installations.
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py2-none-any.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.1.0rc1-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.1.0rc2-py2-none-any.whl
</pre>
Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see
@@ -622,13 +622,13 @@ Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py3-none-any.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.1.0rc1-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.1.0rc2-py3-none-any.whl
</pre>
Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index 5f351c40b4..88268ba62f 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -319,10 +319,11 @@ $ <b>bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pk
Invoke `pip install` to install that pip package.
The filename of the `.whl` file depends on your platform.
For example, the following command will install the pip package
-for TensorFlow 1.1.0rc1 on Linux:
+
+for TensorFlow 1.1.0rc2 on Linux:
<pre>
-$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.1.0rc1-py2-none-any.whl</b>
+$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.1.0rc2-py2-none-any.whl</b>
</pre>
## Validate your installation
diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md
index 7d3d13c34a..5f7c27c028 100644
--- a/tensorflow/docs_src/install/install_windows.md
+++ b/tensorflow/docs_src/install/install_windows.md
@@ -114,12 +114,12 @@ Take the following steps to install TensorFlow in an Anaconda environment:
environment. To install the CPU-only version of TensorFlow, enter the
following command:
- <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-1.1.0rc1-cp35-cp35m-win_amd64.whl</b> </pre>
+ <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-1.1.0rc2-cp35-cp35m-win_amd64.whl</b> </pre>
To install the GPU version of TensorFlow, enter the following command
(on a single line):
- <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-1.1.0rc1-cp35-cp35m-win_amd64.whl</b> </pre>
+ <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-1.1.0rc2-cp35-cp35m-win_amd64.whl</b> </pre>
## Validate your installation
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py
index 8cd296d752..e29387ab9d 100644
--- a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py
@@ -88,7 +88,8 @@ def run_training():
saver = tf.train.Saver()
# Create the op for initializing variables.
- init_op = tf.global_variables_initializer()
+ init_op = tf.group(tf.global_variables_initializer(),
+ tf.local_variables_initializer())
# Create a session for running Ops on the Graph.
sess = tf.Session()
diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
index 698c97ca1d..dc0d870315 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
@@ -135,7 +135,8 @@ def train():
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', accuracy)
- # Merge all the summaries and write them out to /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default)
+ # Merge all the summaries and write them out to
+ # /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default)
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')
@@ -196,9 +197,15 @@ if __name__ == '__main__':
help='Initial learning rate')
parser.add_argument('--dropout', type=float, default=0.9,
help='Keep probability for training dropout.')
- parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
- help='Directory for storing input data')
- parser.add_argument('--log_dir', type=str, default='/tmp/tensorflow/mnist/logs/mnist_with_summaries',
- help='Summaries log directory')
+ parser.add_argument(
+ '--data_dir',
+ type=str,
+ default='/tmp/tensorflow/mnist/input_data',
+ help='Directory for storing input data')
+ parser.add_argument(
+ '--log_dir',
+ type=str,
+ default='/tmp/tensorflow/mnist/logs/mnist_with_summaries',
+ help='Summaries log directory')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index cad8ccaaad..c367d20f81 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -25,6 +25,7 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_py
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_lib_deps")
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_plugin_deps")
load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py")
+load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_verbs_deps")
py_library(
name = "python",
@@ -2610,7 +2611,9 @@ tf_py_wrap_cc(
"//tensorflow/tools/graph_transforms:transform_graph_lib",
"//tensorflow/tools/tfprof/internal:print_model_analysis",
"//util/python:python_headers",
- ] + tf_additional_lib_deps() + tf_additional_plugin_deps(),
+ ] + (tf_additional_lib_deps() +
+ tf_additional_plugin_deps() +
+ tf_additional_verbs_deps()),
)
py_library(
diff --git a/tensorflow/python/framework/dtypes_test.py b/tensorflow/python/framework/dtypes_test.py
index fac2cf4def..f04f67ffed 100644
--- a/tensorflow/python/framework/dtypes_test.py
+++ b/tensorflow/python/framework/dtypes_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for tensorflow.python.framework.importer."""
+"""Tests for tensorflow.python.framework.dtypes."""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index 5e8f8e8673..10f34751d0 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -352,9 +352,20 @@ class FixedLengthRecordReaderTest(test.TestCase):
self._record_bytes = 3
self._footer_bytes = 2
+ self._hop_bytes = 2
+ self._num_overlapped_records = 3
+
def _Record(self, f, r):
return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
+ def _OverlappedRecord(self, f, r):
+ record_str = "".join([
+ str(i)[0]
+ for i in range(r * self._hop_bytes,
+ r * self._hop_bytes + self._record_bytes)
+ ])
+ return compat.as_bytes(record_str)
+
def _CreateFiles(self):
filenames = []
for i in range(self._num_files):
@@ -367,6 +378,23 @@ class FixedLengthRecordReaderTest(test.TestCase):
f.write(b"F" * self._footer_bytes)
return filenames
+ def _CreateOverlappedRecordFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(),
+ "fixed_length_overlapped_record.%d.txt" % i)
+ filenames.append(fn)
+ with open(fn, "wb") as f:
+ f.write(b"H" * self._header_bytes)
+ all_records_str = "".join([
+ str(i)[0]
+ for i in range(self._record_bytes + self._hop_bytes *
+ (self._num_overlapped_records - 1))
+ ])
+ f.write(compat.as_bytes(all_records_str))
+ f.write(b"F" * self._footer_bytes)
+ return filenames
+
def testOneEpoch(self):
files = self._CreateFiles()
with self.test_session() as sess:
@@ -374,6 +402,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
header_bytes=self._header_bytes,
record_bytes=self._record_bytes,
footer_bytes=self._footer_bytes,
+ hop_bytes=0,
name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -390,6 +419,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
"\\(requested 1, current size 0\\)"):
k, v = sess.run([key, value])
+ def testOneEpochWithHopBytes(self):
+ files = self._CreateOverlappedRecordFiles()
+ with self.test_session() as sess:
+ reader = io_ops.FixedLengthRecordReader(
+ header_bytes=self._header_bytes,
+ record_bytes=self._record_bytes,
+ footer_bytes=self._footer_bytes,
+ hop_bytes=self._hop_bytes,
+ name="test_reader")
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ key, value = reader.read(queue)
+
+ queue.enqueue_many([files]).run()
+ queue.close().run()
+ for i in range(self._num_files):
+ for j in range(self._num_overlapped_records):
+ k, v = sess.run([key, value])
+ print(v)
+ self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
+ self.assertAllEqual(self._OverlappedRecord(i, j), v)
+
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ k, v = sess.run([key, value])
+
class TFRecordReaderTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py
index df5462dd2d..e8b94294b1 100644
--- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py
@@ -30,34 +30,44 @@ from tensorflow.python.platform import test
class SparseTensorDenseMatMulGradientTest(test.TestCase):
- def _sparsify(self, x):
+ def _sparsify(self, x, indices_dtype=np.int64):
x[x < 0.5] = 0
non_zero = np.where(x)
- x_indices = np.vstack(non_zero).astype(np.int64).T
+ x_indices = np.vstack(non_zero).astype(indices_dtype).T
x_values = x[non_zero]
x_shape = x.shape
return sparse_tensor.SparseTensor(
indices=x_indices, values=x_values, dense_shape=x_shape), len(x_values)
- def _randomTensor(self, size, np_dtype, adjoint=False, sparse=False):
+ def _randomTensor(self,
+ size,
+ values_dtype,
+ adjoint=False,
+ sparse=False,
+ indices_dtype=np.int64):
n, m = size
- x = np.random.randn(n, m).astype(np_dtype)
+ x = np.random.randn(n, m).astype(values_dtype)
if adjoint:
x = x.transpose()
if sparse:
- return self._sparsify(x)
+ return self._sparsify(x, indices_dtype=indices_dtype)
else:
- return constant_op.constant(x, dtype=np_dtype)
+ return constant_op.constant(x, dtype=values_dtype)
- def _testGradients(self, adjoint_a, adjoint_b, name, np_dtype):
+ def _testGradients(self, adjoint_a, adjoint_b, name, values_dtype,
+ indices_dtype):
n, k, m = np.random.randint(1, 10, size=3)
sp_t, nnz = self._randomTensor(
- [n, k], np_dtype, adjoint=adjoint_a, sparse=True)
- dense_t = self._randomTensor([k, m], np_dtype, adjoint=adjoint_b)
+ [n, k],
+ values_dtype,
+ adjoint=adjoint_a,
+ sparse=True,
+ indices_dtype=indices_dtype)
+ dense_t = self._randomTensor([k, m], values_dtype, adjoint=adjoint_b)
matmul = sparse_ops.sparse_tensor_dense_matmul(
sp_t, dense_t, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name=name)
@@ -71,17 +81,19 @@ class SparseTensorDenseMatMulGradientTest(test.TestCase):
print("%s gradient err = %s" % (name, err))
self.assertLess(err, 1e-3)
- def _testGradientsType(self, np_dtype):
+ def _testGradientsType(self, values_dtype, indices_dtype):
for adjoint_a in [True, False]:
for adjoint_b in [True, False]:
- name = "sparse_tensor_dense_matmul_%s_%s_%s" % (adjoint_a, adjoint_b,
- np_dtype.__name__)
- self._testGradients(adjoint_a, adjoint_b, name, np_dtype)
+ name = "sparse_tensor_dense_matmul_%s_%s_%s_%s" % (
+ adjoint_a, adjoint_b, values_dtype.__name__, indices_dtype.__name__)
+ self._testGradients(adjoint_a, adjoint_b, name, values_dtype,
+ indices_dtype)
def testGradients(self):
np.random.seed(5) # Fix seed to avoid flakiness
- self._testGradientsType(np.float32)
- self._testGradientsType(np.float64)
+ self._testGradientsType(np.float32, np.int64)
+ self._testGradientsType(np.float64, np.int64)
+ self._testGradientsType(np.float32, np.int32)
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
index da72803ee7..8099175186 100644
--- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
@@ -45,7 +45,12 @@ def _maybe_complex(x):
class SparseTensorDenseMatMulTest(test.TestCase):
- def _testMatmul(self, x, y, adjoint_a=False, adjoint_b=False):
+ def _testMatmul(self,
+ x,
+ y,
+ adjoint_a=False,
+ adjoint_b=False,
+ indices_dtype=np.int64):
x_mat = np.matrix(x)
if adjoint_a:
x_mat = x_mat.H
@@ -55,7 +60,7 @@ class SparseTensorDenseMatMulTest(test.TestCase):
np_ans = x_mat * y_mat
- x_indices = np.vstack(np.where(x)).astype(np.int64).T
+ x_indices = np.vstack(np.where(x)).astype(indices_dtype).T
x_values = x[np.where(x)]
x_shape = x.shape
@@ -82,13 +87,13 @@ class SparseTensorDenseMatMulTest(test.TestCase):
else:
self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4)
- def _testBasic(self, np_dtype):
- x = _maybe_complex(np.random.rand(10, 10).astype(np_dtype))
+ def _testBasic(self, value_dtype, indices_dtype=np.int64):
+ x = _maybe_complex(np.random.rand(10, 10).astype(value_dtype))
x[np.abs(x) < 0.5] = 0 # Make it sparse
- y = _maybe_complex(np.random.randn(10, 20).astype(np_dtype))
+ y = _maybe_complex(np.random.randn(10, 20).astype(value_dtype))
- self._testMatmul(x, y)
+ self._testMatmul(x, y, indices_dtype=indices_dtype)
def testBasic(self):
np.random.seed(127) # Repeatable results
@@ -97,6 +102,8 @@ class SparseTensorDenseMatMulTest(test.TestCase):
self._testBasic(np.float64)
self._testBasic(np.complex64)
self._testBasic(np.complex128)
+ self._testBasic(np.int32, indices_dtype=np.int32)
+ self._testBasic(np.float32, indices_dtype=np.int32)
def testShapeInference(self):
x = np.random.rand(10, 10)
diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py
index c3e133d08b..da962b2f99 100644
--- a/tensorflow/python/layers/convolutional_test.py
+++ b/tensorflow/python/layers/convolutional_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for tf.layers.core."""
+"""Tests for tf.layers.convolutional."""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py
index 0f82f73ea4..933f196e01 100644
--- a/tensorflow/python/layers/normalization_test.py
+++ b/tensorflow/python/layers/normalization_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for tf.layers.core."""
+"""Tests for tf.layers.normalization."""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/layers/utils_test.py b/tensorflow/python/layers/utils_test.py
index ace8046a0b..54e757c112 100644
--- a/tensorflow/python/layers/utils_test.py
+++ b/tensorflow/python/layers/utils_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for tf.layers.core."""
+"""Tests for tf.layers.utils."""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/ops/batch_norm_benchmark.py b/tensorflow/python/ops/batch_norm_benchmark.py
index 397ed91078..c2ee2b3832 100644
--- a/tensorflow/python/ops/batch_norm_benchmark.py
+++ b/tensorflow/python/ops/batch_norm_benchmark.py
@@ -198,7 +198,7 @@ class BatchNormBenchmark(test.Benchmark):
if FLAGS.use_gpu:
t1 = self._run_graph("gpu", shape, axes, 10, "op", True, True, 50)
t2 = self._run_graph("gpu", shape, axes, 10, "py", True, True, 50)
- t2 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 50)
+ t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 50)
print_difference("op vs py", t1, t2)
print_difference("py vs slow", t2, t3)
print("Forward convolution (higher layers).")
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index ae45c40aec..68ecc219e4 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -391,7 +391,11 @@ class FixedLengthRecordReader(ReaderBase):
"""
# TODO(josh11b): Support serializing and restoring state.
- def __init__(self, record_bytes, header_bytes=None, footer_bytes=None,
+ def __init__(self,
+ record_bytes,
+ header_bytes=None,
+ footer_bytes=None,
+ hop_bytes=None,
name=None):
"""Create a FixedLengthRecordReader.
@@ -399,11 +403,15 @@ class FixedLengthRecordReader(ReaderBase):
record_bytes: An int.
header_bytes: An optional int. Defaults to 0.
footer_bytes: An optional int. Defaults to 0.
+ hop_bytes: An optional int. Defaults to 0.
name: A name for the operation (optional).
"""
rr = gen_io_ops._fixed_length_record_reader_v2(
- record_bytes=record_bytes, header_bytes=header_bytes,
- footer_bytes=footer_bytes, name=name)
+ record_bytes=record_bytes,
+ header_bytes=header_bytes,
+ footer_bytes=footer_bytes,
+ hop_bytes=hop_bytes,
+ name=name)
super(FixedLengthRecordReader, self).__init__(rr)
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 118ade45ec..7c17cf2cb6 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -639,18 +639,22 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
math_ops.reduce_mean(y, axes, keep_dims=True))
else:
shift = math_ops.cast(shift, y.dtype)
- counts, m_ss, v_ss, shift = sufficient_statistics(
- y, axes, shift=shift, keep_dims=keep_dims, name=name)
- # Reshape shift as needed.
- shift = array_ops.reshape(shift, array_ops.shape(m_ss))
- shift.set_shape(m_ss.get_shape())
- with ops.control_dependencies([counts, m_ss, v_ss]):
- mean, variance = normalize_moments(counts, m_ss, v_ss, shift, name=name)
- if x.dtype == dtypes.float16:
- return (math_ops.cast(mean, dtypes.float16),
- math_ops.cast(variance, dtypes.float16))
- else:
- return (mean, variance)
+ shifted_mean = math_ops.reduce_mean(
+ math_ops.subtract(y, shift), axes, keep_dims=True, name="shifted_mean")
+ variance = math_ops.subtract(
+ math_ops.reduce_mean(
+ math_ops.squared_difference(y, shift), axes, keep_dims=True),
+ math_ops.square(shifted_mean),
+ name="variance")
+ mean = math_ops.add(shifted_mean, shift, name="mean")
+ if not keep_dims:
+ mean = array_ops.squeeze(mean, axes)
+ variance = array_ops.squeeze(variance, axes)
+ if x.dtype == dtypes.float16:
+ return (math_ops.cast(mean, dtypes.float16), math_ops.cast(
+ variance, dtypes.float16))
+ else:
+ return (mean, variance)
def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py
index fa015856ce..b8e356c78c 100644
--- a/tensorflow/python/ops/sparse_grad.py
+++ b/tensorflow/python/ops/sparse_grad.py
@@ -136,12 +136,13 @@ def _SparseTensorDenseMatMulGrad(op, grad):
Raises:
TypeError: When the two operands don't have the same type.
"""
- sp_t = sparse_tensor.SparseTensor(*op.inputs[:3])
+ a_indices, a_values, a_shape = op.inputs[:3]
+ b = op.inputs[3]
adj_a = op.get_attr("adjoint_a")
adj_b = op.get_attr("adjoint_b")
- a_type = sp_t.values.dtype.base_dtype
- b_type = op.inputs[3].dtype.base_dtype
+ a_type = a_values.dtype.base_dtype
+ b_type = b.dtype.base_dtype
if a_type != b_type:
raise TypeError("SparseTensorDenseMatMul op received operands with "
"different types: ", a_type, " and ", b_type)
@@ -150,15 +151,12 @@ def _SparseTensorDenseMatMulGrad(op, grad):
"complex gradients.")
# gradient w.r.t. dense
- b_grad = sparse_ops.sparse_tensor_dense_matmul(sp_t, grad,
- adjoint_a=not adj_a)
+ b_grad = gen_sparse_ops._sparse_tensor_dense_mat_mul( # pylint: disable=protected-access
+ a_indices, a_values, a_shape, grad, adjoint_a=not adj_a)
if adj_b:
b_grad = array_ops.transpose(b_grad)
# gradient w.r.t. sparse values
- a_indices = op.inputs[0]
- b = op.inputs[3]
-
rows = a_indices[:, 0]
cols = a_indices[:, 1]
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 9f4e6607d1..af7abf5251 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -1239,7 +1239,7 @@ def sparse_tensor_dense_matmul(sp_a,
A should be sorted in order of increasing dimension 1 (i.e., "column major"
order instead of "row major" order).
- Deciding when to use sparse_tensor_dense_matmul vs. matmul(sp_a=True):
+ Deciding when to use sparse_tensor_dense_matmul vs. matmul(a_is_sparse=True):
There are a number of questions to ask in the decision process, including:
@@ -1249,14 +1249,14 @@ def sparse_tensor_dense_matmul(sp_a,
If the answer to several of these questions is yes, consider
converting the `SparseTensor` to a dense one and using `tf.matmul` with
- `sp_a=True`.
+ `a_is_sparse=True`.
This operation tends to perform well when A is more sparse, if the column size
of the product is small (e.g. matrix-vector multiplication), if
`sp_a.dense_shape` takes on large values.
Below is a rough speed comparison between sparse_tensor_dense_matmul,
- labelled 'sparse', and matmul(sp_a=True), labelled 'dense'. For purposes of
+ labelled 'sparse', and matmul(a_is_sparse=True), labelled 'dense'. For purposes of
the comparison, the time spent converting from a SparseTensor to a dense
Tensor is not included, so it is overly conservative with respect to
the time ratio.
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 751ccd3d0e..9d3ac4ed9e 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -319,7 +319,7 @@ class StreamExecutorInterface {
// Creates a new DnnSupport object, ownership is transferred to the caller.
// If SupportsDnn() is false, this will always return null.
//
- // If SupportsDnn() is true, this may return null, for example, if the RNG
+ // If SupportsDnn() is true, this may return null, for example, if the DNN
// initialization fails.
virtual dnn::DnnSupport *CreateDnn() { return nullptr; }
diff --git a/tensorflow/tensorboard/DEVELOPMENT.md b/tensorflow/tensorboard/DEVELOPMENT.md
index 0a35dec42f..3ff2c87dab 100644
--- a/tensorflow/tensorboard/DEVELOPMENT.md
+++ b/tensorflow/tensorboard/DEVELOPMENT.md
@@ -21,7 +21,7 @@ Then, cd into the TensorBoard directory:
and install dependencies:
-`npm run prepare`
+`npm run prep`
Then, run gulp: `gulp`
diff --git a/tensorflow/tensorboard/dist/tf-tensorboard.html b/tensorflow/tensorboard/dist/tf-tensorboard.html
index 2a0a029ffa..8610940ac3 100644
--- a/tensorflow/tensorboard/dist/tf-tensorboard.html
+++ b/tensorflow/tensorboard/dist/tf-tensorboard.html
@@ -3325,6 +3325,7 @@ var Categorizer;
// if undefined, default value (enable for first k runs, disable after).
type: Object,
value: TF.URIStorage.getObjectInitializer('runSelectionState', {}),
+ observer: "_storeRunToIsCheckedMapping",
},
// (Allows state to persist across regex filtering)
outSelected: {
@@ -3373,24 +3374,7 @@ var Categorizer;
},
observers: [
"_setIsolatorIcon(runSelectionState, names)",
- "_storeRunToIsCheckedMappingWithDefault(runSelectionState, namesMatchingRegex)",
],
- _storeRunToIsCheckedMappingWithDefault() {
- var runSelectionStateIsDefault = Object.keys(this.runSelectionState).length == 0;
- if (runSelectionStateIsDefault || this.namesMatchingRegex == null) {
- return;
- }
- var _this = this;
- var allToggledOn = this.namesMatchingRegex
- .every(function(n) {return _this.runSelectionState[n]});
- var allToggledOff = this.namesMatchingRegex
- .every(function(n) {return !_this.runSelectionState[n]});
- var defaultOff = this.namesMatchingRegex.length > this.maxRunsToEnableByDefault;
- if (defaultOff && allToggledOff || !defaultOff && allToggledOn) {
- this.runSelectionState = {};
- }
- this._storeRunToIsCheckedMapping(this.runSelectionState);
- },
_storeRunToIsCheckedMapping: TF.URIStorage.getObjectObserver('runSelectionState', {}),
_makeRegex: function(regex) {
try {
@@ -27156,4 +27140,4 @@ arguments[4][8][0].apply(exports,arguments)
},{"dup":8}]},{},[35,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34]);
</script>
</dom-module>
-</body></html> \ No newline at end of file
+</body></html>
diff --git a/tensorflow/tensorboard/gulp_tasks/bower.js b/tensorflow/tensorboard/gulp_tasks/bower.js
index 7c0e515c6c..8f4666a8c1 100644
--- a/tensorflow/tensorboard/gulp_tasks/bower.js
+++ b/tensorflow/tensorboard/gulp_tasks/bower.js
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-var gulp = require('gulp');
-var bower = require('gulp-bower');
+const gulp = require('gulp');
+const bower = require('gulp-bower');
module.exports = function() {
return function() {
diff --git a/tensorflow/tensorboard/gulp_tasks/compile.js b/tensorflow/tensorboard/gulp_tasks/compile.js
index 3d0d725cfb..01af60eba7 100644
--- a/tensorflow/tensorboard/gulp_tasks/compile.js
+++ b/tensorflow/tensorboard/gulp_tasks/compile.js
@@ -13,25 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-var gulp = require('gulp');
-var ts = require('gulp-typescript');
-var typescript = require('typescript');
-var gutil = require('gulp-util');
-var filter = require('gulp-filter');
-var merge = require('merge2');
-var browserify = require('browserify');
-var tsify = require('tsify');
-var source = require('vinyl-source-stream');
-var glob = require('glob').sync;
-var concat = require('gulp-concat');
+const gulp = require('gulp');
+const ts = require('gulp-typescript');
+const typescript = require('typescript');
+const gutil = require('gulp-util');
+const filter = require('gulp-filter');
+const merge = require('merge2');
+const browserify = require('browserify');
+const tsify = require('tsify');
+const source = require('vinyl-source-stream');
+const glob = require('glob').sync;
+const concat = require('gulp-concat');
-var tsProject = ts.createProject('./tsconfig.json', {
+const tsProject = ts.createProject('./tsconfig.json', {
typescript: typescript,
- noExternalResolve: true, // opt-in for faster compilation!
+ noExternalResolve: true, // opt-in for faster compilation!
});
/** List of components (and their external deps) that are using es6 modules. */
-var ES6_COMPONENTS = [{
+const ES6_COMPONENTS = [{
name: 'vz_projector',
deps: [
'd3/d3.min.js', 'weblas/dist/weblas.js', 'three.js/build/three.min.js',
@@ -44,8 +44,8 @@ module.exports = function(includeDeps) {
return function() {
// Compile all components that are using ES6 modules into a bundle.js
// using browserify.
- var entries = ['typings/index.d.ts'];
- var deps = {};
+ const entries = ['typings/index.d.ts'];
+ const deps = {};
ES6_COMPONENTS.forEach(function(component) {
// Collect all the typescript files across the components.
entries = entries.concat(glob(
@@ -79,7 +79,7 @@ module.exports = function(includeDeps) {
// Compile components that are using global namespaces producing 1 js file
// for each ts file.
- var isComponent = filter([
+ const isComponent = filter([
'components/tf_*/**/*.ts', 'components/vz_*/**/*.ts', 'typings/**/*.ts',
'components/plottable/plottable.d.ts'
// Ignore components that use es6 modules.
diff --git a/tensorflow/tensorboard/gulp_tasks/test.js b/tensorflow/tensorboard/gulp_tasks/test.js
index ffa8122c7b..0c8b14a4cd 100644
--- a/tensorflow/tensorboard/gulp_tasks/test.js
+++ b/tensorflow/tensorboard/gulp_tasks/test.js
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-var gulp = require('gulp');
-var tester = require('web-component-tester').test;
+const gulp = require('gulp');
+const tester = require('web-component-tester').test;
module.exports = function(done) {
tester({}, function(error) {
diff --git a/tensorflow/tensorboard/gulp_tasks/util.js b/tensorflow/tensorboard/gulp_tasks/util.js
index 7a1d2a58ab..0d73f69c73 100644
--- a/tensorflow/tensorboard/gulp_tasks/util.js
+++ b/tensorflow/tensorboard/gulp_tasks/util.js
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-var fs = require('fs');
-var path = require('path');
+const fs = require('fs');
+const path = require('path');
/**
* Returns a list of web components inside the components directory for which
@@ -34,6 +34,6 @@ exports.getComponents = function(namePredicate) {
* directory.
*/
exports.tbComponents = exports.getComponents(function(name) {
- var prefix = name.slice(0, 3);
+ const prefix = name.slice(0, 3);
return prefix == 'tf_' || prefix == 'vz_';
});
diff --git a/tensorflow/tensorboard/gulp_tasks/vulcanize.js b/tensorflow/tensorboard/gulp_tasks/vulcanize.js
index 89700e1d4c..d2286f1d6c 100644
--- a/tensorflow/tensorboard/gulp_tasks/vulcanize.js
+++ b/tensorflow/tensorboard/gulp_tasks/vulcanize.js
@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-var gulp = require('gulp');
-var path = require('path');
-var util = require('./util');
-var vulcanize = require('gulp-vulcanize');
-var replace = require('gulp-replace');
-var rename = require('gulp-rename');
-var header = require('gulp-header');
+const gulp = require('gulp');
+const path = require('path');
+const util = require('./util');
+const vulcanize = require('gulp-vulcanize');
+const replace = require('gulp-replace');
+const rename = require('gulp-rename');
+const header = require('gulp-header');
-var HEADER_STR = '<!-- Copyright 2015 The TensorFlow Authors. All Rights Reserved.\n\
+const HEADER_STR =
+ '<!-- Copyright 2015 The TensorFlow Authors. All Rights Reserved.\n\
\n\
Licensed under the Apache License, Version 2.0 (the "License");\n\
you may not use this file except in compliance with the License.\n\
@@ -40,16 +41,16 @@ This file is generated by `gulp` & `vulcanize`. Do not directly change it.\n\
Instead, use `gulp regenerate` to create a new version with your changes.\n\
-->\n\n'
-var base = path.join(__dirname, '../components');
+const base = path.join(__dirname, '../components');
// List of redirects of the form path1|path2 for every tensorboard component
// in order to replace dashes with underscores.
// E.g. .../tf-tensorboard|.../tf_tensorboard
-var redirects = util.tbComponents.map(function(dir) {
+const redirects = util.tbComponents.map(function(dir) {
return path.join(base, dir.replace(/_/g, '-')) + '|' + path.join(base, dir);
});
-var nonTBComponents = util.getComponents(function(name) {
- var prefix = name.slice(0, 3);
+const nonTBComponents = util.getComponents(function(name) {
+ const prefix = name.slice(0, 3);
return prefix !== 'tf_' && prefix !== 'vz_';
});
@@ -65,7 +66,7 @@ nonTBComponents.push('/tf-imports/plottable.js');
module.exports = function(overwrite) {
return function() {
- var suffix = overwrite ? '' : '.OPENSOURCE';
+ const suffix = overwrite ? '' : '.OPENSOURCE';
// Vulcanize TensorBoard without external libraries.
gulp.src('components/tf_tensorboard/tf-tensorboard.html')
.pipe(vulcanize({
diff --git a/tensorflow/tensorboard/gulpfile.js b/tensorflow/tensorboard/gulpfile.js
index 257ee0ab83..c03c4faebc 100644
--- a/tensorflow/tensorboard/gulpfile.js
+++ b/tensorflow/tensorboard/gulpfile.js
@@ -13,15 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-var gulp = require('gulp');
-var server = require('gulp-server-livereload');
-var minimist = require('minimist');
-var util = require('./gulp_tasks/util');
+const gulp = require('gulp');
+const server = require('gulp-server-livereload');
+const minimist = require('minimist');
+const util = require('./gulp_tasks/util');
-var options = minimist(process.argv.slice(2), {
+const options = minimist(process.argv.slice(2), {
default: {
- p: 8000, // port for gulp server
- h: '0.0.0.0', // host to serve on
+ p: 8000, // port for gulp server
+ h: '0.0.0.0', // host to serve on
}
});
@@ -43,8 +43,8 @@ gulp.task('watch', [], function() {
{ignoreInitial: true}, ['compile']);
});
-var httpPrefix = 'http://' + options.h + ':' + options.p + '/components';
-var proxies = util.tbComponents.map(function(component) {
+const httpPrefix = 'http://' + options.h + ':' + options.p + '/components';
+const proxies = util.tbComponents.map(function(component) {
return {
source: '/components' + component.replace(/_/g, '-'),
target: httpPrefix + component
@@ -84,7 +84,7 @@ gulp.task(
gulp.task('default', ['watch', 'server']);
// Clean all compiled JS files.
-var cleanCompiledTypeScript = require('gulp-clean-compiled-typescript');
+const cleanCompiledTypeScript = require('gulp-clean-compiled-typescript');
gulp.task('clean', function () {
return gulp.src(['./components/**/*.ts', '!./components/**/deps.d.ts'])
.pipe(cleanCompiledTypeScript());
diff --git a/tensorflow/tensorboard/package.json b/tensorflow/tensorboard/package.json
index 5dcf2f21e9..69f08495a3 100644
--- a/tensorflow/tensorboard/package.json
+++ b/tensorflow/tensorboard/package.json
@@ -4,7 +4,7 @@
"description": "Visualizers for TensorFlow",
"scripts": {
"test": "gulp test",
- "prepare": "npm install && bower install && typings install",
+ "prep": "npm install && bower install && typings install",
"compile": "gulp compile"
},
"keywords": [
diff --git a/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt b/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt
index e7e36e2bb3..5c77b3dd5c 100644
--- a/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt
@@ -13,7 +13,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'hop_bytes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "num_records_produced"
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index 7bf7fd5719..bfac54c601 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -71,8 +71,8 @@ ENV BAZEL_VERSION 0.4.5
WORKDIR /
RUN mkdir /bazel && \
cd /bazel && \
- curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
- curl -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE.txt && \
+ curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+ curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
chmod +x bazel-*.sh && \
./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
cd / && \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index 769731974a..7726cbdfbf 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -71,8 +71,8 @@ ENV BAZEL_VERSION 0.4.5
WORKDIR /
RUN mkdir /bazel && \
cd /bazel && \
- curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
- curl -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE.txt && \
+ curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+ curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
chmod +x bazel-*.sh && \
./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
cd / && \
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index e20d74fd4a..3ee99d5d31 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -29,7 +29,7 @@ from setuptools.dist import Distribution
# This version string is semver compatible, but incompatible with pip.
# For pip, we will remove all '-' characters from this string, and use the
# result for pip.
-_VERSION = '1.1.0-rc1'
+_VERSION = '1.1.0-rc2'
REQUIRED_PACKAGES = [
'numpy >= 1.11.0',
diff --git a/third_party/jemalloc.BUILD b/third_party/jemalloc.BUILD
index aabff39d7b..8ed13c51a5 100644
--- a/third_party/jemalloc.BUILD
+++ b/third_party/jemalloc.BUILD
@@ -89,6 +89,14 @@ cc_library(
"-D_REENTRANT",
],
includes = ["include"],
+ # pthread_atfork() is called for PPC.
+ linkopts = select({
+ "@%ws%//tensorflow:linux_ppc64le": [
+ "-lpthread",
+ ],
+ "//conditions:default": [
+ ],
+ }),
visibility = ["//visibility:public"],
)
@@ -183,12 +191,17 @@ sh_binary(
srcs = ["include/jemalloc/internal/size_classes.sh"],
)
-# Size classes for Linux x86_64. Update if adding builds for other
+# Size classes for Linux x86_64 and ppc64le. Update if adding builds for other
# architectures. See size_classes.sh for details on the arguments.
+# For default case, kept the arguments same as that of x86_64 for now.
genrule(
name = "size_classes_h",
outs = ["include/jemalloc/internal/size_classes.h"],
- cmd = "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@",
+ cmd = select({
+ "@%ws%//tensorflow:linux_ppc64le": "$(location :size_classes_sh) \"3 4\" 3 16 2 >$@",
+ "@%ws%//tensorflow:linux_x86_64": "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@",
+ "//conditions:default": "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@",
+ }),
tools = [":size_classes_sh"],
)
@@ -210,7 +223,13 @@ template_rule(
"#undef JEMALLOC_PREFIX": "#define JEMALLOC_PREFIX \"jemalloc_\"",
"#undef JEMALLOC_CPREFIX": "#define JEMALLOC_CPREFIX \"JEMALLOC_\"",
"#undef JEMALLOC_PRIVATE_NAMESPACE": "#define JEMALLOC_PRIVATE_NAMESPACE je_",
- "#undef CPU_SPINWAIT": "#define CPU_SPINWAIT __asm__ volatile(\"pause\")",
+ "#undef CPU_SPINWAIT": "\n".join([
+ "#if defined(__powerpc64__) || defined(__powerpc__)",
+ "#define CPU_SPINWAIT __asm__ volatile(\"or 27,27,27\")",
+ "#else",
+ "#define CPU_SPINWAIT __asm__ volatile(\"pause\")",
+ "#endif",
+ ]),
"#undef JEMALLOC_HAVE_BUILTIN_CLZ": "#define JEMALLOC_HAVE_BUILTIN_CLZ",
"#undef JEMALLOC_USE_SYSCALL": "#define JEMALLOC_USE_SYSCALL",
"#undef JEMALLOC_HAVE_SECURE_GETENV": "#define JEMALLOC_HAVE_SECURE_GETENV",
@@ -226,7 +245,13 @@ template_rule(
"#undef JEMALLOC_DSS": "#define JEMALLOC_DSS",
"#undef JEMALLOC_FILL": "#define JEMALLOC_FILL",
"#undef LG_TINY_MIN": "#define LG_TINY_MIN 3",
- "#undef LG_PAGE": "#define LG_PAGE 12",
+ "#undef LG_PAGE": "\n".join([
+ "#if defined(__powerpc64__) || defined(__powerpc__)",
+ "#define LG_PAGE 16",
+ "#else",
+ "#define LG_PAGE 12",
+ "#endif",
+ ]),
"#undef JEMALLOC_MAPS_COALESCE": "#define JEMALLOC_MAPS_COALESCE",
"#undef JEMALLOC_TLS": "#define JEMALLOC_TLS",
"#undef JEMALLOC_INTERNAL_UNREACHABLE": "#define JEMALLOC_INTERNAL_UNREACHABLE __builtin_unreachable",
diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD
index 819807220b..d5ab326283 100644
--- a/third_party/llvm/llvm.BUILD
+++ b/third_party/llvm/llvm.BUILD
@@ -152,6 +152,11 @@ all_cmake_vars = select({
cmake_vars + llvm_target_cmake_vars("X86", "x86_64-apple-darwin") +
darwin_cmake_vars,
),
+ "@%ws%//tensorflow:linux_ppc64le": cmake_var_string(
+ cmake_vars +
+ llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu") +
+ linux_cmake_vars,
+ ),
"//conditions:default": cmake_var_string(
cmake_vars +
llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu") +