aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--README.md10
-rw-r--r--tensorflow/BUILD2
-rw-r--r--tensorflow/cc/gradients/nn_grad.cc21
-rw-r--r--tensorflow/compiler/aot/BUILD35
-rw-r--r--tensorflow/compiler/aot/codegen.cc55
-rw-r--r--tensorflow/compiler/aot/codegen.h8
-rw-r--r--tensorflow/compiler/aot/codegen_test.cc42
-rw-r--r--tensorflow/compiler/aot/compile.cc340
-rw-r--r--tensorflow/compiler/aot/compile.h46
-rw-r--r--tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/BUILD20
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl2
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc26
-rw-r--r--tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt2
-rw-r--r--tensorflow/compiler/tf2xla/BUILD79
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc370
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.h43
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.proto (renamed from tensorflow/compiler/aot/tfcompile.proto)26
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_test.cc99
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc (renamed from tensorflow/compiler/aot/tfcompile_util.cc)74
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h (renamed from tensorflow/compiler/aot/tfcompile_util.h)23
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util_test.cc (renamed from tensorflow/compiler/aot/tfcompile_util_test.cc)89
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc1
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h14
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD23
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc68
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc20
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h12
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc30
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h2
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc39
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java33
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py2
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake3
-rw-r--r--tensorflow/contrib/data/__init__.py2
-rw-r--r--tensorflow/contrib/data/python/framework/BUILD48
-rw-r--r--tensorflow/contrib/data/python/framework/function.py275
-rw-r--r--tensorflow/contrib/data/python/framework/function_test.py59
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py49
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py75
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py20
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD4
-rw-r--r--tensorflow/contrib/data/python/ops/dataset_ops.py166
-rw-r--r--tensorflow/contrib/data/python/ops/sloppy_ops.py2
-rw-r--r--tensorflow/contrib/eager/python/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/tfe_test.py10
-rw-r--r--tensorflow/contrib/estimator/BUILD66
-rw-r--r--tensorflow/contrib/estimator/__init__.py3
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn.py134
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_test.py153
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py218
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py570
-rw-r--r--tensorflow/contrib/factorization/BUILD1
-rw-r--r--tensorflow/contrib/gan/BUILD113
-rw-r--r--tensorflow/contrib/gan/__init__.py2
-rw-r--r--tensorflow/contrib/gan/python/eval/__init__.py39
-rw-r--r--tensorflow/contrib/gan/python/eval/python/classifier_metrics.py28
-rw-r--r--tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py401
-rw-r--r--tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py316
-rw-r--r--tensorflow/contrib/gan/python/eval/python/eval_utils.py28
-rw-r--r--tensorflow/contrib/gan/python/eval/python/eval_utils_impl.py134
-rw-r--r--tensorflow/contrib/gan/python/eval/python/eval_utils_test.py48
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries.py28
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_impl.py157
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_test.py96
-rwxr-xr-xtensorflow/contrib/image/BUILD75
-rwxr-xr-xtensorflow/contrib/image/__init__.py9
-rw-r--r--tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc172
-rw-r--r--tensorflow/contrib/image/ops/distort_image_ops.cc60
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py338
-rw-r--r--tensorflow/contrib/image/python/ops/distort_image_ops.py138
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py8
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py6
-rw-r--r--tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc2
-rw-r--r--tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc3
-rw-r--r--tensorflow/core/common_runtime/executor.cc44
-rw-r--r--tensorflow/core/kernels/BUILD6
-rw-r--r--tensorflow/core/kernels/cholesky_op.cc23
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op.cc167
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op.h53
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc80
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt35
-rw-r--r--tensorflow/core/ops/linalg_ops.cc2
-rw-r--r--tensorflow/core/ops/ops.pbtxt5
-rw-r--r--tensorflow/core/profiler/g3doc/command_line.md2
-rw-r--r--tensorflow/core/public/version.h4
-rw-r--r--tensorflow/docs_src/install/install_c.md2
-rw-r--r--tensorflow/docs_src/install/install_go.md2
-rw-r--r--tensorflow/docs_src/install/install_java.md18
-rw-r--r--tensorflow/docs_src/install/install_linux.md22
-rw-r--r--tensorflow/docs_src/install/install_mac.md10
-rw-r--r--tensorflow/docs_src/install/install_sources.md4
-rw-r--r--tensorflow/docs_src/performance/xla/tfcompile.md2
-rw-r--r--tensorflow/go/op/wrappers.go846
-rw-r--r--tensorflow/python/BUILD3
-rw-r--r--tensorflow/python/debug/BUILD1
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli.py2
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli_test.py196
-rw-r--r--tensorflow/python/eager/backprop.py151
-rw-r--r--tensorflow/python/eager/context.py14
-rw-r--r--tensorflow/python/eager/core.py4
-rw-r--r--tensorflow/python/eager/custom_gradient.py4
-rw-r--r--tensorflow/python/eager/execute.py5
-rw-r--r--tensorflow/python/eager/memory_trace.py11
-rw-r--r--tensorflow/python/eager/tape.py81
-rw-r--r--tensorflow/python/eager/tape_test.py30
-rw-r--r--tensorflow/python/framework/function.py44
-rw-r--r--tensorflow/python/framework/function_test.py31
-rw-r--r--tensorflow/python/framework/test_util.py27
-rw-r--r--tensorflow/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/matrix_band_part_op_test.py120
-rw-r--r--tensorflow/python/kernel_tests/svd_op_test.py4
-rw-r--r--tensorflow/python/ops/nn_impl.py6
-rw-r--r--tensorflow/python/saved_model/signature_def_utils.py14
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_impl.py80
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_test.py68
-rw-r--r--tensorflow/python/util/tf_inspect.py4
-rw-r--r--tensorflow/tensorflow.bzl1
-rw-r--r--tensorflow/tools/api/generator/BUILD53
-rw-r--r--tensorflow/tools/api/generator/create_python_api.py179
-rw-r--r--tensorflow/tools/api/generator/create_python_api_test.py86
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu2
-rw-r--r--tensorflow/tools/pip_package/setup.py6
141 files changed, 6458 insertions, 1866 deletions
diff --git a/README.md b/README.md
index f4c3f33da2..8e8a5301e8 100644
--- a/README.md
+++ b/README.md
@@ -45,11 +45,11 @@ GPU packages on all platforms will arrive soon!
**Individual whl files**
-* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
-* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
-* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
-* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/))
-* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/))
+* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
+* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.4.0dev-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.4.0dev-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.4.0dev-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
+* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.4.0dev-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
+* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.4.0dev-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.4.0dev-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/))
+* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.4.0dev-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.4.0dev-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/))
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 33609c26e7..1b859dbda1 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -283,7 +283,6 @@ filegroup(
"//tensorflow/contrib/crf:all_files",
"//tensorflow/contrib/cudnn_rnn:all_files",
"//tensorflow/contrib/data:all_files",
- "//tensorflow/contrib/data/python/framework:all_files",
"//tensorflow/contrib/data/python/kernel_tests:all_files",
"//tensorflow/contrib/data/python/ops:all_files",
"//tensorflow/contrib/data/python/util:all_files",
@@ -418,6 +417,7 @@ filegroup(
"//tensorflow/python/profiler/internal:all_files",
"//tensorflow/python/saved_model:all_files",
"//tensorflow/python/tools:all_files",
+ "//tensorflow/tools/api/generator:all_files",
"//tensorflow/tools/api/golden:all_files",
"//tensorflow/tools/api/lib:all_files",
"//tensorflow/tools/api/tests:all_files",
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc
index c12f43cda2..fcc3fc9dae 100644
--- a/tensorflow/cc/gradients/nn_grad.cc
+++ b/tensorflow/cc/gradients/nn_grad.cc
@@ -127,10 +127,11 @@ Status Conv2DGrad(const Scope& scope, const Operation& op,
std::vector<int32> strides;
bool use_cudnn_on_gpu;
auto attrs = op.output(0).node()->attrs();
- GetNodeAttr(attrs, "data_format", &data_format);
- GetNodeAttr(attrs, "padding", &padding);
- GetNodeAttr(attrs, "strides", &strides);
- GetNodeAttr(attrs, "use_cudnn_on_gpu", &use_cudnn_on_gpu);
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "use_cudnn_on_gpu",
+ &use_cudnn_on_gpu));
Conv2DBackpropInput::Attrs input_attrs;
input_attrs.DataFormat(data_format);
input_attrs.UseCudnnOnGpu(use_cudnn_on_gpu);
@@ -157,10 +158,10 @@ Status MaxPoolGradHelper(const Scope& scope, const Operation& op,
std::vector<int32> strides;
std::vector<int32> ksize;
auto attrs = op.output(0).node()->attrs();
- GetNodeAttr(attrs, "data_format", &data_format);
- GetNodeAttr(attrs, "ksize", &ksize);
- GetNodeAttr(attrs, "padding", &padding);
- GetNodeAttr(attrs, "strides", &strides);
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize));
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
internal::MaxPoolGrad::Attrs grad_attrs;
grad_attrs.DataFormat(data_format);
auto dx = internal::MaxPoolGrad(scope, op.input(0),
@@ -179,8 +180,8 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op,
string data_format;
string padding;
auto attrs = op.output(0).node()->attrs();
- GetNodeAttr(attrs, "data_format", &data_format);
- GetNodeAttr(attrs, "padding", &padding);
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
MaxPoolGradV2::Attrs grad_attrs;
grad_attrs.DataFormat(data_format);
auto dx = MaxPoolGradV2(scope, op.input(0),
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 9909e88e64..29dbe4a08b 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -4,7 +4,6 @@ package(
default_visibility = ["//visibility:private"],
)
-load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
# Optional runtime utilities for use by code generated by tfcompile.
@@ -39,32 +38,24 @@ cc_library(
deps = ["//tensorflow/core:test_main"],
)
-xla_proto_library(
- name = "tfcompile_proto",
- srcs = ["tfcompile.proto"],
- deps = [
- "//tensorflow/core:protos_all_cc",
- ],
-)
-
cc_library(
name = "tfcompile_lib",
srcs = [
"codegen.cc",
"compile.cc",
"flags.cc",
- "tfcompile_util.cc",
],
hdrs = [
"codegen.h",
"compile.h",
"flags.h",
- "tfcompile_util.h",
],
deps = [
":runtime", # needed by codegen to print aligned_buffer_bytes
- ":tfcompile_proto",
+ "//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:tf2xla_proto",
+ "//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@@ -82,7 +73,6 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core:stream_executor_no_cuda",
],
)
@@ -99,18 +89,6 @@ cc_test(
],
)
-cc_test(
- name = "tfcompile_util_test",
- srcs = ["tfcompile_util_test.cc"],
- deps = [
- ":tfcompile_lib",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
cc_binary(
name = "tfcompile",
visibility = ["//visibility:public"],
@@ -123,7 +101,8 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":tfcompile_lib",
- ":tfcompile_proto",
+ "//tensorflow/compiler/tf2xla:tf2xla_proto",
+ "//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:core_cpu",
@@ -226,7 +205,11 @@ test_suite(
tags = ["manual"],
tests = [
":benchmark_test",
+ ":codegen_test",
+ ":runtime_test",
":test_graph_tfadd_test",
+ ":test_graph_tfunknownop2_test",
+ ":test_graph_tfunknownop3_test",
":test_graph_tfunknownop_test",
"//tensorflow/compiler/aot/tests:all_tests",
],
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index bbdb342a62..fc5c6ce58d 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/aot/runtime.h"
-#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/tf2xla/str_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -35,6 +35,12 @@ namespace tfcompile {
namespace {
+bool IsAlpha(char c) {
+ return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
+}
+
+bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
+
// Convert an XLA type into a C++ type.
Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
switch (type) {
@@ -156,7 +162,7 @@ string RewriteWithName(const string& name, string code,
}
// Generate methods for args (inputs).
-Status GenArgMethods(const Config& config, const xla::ProgramShape& ps,
+Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
const CompileResult& compile_result, string* methods) {
*methods += R"(
void** args() { return args_; }
@@ -204,8 +210,8 @@ Status GenArgMethods(const Config& config, const xla::ProgramShape& ps,
}
// Generate methods for results (outputs).
-Status GenResultMethods(const Config& config, const xla::ProgramShape& ps,
- string* methods) {
+Status GenResultMethods(const tf2xla::Config& config,
+ const xla::ProgramShape& ps, string* methods) {
if (ps.result().element_type() != xla::TUPLE) {
// Non-tuple (i.e. single-result) case.
if (config.fetch_size() != 1) {
@@ -285,11 +291,26 @@ Status GenResultMethods(const Config& config, const xla::ProgramShape& ps,
return Status::OK();
}
+Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
+ for (const tf2xla::Feed& feed : config.feed()) {
+ if (!feed.name().empty()) {
+ TF_RETURN_IF_ERROR(ValidateCppIdent(feed.name(), "feed name"));
+ }
+ }
+ for (const tf2xla::Fetch& fetch : config.fetch()) {
+ if (!fetch.name().empty()) {
+ TF_RETURN_IF_ERROR(ValidateCppIdent(fetch.name(), "fetch name"));
+ }
+ }
+ return Status::OK();
+}
+
} // namespace
-Status GenerateHeader(const HeaderOpts& opts, const Config& config,
+Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
const CompileResult& compile_result, string* header) {
TF_RETURN_IF_ERROR(ValidateConfig(config));
+ TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
const int64 result_index = compile_result.aot->result_buffer_index();
const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
if (result_index < 0 || result_index > temp_sizes.size()) {
@@ -574,5 +595,29 @@ Status ParseCppClass(const string& cpp_class, string* class_name,
return Status::OK();
}
+Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
+ if (ident.empty()) {
+ return errors::InvalidArgument("empty identifier: ", msg);
+ }
+ // Require that the identifier starts with a nondigit, and is composed of
+ // nondigits and digits, as specified in section [2.11 Identifiers] of the
+ // C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is
+ // defined as [0-9].
+ //
+ // Technically the standard also allows for `universal-character-name`, with a
+ // table of allowed unicode ranges, as well as `other implementation-defined
+ // characters`. We disallow those here to give better error messages, at the
+ // expensive of being more restrictive than the standard.
+ if (ident[0] != '_' && !IsAlpha(ident[0])) {
+ return errors::InvalidArgument("illegal leading char: ", msg);
+ }
+ for (size_t pos = 1; pos < ident.size(); ++pos) {
+ if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
+ return errors::InvalidArgument("illegal char: ", msg);
+ }
+ }
+ return Status::OK();
+}
+
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h
index 7217c57739..740edd1e83 100644
--- a/tensorflow/compiler/aot/codegen.h
+++ b/tensorflow/compiler/aot/codegen.h
@@ -20,6 +20,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/aot/compile.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
namespace tfcompile {
@@ -37,7 +39,7 @@ struct HeaderOpts {
// GenerateHeader uses the meta-information from compile_result to generate a
// C++ header giving access to the function in the generated object file. The
// header includes API usage documentation.
-Status GenerateHeader(const HeaderOpts& opts, const Config& config,
+Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
const CompileResult& compile_result, string* header);
// ParseCppClass parses `cpp_class` into its `class_name` and `namespaces`
@@ -47,6 +49,10 @@ Status GenerateHeader(const HeaderOpts& opts, const Config& config,
Status ParseCppClass(const string& cpp_class, string* class_name,
std::vector<string>* namespaces);
+// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is
+// appended to error messages.
+Status ValidateCppIdent(StringPiece ident, StringPiece msg);
+
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index e3f76f3666..98cbd67e53 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
@@ -29,6 +30,41 @@ namespace tensorflow {
namespace tfcompile {
namespace {
+void ExpectErrorContains(const Status& status, StringPiece str) {
+ EXPECT_NE(Status::OK(), status);
+ EXPECT_TRUE(StringPiece(status.error_message()).contains(str))
+ << "expected error: " << status.error_message() << " to contain: " << str;
+}
+
+TEST(ValidateCppIdent, Simple) {
+ TF_EXPECT_OK(ValidateCppIdent("a", ""));
+ TF_EXPECT_OK(ValidateCppIdent("abc", ""));
+ TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
+ TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
+ // Make sure we didn't skip a valid letter or digit
+ string ident;
+ for (char c = 'a'; c <= 'z'; c++) {
+ ident.append(1, c);
+ }
+ for (char c = 'A'; c <= 'Z'; c++) {
+ ident.append(1, c);
+ }
+ for (char c = '0'; c <= '9'; c++) {
+ ident.append(1, c);
+ }
+ ident += "_";
+ TF_EXPECT_OK(ValidateCppIdent(ident, ""));
+
+ ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
+ ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
+ ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
+ ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
+ ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
+ ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
+ ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
+ ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
+}
+
class ParseCppClassTest : public ::testing::Test {
protected:
void ExpectOK(const string& cpp_class, const string& want_class_name,
@@ -91,13 +127,13 @@ TEST(GenerateHeader, Golden) {
HeaderOpts opts;
opts.class_name = "MyClass";
opts.namespaces = {"foo", "bar"};
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("feed0");
feed->set_name("myfeed");
feed = config.add_feed();
feed->mutable_id()->set_node_name("feed1");
- Fetch* fetch = config.add_fetch();
+ tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("fetch0");
fetch->set_name("myfetch");
CompileResult compile_result;
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index a485d2e555..eac8da0ab1 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -15,326 +15,32 @@ limitations under the License.
#include "tensorflow/compiler/aot/compile.h"
-#include <map>
#include <memory>
#include <string>
-#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/compiler/aot/flags.h"
-#include "tensorflow/compiler/aot/tfcompile_util.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/compile_only_client.h"
-#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
-#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/graph_def_util.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/versions.pb.h"
-#include "tensorflow/core/graph/algorithm.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace tfcompile {
-const char* const kArgOp = "_Arg";
-const char* const kRetvalOp = "_Retval";
-const char* const kFeedIdAttr = "_feed_id";
-const char* const kFetchIdAttr = "_fetch_id";
-const char* const kShapeAttr = "_shape";
-const char* const kDebugNameAttr = "_debug_name";
-
namespace {
-Status DumpGraph(const MainFlags& flags, const string& name,
- const Graph& graph) {
- if (flags.debug_dir.empty()) {
- return Status::OK();
- }
- GraphDef graph_def;
- graph.ToGraphDef(&graph_def);
- string file = io::JoinPath(flags.debug_dir, name + ".pbtxt");
- return WriteTextProto(Env::Default(), file, graph_def);
-}
-
-typedef std::unordered_map<string, Node*> NodeMap;
-
-// Each feed id identifies the positional output of some node, which may consist
-// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
-// tensor with a placeholder. For each feed tensor, replaces all edges so they
-// point from a new _Arg node instead.
-Status AddArgNodes(Graph* graph, const NodeMap& node_map,
- const protobuf::RepeatedPtrField<Feed>& feeds,
- const std::unordered_map<string, string>& feed_remapping) {
- for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
- const Feed& feed = feeds[arg_index];
- // All feeds have been replaced by placeholders.
- const int output_index = 0;
-
- const string key = TensorIdToString(feed.id());
- const auto remap_it = feed_remapping.find(key);
- auto node_it = node_map.find(remap_it->second);
- if (node_it == node_map.end()) {
- // Strip off the aot_feed_#/ prefix.
- StringPiece name(remap_it->second);
- const auto index = name.find('/');
- if (index > 0) name.remove_prefix(index + 1);
- return errors::InvalidArgument(
- "Node is fed but not needed for fetching: ", name);
- }
- const Node* feed_node = node_it->second;
-
- // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
- // "_shape" attr if we can determine it. That way the graph will be
- // initialized with whatever shapes we can infer, while the user can still
- // explicitly specify or override them.
- Node* arg_node = nullptr;
- TF_RETURN_IF_ERROR(
- NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
- .Attr("T", BaseType(feed_node->output_type(output_index)))
- .Attr("index", arg_index)
- .Attr(kFeedIdAttr, TensorIdToString(feed.id()))
- .Attr(kShapeAttr, TensorShape(feed.shape()))
- .Attr(kDebugNameAttr, feed.name())
- .Finalize(graph, &arg_node));
-
- // Collects out-edges from the feed node that have a matching edge index;
- // these will be replaced with edges from the arg node instead.
- //
- // We must collect the edges first and process them in a second pass, since
- // removing the edge from the graph invalidates feed_node->out_edges.
- std::vector<const Edge*> feed_edges;
- for (const Edge* edge : feed_node->out_edges()) {
- if (edge->src_output() == output_index) {
- feed_edges.push_back(edge);
- }
- }
- for (const Edge* edge : feed_edges) {
- graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
- graph->RemoveEdge(edge);
- }
- }
- return Status::OK();
-}
-
-// Each fetch id identifies the positional output of some node. For each fetch
-// node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
-Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
- const protobuf::RepeatedPtrField<Fetch>& fetches,
- std::unordered_set<const Node*>* retval_nodes) {
- for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
- const TensorId& id = fetches[ret_index].id();
- auto it = node_map.find(id.node_name());
- if (it == node_map.end()) {
- return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
- }
- Node* fetch_node = it->second;
- if (id.output_index() >= fetch_node->num_outputs()) {
- return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
- ", output index should be < ",
- fetch_node->num_outputs());
- }
- // Connects fetch_node -> retval_node.
- Node* retval_node = nullptr;
- TF_RETURN_IF_ERROR(
- NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
- .Input(fetch_node, id.output_index())
- .Attr("T", BaseType(fetch_node->output_type(id.output_index())))
- .Attr("index", ret_index)
- .Attr(kFetchIdAttr, TensorIdToString(id))
- .Finalize(graph, &retval_node));
- retval_nodes->insert(retval_node);
- }
- return Status::OK();
-}
-
-// RewriteAndPruneGraph identifies input and output edges (named by the feed and
-// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
-// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
-// execution to know the input and output args for the generated function.
-Status RewriteAndPruneGraph(
- Graph* graph, const Config& config,
- const std::unordered_map<string, string>& feed_remapping,
- const MainFlags& flags) {
- NodeMap node_map;
- for (Node* n : graph->nodes()) {
- node_map[n->name()] = n;
- }
- TF_RETURN_IF_ERROR(
- AddArgNodes(graph, node_map, config.feed(), feed_remapping));
- std::unordered_set<const Node*> retval_nodes;
- TF_RETURN_IF_ERROR(
- AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
- TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_rewrite", *graph));
- PruneForReverseReachability(graph, retval_nodes);
- FixupSourceAndSinkEdges(graph);
- TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_prune", *graph));
- // Sanity-check, to make sure the feeds and fetches still exist post-pruning.
- std::set<string> missing_feeds, missing_fetches;
- for (const Feed& feed : config.feed()) {
- missing_feeds.insert(TensorIdToString(feed.id()));
- }
- for (const Fetch& fetch : config.fetch()) {
- missing_fetches.insert(TensorIdToString(fetch.id()));
- }
- for (const Node* n : graph->op_nodes()) {
- if (n->type_string() == kArgOp) {
- string feed_id;
- TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
- if (missing_feeds.erase(feed_id) == 0) {
- return errors::Aborted(kArgOp,
- " node found with unknown feed id: ", feed_id);
- }
- } else if (n->type_string() == kRetvalOp) {
- string fetch_id;
- TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
- if (missing_fetches.erase(fetch_id) == 0) {
- return errors::Aborted(kRetvalOp,
- " node found with unknown fetch id: ", fetch_id);
- }
- }
- }
- if (!missing_feeds.empty() || !missing_fetches.empty()) {
- return errors::Aborted(
- "Post graph-pruning",
- ", missing feeds: ", str_util::Join(missing_feeds, ", "),
- ", missing fetches: ", str_util::Join(missing_fetches, ", "));
- }
- return Status::OK();
-}
-
-// CollectArgNodes collects _Arg nodes from the graph, and performs basic
-// sanity-checking to ensure the index and type attributes of each node are
-// initialized correctly.
-Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
- std::map<int, Node*> indexed_arg_nodes;
- for (Node* n : graph.nodes()) {
- if (n->type_string() == kArgOp) {
- int index;
- TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
- auto insert_result = indexed_arg_nodes.insert({index, n});
- if (!insert_result.second) {
- const Node* dup = insert_result.first->second;
- return errors::InvalidArgument(
- "Multiple ", kArgOp, " nodes with index ", index, ", ",
- n->DebugString(), " and ", dup->DebugString());
- }
- }
- }
- arg_nodes->clear();
- for (const auto& index_node : indexed_arg_nodes) {
- if (index_node.first != arg_nodes->size()) {
- return errors::InvalidArgument("Expected ", kArgOp, " node with index ",
- arg_nodes->size(), ", but got index ",
- index_node.first);
- }
- arg_nodes->push_back(index_node.second);
- }
- return Status::OK();
-}
-
-// Fills in xla_args from the corresponding _Arg nodes in the graph.
-Status CreateXlaArgs(const Graph& graph,
- std::vector<XlaCompiler::Argument>* xla_args) {
- std::vector<Node*> arg_nodes;
- TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
- for (const Node* node : arg_nodes) {
- XlaCompiler::Argument arg;
- arg.kind = XlaCompiler::Argument::kParameter;
- TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
- TensorShape shape;
- TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
- TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
- TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
- xla_args->push_back(arg);
- }
- return Status::OK();
-}
-
-// Converts the TensorFlow graph into an XLA computation, by executing the
-// graph symbolically, with each op building up the XLA HLO.
-Status ConvertGraphToXla(xla::CompileOnlyClient* client,
- std::unique_ptr<Graph> graph,
- xla::Computation* computation, bool* has_context_arg) {
- // Create a device and context to convert the graph into an XLA computation.
- XlaOpRegistry::RegisterCompilationKernels();
- // Populate the context with args from the graph.
- for (Node* node : graph->nodes()) {
- node->set_assigned_device_name(DEVICE_CPU_XLA_JIT);
- }
- std::vector<XlaCompiler::Argument> xla_args;
- TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
-
- // Compile the graph into an XLA computation.
- XlaCompiler::Options compiler_options;
- compiler_options.client = client;
- DeviceType device_type(DEVICE_CPU_XLA_JIT);
- compiler_options.device_type = &device_type;
- compiler_options.flib_def = &graph->flib_def();
- compiler_options.graph_def_version = graph->versions().producer();
- compiler_options.allow_cpu_custom_calls = true;
- XlaCompiler compiler(compiler_options);
-
- XlaCompiler::CompilationResult result;
- TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(),
- "tfcompile", std::move(graph),
- xla_args, &result));
- *has_context_arg = result.requires_runtime_context;
- *computation = std::move(*result.computation);
-
- int num_const_results = 0;
- for (int i = 0; i < result.outputs.size(); ++i) {
- // Ending up with const results (i.e. output args) is an error, since it
- // means that one or more fetches that the user specified will be dropped
- // from the generated function. It's most likely a configuration error,
- // since the user shouldn't be asking for output args that end up as consts.
- //
- // TODO(toddw): Provide a way for the user to access const output args,
- // e.g. perhaps hard-coded into the header, or somehow copied into the
- // output buffers.
- if (result.outputs[i].is_constant) {
- ++num_const_results;
- LOG(ERROR) << "ConstRetVal index:" << i
- << " value:" << result.outputs[i].constant_value.DebugString();
- }
- }
- if (num_const_results > 0) {
- return errors::Unimplemented(
- "Conversion from TensorFlow graph to XLA resulted in ",
- num_const_results,
- " constant results. The configuration of "
- "the output args (i.e. fetch ids) is probably wrong.");
- }
- if (computation->IsNull()) {
- return errors::Aborted(
- "Conversion from TensorFlow graph to XLA resulted in an empty "
- "computation.");
- }
- return Status::OK();
-}
-
// Compiles the XLA computation into executable code.
Status CompileXla(xla::CompileOnlyClient* client,
const xla::Computation& computation,
@@ -376,41 +82,8 @@ Status CompileXla(xla::CompileOnlyClient* client,
} // namespace
-Status InitGraph(const GraphDef& graph_def, const Config& config,
- const MainFlags& flags, std::unique_ptr<Graph>* graph) {
- TF_RETURN_IF_ERROR(ValidateConfig(config));
-
- FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
- std::unique_ptr<Graph> g(new Graph(flib_def));
-
- // Replace references to fed tensors with references to newly added
- // placeholders.
- GraphDef first_copy_def = graph_def;
-
- // Maps from name:port of a feed to the name:port of the placeholder to use.
- std::unordered_map<string, string> feed_remapping;
- TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
- &feed_remapping, &first_copy_def));
-
- // Prune the GraphDef first so that unknown ops that we aren't compiling get
- // filtered out.
- GraphDef second_copy_def;
- TF_RETURN_IF_ERROR(
- PruneGraphDefInto(config, first_copy_def, &second_copy_def));
-
- TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
- &second_copy_def, *g->op_registry(), 0 /*node_offset*/));
-
- TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
- second_copy_def, g.get()));
- TF_RETURN_IF_ERROR(
- RewriteAndPruneGraph(g.get(), config, feed_remapping, flags));
- *graph = std::move(g);
- return Status::OK();
-}
-
-Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
- CompileResult* compile_result) {
+Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
+ const MainFlags& flags, CompileResult* compile_result) {
// Converts the graph into an XLA computation, and compiles the
// computation.
// TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
@@ -421,8 +94,9 @@ Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
.ValueOrDie();
xla::Computation computation;
- TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation,
- &compile_result->has_context_arg));
+ TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client,
+ &computation,
+ &compile_result->has_context_arg));
if (!flags.debug_dir.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
computation.Snapshot());
diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h
index e929272b2e..965c296081 100644
--- a/tensorflow/compiler/aot/compile.h
+++ b/tensorflow/compiler/aot/compile.h
@@ -18,46 +18,16 @@ limitations under the License.
#include <memory>
#include <string>
-#include <vector>
#include "tensorflow/compiler/aot/flags.h"
-#include "tensorflow/compiler/aot/tfcompile.pb.h"
-#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
namespace tfcompile {
-// Constants for op types and attribute names.
-extern const char* const kArgOp;
-extern const char* const kRetvalOp;
-extern const char* const kFeedIdAttr;
-extern const char* const kFetchIdAttr;
-extern const char* const kShapeAttr;
-extern const char* const kDebugNameAttr;
-
-// InitGraph creates a graph based on the graph_def, that may then be compiled
-// by CompileGraph.
-//
-// The graph is rewritten with _Arg and _Retval nodes, representing the inputs
-// and outputs of the function that will be compiled. Each feed id causes a new
-// _Arg node to be created, where we first collect all existing edges pointing
-// from the named node's output index, and then rewrite them to point from that
-// _Arg node instead. Each fetch id causes a new _Retval node to be created,
-// with a new edge pointing from the named node's output index to that _Retval
-// node. All _Retval nodes also point to a special CompileExpressions node,
-// used internally to finish the compilation.
-//
-// The rewritten graph is then pruned to only contain the portion necessary to
-// compute the outputs. If dump_graphs is true, graph rewrites will be dumped
-// for debugging.
-Status InitGraph(const GraphDef& graph_def, const Config& config,
- const MainFlags& flags, std::unique_ptr<Graph>* graph);
-
// CompileResult describes the output of CompileGraph, where the object file
// data and meta-information is available in aot.
struct CompileResult {
@@ -69,20 +39,12 @@ struct CompileResult {
int pointer_size = 0; // Size of a pointer in bytes.
};
-// CompileGraph compiles the graph into an object file containing a function
+// CompileGraph compiles the graph_def into an object file containing a function
// that performs the graph operations.
//
-// The graph must have _Arg and _Retval nodes representing the function inputs
-// and outputs. Every _Arg node must have a shape attribute (key=kShapeAttr,
-// value=TensorShape) representing the static shape of that input, and every
-// _Retval node must point to a CompileExpressions node.
-//
-// Typically InitGraph is called to perform this initialization, followed by
-// full specification of the shape attributes.
-//
// The XLA compilation options are specified in the flags.
-Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
- CompileResult* result);
+Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
+ const MainFlags& flags, CompileResult* compile_result);
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt
index f2d9c34b2d..a4ad334352 100644
--- a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt
+++ b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {
diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt
index 5625c0ab03..d3f0e4990c 100644
--- a/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt
+++ b/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {
diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt
index 7370ed370d..e0b012adea 100644
--- a/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt
+++ b/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {
diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt
index b2d7d54574..662ba1c321 100644
--- a/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt
+++ b/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 05d338e4c5..4d65a044bc 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -13,9 +13,11 @@ test_suite(
":test_graph_tfadd_test",
":test_graph_tfadd_with_ckpt_saver_test",
":test_graph_tfadd_with_ckpt_test",
+ ":test_graph_tffunction_test",
":test_graph_tfgather_test",
":test_graph_tfmatmul_test",
":test_graph_tfmatmulandadd_test",
+ ":test_graph_tfsplits_test",
":tfcompile_test",
],
)
@@ -91,6 +93,15 @@ tf_library(
)
tf_library(
+ name = "test_graph_tffunction",
+ testonly = 1,
+ config = "test_graph_tffunction.config.pbtxt",
+ cpp_class = "FunctionComp",
+ graph = "test_graph_tffunction.pb",
+ tags = ["manual"],
+)
+
+tf_library(
name = "test_graph_tfgather",
testonly = 1,
config = "test_graph_tfgather.config.pbtxt",
@@ -118,15 +129,6 @@ tf_library(
)
tf_library(
- name = "test_graph_tffunction",
- testonly = 1,
- config = "test_graph_tffunction.config.pbtxt",
- cpp_class = "FunctionComp",
- graph = "test_graph_tffunction.pb",
- tags = ["manual"],
-)
-
-tf_library(
name = "test_graph_tfsplits",
testonly = 1,
config = "test_graph_tfsplits.config.pbtxt",
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt
index 5625c0ab03..d3f0e4990c 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt
index 4d876a6e91..8adc9cdc14 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_hold" }
shape {
diff --git a/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt
index eb9c1cacb7..cbfe458908 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt
index 648ee31fdb..89ed678a9c 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "params" }
shape {
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt
index a3ce2029c1..2acd0289c2 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_hold" }
shape {
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt
index 4a4a237a4f..e5ca6115e9 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_hold" }
shape {
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt
index 85fc7da442..5adc77336c 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x" }
shape {
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index f9896988dc..fc1342d84e 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -41,7 +41,7 @@ def tf_library(name, graph, config,
graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it
is expected to be in the human-readable proto text format, otherwise it is
expected to be in the proto binary format.
- config: File containing tensorflow.tfcompile.Config proto. If the file ends
+ config: File containing tensorflow.tf2xla.Config proto. If the file ends
in '.pbtxt' it is expected to be in the human-readable proto text format,
otherwise it is expected to be in the proto binary format.
freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index be2cfe4734..cc499c3284 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -21,8 +21,8 @@ limitations under the License.
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h"
-#include "tensorflow/compiler/aot/tfcompile.pb.h"
-#include "tensorflow/compiler/aot/tfcompile_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/core/framework/function.h"
@@ -54,8 +54,7 @@ const char kUsageHeader[] =
"--cpp_class=\"mynamespace::MyComputation\"\n"
"\n";
-Status ReadProtoFile(const string& kind, const string& fname,
- protobuf::Message* proto) {
+Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (StringPiece(fname).ends_with(".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
@@ -63,23 +62,17 @@ Status ReadProtoFile(const string& kind, const string& fname,
}
}
-void ParseTensorId(const string& name, TensorId* id) {
- const std::pair<StringPiece, int> name_index = ParseTensorName(name);
- id->set_node_name(name_index.first.ToString());
- id->set_output_index(name_index.second);
-}
-
Status Main(const MainFlags& flags) {
// Process config.
- Config config;
+ tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
- TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config));
+ TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
- for (const Fetch& fetch : config.fetch()) {
+ for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << str_util::Join(nodes, ",");
@@ -91,12 +84,9 @@ Status Main(const MainFlags& flags) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
- TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def));
- std::unique_ptr<Graph> graph;
- TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &graph));
-
+ TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
- TF_RETURN_IF_ERROR(CompileGraph(std::move(graph), flags, &compile_result));
+ TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
// Write output files.
Env* env = Env::Default();
diff --git a/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt b/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt
index c46e65f71a..3025fc27b1 100644
--- a/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt
+++ b/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed{ id{node_name:"inputs/x_seq_0/read"} shape{dim{size:128}dim{size:1024}} }
feed{ id{node_name:"inputs/x_seq_1/read"} shape{dim{size:128}dim{size:1024}} }
feed{ id{node_name:"inputs/x_seq_2/read"} shape{dim{size:128}dim{size:1024}} }
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 13fc233054..22f2441a68 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -21,6 +21,40 @@ package(
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
+load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
+
+xla_proto_library(
+ name = "tf2xla_proto",
+ srcs = ["tf2xla.proto"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "tf2xla",
+ srcs = ["tf2xla.cc"],
+ hdrs = ["tf2xla.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":common",
+ ":dump_graph",
+ ":tf2xla_proto",
+ ":tf2xla_util",
+ ":xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/compiler/xla/client",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
cc_library(
name = "xla_compiler",
@@ -96,6 +130,51 @@ cc_library(
# Internal targets below this point.
+cc_library(
+ name = "tf2xla_util",
+ srcs = ["tf2xla_util.cc"],
+ hdrs = ["tf2xla_util.h"],
+ deps = [
+ ":tf2xla_proto",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "tf2xla_util_test",
+ srcs = ["tf2xla_util_test.cc"],
+ deps = [
+ ":tf2xla_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_test(
+ name = "tf2xla_test",
+ srcs = ["tf2xla_test.cc"],
+ deps = [
+ ":tf2xla",
+ ":tf2xla_proto",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
cc_test(
name = "xla_compiler_test",
srcs = ["xla_compiler_test.cc"],
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
new file mode 100644
index 0000000000..b29c92190d
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -0,0 +1,370 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/tf2xla.h"
+
+#include <map>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+const char* const kArgOp = "_Arg";
+const char* const kRetvalOp = "_Retval";
+const char* const kFeedIdAttr = "_feed_id";
+const char* const kFetchIdAttr = "_fetch_id";
+const char* const kShapeAttr = "_shape";
+const char* const kDebugNameAttr = "_debug_name";
+
+namespace {
+
+typedef std::unordered_map<string, Node*> NodeMap;
+
+// Each feed id identifies the positional output of some node, which may consist
+// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
+// tensor with a placeholder. For each feed tensor, replaces all edges so they
+// point from a new _Arg node instead.
+Status AddArgNodes(Graph* graph, const NodeMap& node_map,
+ const protobuf::RepeatedPtrField<tf2xla::Feed>& feeds,
+ const std::unordered_map<string, string>& feed_remapping) {
+ for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
+ const tf2xla::Feed& feed = feeds[arg_index];
+ // All feeds have been replaced by placeholders.
+ const int output_index = 0;
+
+ const string key = TensorIdToString(feed.id());
+ const auto remap_it = feed_remapping.find(key);
+ auto node_it = node_map.find(remap_it->second);
+ if (node_it == node_map.end()) {
+ // Strip off the aot_feed_#/ prefix.
+ StringPiece name(remap_it->second);
+ const auto index = name.find('/');
+ if (index > 0) name.remove_prefix(index + 1);
+ return errors::InvalidArgument(
+ "Node is fed but not needed for fetching: ", name);
+ }
+ const Node* feed_node = node_it->second;
+
+ // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
+ // "_shape" attr if we can determine it. That way the graph will be
+ // initialized with whatever shapes we can infer, while the user can still
+ // explicitly specify or override them.
+ Node* arg_node = nullptr;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
+ .Attr("T", BaseType(feed_node->output_type(output_index)))
+ .Attr("index", arg_index)
+ .Attr(kFeedIdAttr, TensorIdToString(feed.id()))
+ .Attr(kShapeAttr, TensorShape(feed.shape()))
+ .Attr(kDebugNameAttr, feed.name())
+ .Finalize(graph, &arg_node));
+
+ // Collects out-edges from the feed node that have a matching edge index;
+ // these will be replaced with edges from the arg node instead.
+ //
+ // We must collect the edges first and process them in a second pass, since
+ // removing the edge from the graph invalidates feed_node->out_edges.
+ std::vector<const Edge*> feed_edges;
+ for (const Edge* edge : feed_node->out_edges()) {
+ if (edge->src_output() == output_index) {
+ feed_edges.push_back(edge);
+ }
+ }
+ for (const Edge* edge : feed_edges) {
+ graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
+ graph->RemoveEdge(edge);
+ }
+ }
+ return Status::OK();
+}
+
+// Each fetch id identifies the positional output of some node. For each fetch
+// node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
+Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
+ const protobuf::RepeatedPtrField<tf2xla::Fetch>& fetches,
+ std::unordered_set<const Node*>* retval_nodes) {
+ for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
+ const tf2xla::TensorId& id = fetches[ret_index].id();
+ auto it = node_map.find(id.node_name());
+ if (it == node_map.end()) {
+ return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
+ }
+ Node* fetch_node = it->second;
+ if (id.output_index() >= fetch_node->num_outputs()) {
+ return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
+ ", output index should be < ",
+ fetch_node->num_outputs());
+ }
+ // Connects fetch_node -> retval_node.
+ Node* retval_node = nullptr;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
+ .Input(fetch_node, id.output_index())
+ .Attr("T", BaseType(fetch_node->output_type(id.output_index())))
+ .Attr("index", ret_index)
+ .Attr(kFetchIdAttr, TensorIdToString(id))
+ .Finalize(graph, &retval_node));
+ retval_nodes->insert(retval_node);
+ }
+ return Status::OK();
+}
+
+// RewriteAndPruneGraph identifies input and output edges (named by the feed and
+// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
+// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
+// execution to know the input and output args for the generated function.
+Status RewriteAndPruneGraph(
+ Graph* graph, const tf2xla::Config& config,
+ const std::unordered_map<string, string>& feed_remapping) {
+ NodeMap node_map;
+ for (Node* n : graph->nodes()) {
+ node_map[n->name()] = n;
+ }
+ TF_RETURN_IF_ERROR(
+ AddArgNodes(graph, node_map, config.feed(), feed_remapping));
+ std::unordered_set<const Node*> retval_nodes;
+ TF_RETURN_IF_ERROR(
+ AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
+ VLOG(2) << "Post rewrite: "
+ << dump_graph::DumpGraphToFile("tf2xla_post_rewrite", *graph);
+ PruneForReverseReachability(graph, retval_nodes);
+ FixupSourceAndSinkEdges(graph);
+ VLOG(2) << "Post prune: "
+ << dump_graph::DumpGraphToFile("tfcompile_post_prune", *graph);
+ // Sanity-check, to make sure the feeds and fetches still exist post-pruning.
+ std::set<string> missing_feeds, missing_fetches;
+ for (const tf2xla::Feed& feed : config.feed()) {
+ missing_feeds.insert(TensorIdToString(feed.id()));
+ }
+ for (const tf2xla::Fetch& fetch : config.fetch()) {
+ missing_fetches.insert(TensorIdToString(fetch.id()));
+ }
+ for (const Node* n : graph->op_nodes()) {
+ if (n->type_string() == kArgOp) {
+ string feed_id;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
+ if (missing_feeds.erase(feed_id) == 0) {
+ return errors::Aborted(kArgOp,
+ " node found with unknown feed id: ", feed_id);
+ }
+ } else if (n->type_string() == kRetvalOp) {
+ string fetch_id;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
+ if (missing_fetches.erase(fetch_id) == 0) {
+ return errors::Aborted(kRetvalOp,
+ " node found with unknown fetch id: ", fetch_id);
+ }
+ }
+ }
+ if (!missing_feeds.empty() || !missing_fetches.empty()) {
+ return errors::Aborted(
+ "Post graph-pruning",
+ ", missing feeds: ", str_util::Join(missing_feeds, ", "),
+ ", missing fetches: ", str_util::Join(missing_fetches, ", "));
+ }
+ return Status::OK();
+}
+
+// CollectArgNodes collects _Arg nodes from the graph, and performs basic
+// sanity-checking to ensure the index and type attributes of each node are
+// initialized correctly.
+Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
+ std::map<int, Node*> indexed_arg_nodes;
+ for (Node* n : graph.nodes()) {
+ if (n->type_string() == kArgOp) {
+ int index;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
+ auto insert_result = indexed_arg_nodes.insert({index, n});
+ if (!insert_result.second) {
+ const Node* dup = insert_result.first->second;
+ return errors::InvalidArgument(
+ "Multiple ", kArgOp, " nodes with index ", index, ", ",
+ n->DebugString(), " and ", dup->DebugString());
+ }
+ }
+ }
+ arg_nodes->clear();
+ for (const auto& index_node : indexed_arg_nodes) {
+ if (index_node.first != arg_nodes->size()) {
+ return errors::InvalidArgument("Expected ", kArgOp, " node with index ",
+ arg_nodes->size(), ", but got index ",
+ index_node.first);
+ }
+ arg_nodes->push_back(index_node.second);
+ }
+ return Status::OK();
+}
+
+// Fills in xla_args from the corresponding _Arg nodes in the graph.
+Status CreateXlaArgs(const Graph& graph,
+ std::vector<XlaCompiler::Argument>* xla_args) {
+ std::vector<Node*> arg_nodes;
+ TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
+ for (const Node* node : arg_nodes) {
+ XlaCompiler::Argument arg;
+ arg.kind = XlaCompiler::Argument::kParameter;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
+ TensorShape shape;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
+ xla_args->push_back(arg);
+ }
+ return Status::OK();
+}
+
+// Converts the TensorFlow graph into an XLA computation, by executing the
+// graph symbolically, with each op building up the XLA HLO.
+Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
+ xla::Computation* computation,
+ bool* requires_runtime_context) {
+ // Create a device and context to convert the graph into an XLA computation.
+ XlaOpRegistry::RegisterCompilationKernels();
+ // Populate the context with args from the graph.
+ for (Node* node : graph->nodes()) {
+ node->set_assigned_device_name(DEVICE_CPU_XLA_JIT);
+ }
+ std::vector<XlaCompiler::Argument> xla_args;
+ TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
+
+ // Compile the graph into an XLA computation.
+ XlaCompiler::Options compiler_options;
+ compiler_options.client = client;
+ DeviceType device_type(DEVICE_CPU_XLA_JIT);
+ compiler_options.device_type = &device_type;
+ compiler_options.flib_def = &graph->flib_def();
+ compiler_options.graph_def_version = graph->versions().producer();
+ compiler_options.allow_cpu_custom_calls = true;
+ XlaCompiler compiler(compiler_options);
+
+ XlaCompiler::CompilationResult result;
+ TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(),
+ "tfcompile", std::move(graph),
+ xla_args, &result));
+ *requires_runtime_context = result.requires_runtime_context;
+ *computation = std::move(*result.computation);
+
+ int num_const_results = 0;
+ for (int i = 0; i < result.outputs.size(); ++i) {
+ // Ending up with const results (i.e. output args) is an error, since it
+ // means that one or more fetches that the user specified will be dropped
+ // from the generated function. It's most likely a configuration error,
+ // since the user shouldn't be asking for output args that end up as consts.
+ //
+ // TODO(toddw): Provide a way for the user to access const output args,
+ // e.g. perhaps hard-coded into the header, or somehow copied into the
+ // output buffers.
+ if (result.outputs[i].is_constant) {
+ ++num_const_results;
+ LOG(ERROR) << "ConstRetVal index:" << i
+ << " value:" << result.outputs[i].constant_value.DebugString();
+ }
+ }
+ if (num_const_results > 0) {
+ return errors::Unimplemented(
+ "Conversion from TensorFlow graph to XLA resulted in ",
+ num_const_results,
+ " constant results. The configuration of "
+ "the output args (i.e. fetch ids) is probably wrong.");
+ }
+ if (computation->IsNull()) {
+ return errors::Aborted(
+ "Conversion from TensorFlow graph to XLA resulted in an empty "
+ "computation.");
+ }
+ return Status::OK();
+}
+
+// InitGraph creates a graph based on the graph_def, that may then be converted
+// to an xla::Computation via ConvertGraphToXla.
+//
+// The graph is rewritten with _Arg and _Retval nodes, representing the inputs
+// and outputs of the function that will be compiled. Each feed id causes a new
+// _Arg node to be created, where we first collect all existing edges pointing
+// from the named node's output index, and then rewrite them to point from that
+// _Arg node instead. Each fetch id causes a new _Retval node to be created,
+// with a new edge pointing from the named node's output index to that _Retval
+// node.
+Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
+ std::unique_ptr<Graph>* graph) {
+ TF_RETURN_IF_ERROR(ValidateConfig(config));
+
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
+ std::unique_ptr<Graph> g(new Graph(flib_def));
+
+ // Replace references to fed tensors with references to newly added
+ // placeholders.
+ GraphDef first_copy_def = graph_def;
+
+ // Maps from name:port of a feed to the name:port of the placeholder to use.
+ std::unordered_map<string, string> feed_remapping;
+ TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
+ &feed_remapping, &first_copy_def));
+
+ // Prune the GraphDef first so that unknown ops that we aren't compiling get
+ // filtered out.
+ GraphDef second_copy_def;
+ TF_RETURN_IF_ERROR(
+ PruneGraphDefInto(config, first_copy_def, &second_copy_def));
+
+ TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
+ &second_copy_def, *g->op_registry(), /*node_offset=*/0));
+
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
+ second_copy_def, g.get()));
+ TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
+ *graph = std::move(g);
+ return Status::OK();
+}
+
+} // namespace
+
+Status ConvertGraphDefToXla(const GraphDef& graph_def,
+ const tf2xla::Config& config, xla::Client* client,
+ xla::Computation* computation,
+ bool* requires_runtime_context) {
+ std::unique_ptr<Graph> graph;
+ TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph));
+ TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation,
+ requires_runtime_context));
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h
new file mode 100644
index 0000000000..ab99beebf7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/tf2xla.h
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
+#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
+
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/core/framework/graph.pb.h"
+
+namespace tensorflow {
+
+// Converts a tensorflow::GraphDef into an xla::Computation. The given `config`
+// specifies the portion of the graph to convert, via feeds and fetches. Each
+// feed is a positional input argument for the generated computation, while each
+// fetch is a positional output argument.
+//
+// The computation is built in the context of the given `client`, which may
+// subsequently be used to compile or execute the computation.
+//
+// If `requires_runtime_context` is filled with true, this indicates the last
+// argument of the computation is XlaLocalRuntimeContext*.
+Status ConvertGraphDefToXla(const GraphDef& graph_def,
+ const tf2xla::Config& config, xla::Client* client,
+ xla::Computation* computation,
+ bool* requires_runtime_context);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
diff --git a/tensorflow/compiler/aot/tfcompile.proto b/tensorflow/compiler/tf2xla/tf2xla.proto
index cd83840d89..18c9089f5f 100644
--- a/tensorflow/compiler/aot/tfcompile.proto
+++ b/tensorflow/compiler/tf2xla/tf2xla.proto
@@ -1,10 +1,10 @@
syntax = "proto3";
-package tensorflow.tfcompile;
+package tensorflow.tf2xla;
option cc_enable_arenas = true;
-option java_outer_classname = "CompileProtos";
+option java_outer_classname = "Tf2XlaProtos";
option java_multiple_files = true;
-option java_package = "org.tensorflow.tfcompile";
+option java_package = "org.tensorflow.tf2xla";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
@@ -19,32 +19,32 @@ message TensorId {
};
// Feed represents a single feed tensor in the graph, which corresponds to an
-// input argument for the generated function.
+// input argument for the generated computation.
message Feed {
TensorId id = 1;
TensorShapeProto shape = 2;
string name = 3; // Optional name for generated code.
// Optional data type. This is not normally required, as the graph itself
- // contains this information. However, if the node being fed is an op that
- // is not linked into the tfcompile binary, then the type cannot be inferred
- // from the node; in this case, the type should be set here.
+ // contains this information. However, if the node being fed is an op that is
+ // not linked into the binary, then the type cannot be inferred from the node;
+ // in this case, the type should be set here.
DataType type = 4;
};
// Fetch represents a single fetch tensor in the graph, which corresponds to an
-// output argument for the generated function.
+// output argument for the generated computation.
message Fetch {
TensorId id = 1;
string name = 2; // Optional name for generated code.
};
-// Config represents configuration information for tfcompile.
+// Config represents configuration information for tf2xla conversion.
message Config {
- // Each feed is a positional input argument for the generated function. The
- // order of each entry matches the order of each input argument.
+ // Each feed is a positional input argument for the generated computation.
+ // The order of each entry matches the order of each input argument.
repeated Feed feed = 1;
- // Each fetch is a positional output argument for the generated function. The
- // order of each entry matches the order of each output argument.
+ // Each fetch is a positional output argument for the generated computation.
+ // The order of each entry matches the order of each output argument.
repeated Fetch fetch = 2;
};
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
new file mode 100644
index 0000000000..57b53cc660
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -0,0 +1,99 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/tf2xla.h"
+
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+AttrValue TypeAttrValue(DataType type) {
+ AttrValue attr_value;
+ SetAttrValue(type, &attr_value);
+ return attr_value;
+}
+
+GraphDef SumGraph() {
+ GraphDef graph_def;
+ NodeDef* x = graph_def.add_node();
+ x->set_name("x");
+ x->set_op("Placeholder");
+ (*x->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32);
+ NodeDef* y = graph_def.add_node();
+ y->set_name("y");
+ y->set_op("Placeholder");
+ (*y->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32);
+ NodeDef* sum = graph_def.add_node();
+ sum->set_name("sum");
+ sum->set_op("Add");
+ sum->add_input("x");
+ sum->add_input("y");
+ (*sum->mutable_attr())["T"] = TypeAttrValue(DT_INT32);
+ return graph_def;
+}
+
+tf2xla::Config SumConfig() {
+ tf2xla::Config config;
+ config.add_feed()->mutable_id()->set_node_name("x");
+ config.add_feed()->mutable_id()->set_node_name("y");
+ config.add_fetch()->mutable_id()->set_node_name("sum");
+ return config;
+}
+
+TEST(ConvertGraphDefToXla, Sum) {
+ GraphDef graph_def = SumGraph();
+ tf2xla::Config config = SumConfig();
+
+ xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
+ xla::Computation computation;
+ bool requires_runtime_context;
+ TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation,
+ &requires_runtime_context));
+ ASSERT_FALSE(requires_runtime_context);
+
+ // Set up arguments.
+ auto x_literal = xla::Literal::CreateR0<int32>(10);
+ auto y_literal = xla::Literal::CreateR0<int32>(32);
+ auto x_global_or = client->TransferToServer(*x_literal);
+ auto y_global_or = client->TransferToServer(*y_literal);
+ TF_EXPECT_OK(x_global_or.status());
+ TF_EXPECT_OK(y_global_or.status());
+ std::unique_ptr<xla::GlobalData> x_global =
+ std::move(x_global_or.ValueOrDie());
+ std::unique_ptr<xla::GlobalData> y_global =
+ std::move(y_global_or.ValueOrDie());
+
+ // Execute and check result.
+ auto result_or =
+ client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
+ TF_EXPECT_OK(result_or.status());
+ std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
+ EXPECT_EQ("42", result->ToString());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/tfcompile_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index 629187d621..14e0910cab 100644
--- a/tensorflow/compiler/aot/tfcompile_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/aot/tfcompile_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include <queue>
#include <set>
#include <unordered_map>
-#include "tensorflow/compiler/aot/tfcompile.pb.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -29,21 +29,13 @@ limitations under the License.
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
-namespace tfcompile {
namespace {
-bool IsAlpha(char c) {
- return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
-}
-
-bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
-
-Status ValidateTensorId(const TensorId& id) {
+Status ValidateTensorId(const tf2xla::TensorId& id) {
if (id.node_name().empty()) {
return errors::InvalidArgument("TensorId node_name must be non-empty");
}
@@ -53,10 +45,9 @@ Status ValidateTensorId(const TensorId& id) {
return Status::OK();
}
-Status ValidateFeedFetchName(const string& kind, const string& name,
- std::set<string>* names) {
+Status CheckNameDuplicates(const string& kind, const string& name,
+ std::set<string>* names) {
if (!name.empty()) {
- TF_RETURN_IF_ERROR(ValidateCppIdent(name, kind + " name"));
if (!names->insert(name).second) {
return errors::InvalidArgument("duplicate ", kind, " name: ", name);
}
@@ -80,42 +71,18 @@ Status CheckFeedFetchNameConflicts(const string& kind,
} // namespace
-Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
- if (ident.empty()) {
- return errors::InvalidArgument("empty identifier: ", msg);
- }
- // Require that the identifier starts with a nondigit, and is composed of
- // nondigits and digits, as specified in section [2.11 Identifiers] of the
- // C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is
- // defined as [0-9].
- //
- // Technically the standard also allows for `universal-character-name`, with a
- // table of allowed unicode ranges, as well as `other implementation-defined
- // characters`. We disallow those here to give better error messages, at the
- // expensive of being more restrictive than the standard.
- if (ident[0] != '_' && !IsAlpha(ident[0])) {
- return errors::InvalidArgument("illegal leading char: ", msg);
- }
- for (size_t pos = 1; pos < ident.size(); ++pos) {
- if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
- return errors::InvalidArgument("illegal char: ", msg);
- }
- }
- return Status::OK();
-}
-
-Status ValidateConfig(const Config& config) {
+Status ValidateConfig(const tf2xla::Config& config) {
std::set<string> names;
- for (const Feed& feed : config.feed()) {
+ for (const tf2xla::Feed& feed : config.feed()) {
TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
- TF_RETURN_IF_ERROR(ValidateFeedFetchName("feed", feed.name(), &names));
+ TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
}
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
names.clear();
- for (const Fetch& fetch : config.fetch()) {
+ for (const tf2xla::Fetch& fetch : config.fetch()) {
TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
- TF_RETURN_IF_ERROR(ValidateFeedFetchName("fetch", fetch.name(), &names));
+ TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
}
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
if (config.feed().empty() || config.fetch().empty()) {
@@ -125,10 +92,10 @@ Status ValidateConfig(const Config& config) {
}
Status AddPlaceholdersForFeeds(
- const Config& config, const OpRegistryInterface* op_registry,
+ const tf2xla::Config& config, const OpRegistryInterface* op_registry,
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
struct PlaceholderInfo {
- const Feed* feed = nullptr; // point to Feed in <config>.
+ const tf2xla::Feed* feed = nullptr; // point to Feed in <config>.
string placeholder_name;
DataType data_type = DT_INVALID;
};
@@ -137,9 +104,9 @@ Status AddPlaceholdersForFeeds(
// when creating placeholders (genrules want deterministic output).
std::map<string, PlaceholderInfo> placeholder_info;
for (int i = 0; i < config.feed_size(); ++i) {
- const Feed* feed = &config.feed(i);
+ const tf2xla::Feed* feed = &config.feed(i);
const string name_port = TensorIdToString(feed->id());
- auto& info = placeholder_info[name_port];
+ PlaceholderInfo& info = placeholder_info[name_port];
info.feed = feed;
info.placeholder_name = strings::StrCat(
"aot_feed_", feed->id().output_index(), "/", feed->id().node_name());
@@ -153,7 +120,7 @@ Status AddPlaceholdersForFeeds(
}
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
PlaceholderInfo& info = it->second;
- const TensorId& feed_id = info.feed->id();
+ const tf2xla::TensorId& feed_id = info.feed->id();
// Find the existing node and determine data type.
auto node_it = name_to_node.find(feed_id.node_name());
@@ -214,16 +181,16 @@ Status AddPlaceholdersForFeeds(
return Status::OK();
}
-Status PruneGraphDefInto(const Config& config, const GraphDef& in,
+Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
GraphDef* out) {
*out = in;
out->clear_node();
// Tensors needed for feeding.
std::set<std::pair<string, int>> feed_tensors;
- for (const auto& feed_config : config.feed()) {
- feed_tensors.insert(std::make_pair(feed_config.id().node_name(),
- feed_config.id().output_index()));
+ for (const tf2xla::Feed& feed : config.feed()) {
+ feed_tensors.insert(
+ std::make_pair(feed.id().node_name(), feed.id().output_index()));
}
// Maps node name to reachability.
@@ -279,9 +246,8 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
return Status::OK();
}
-string TensorIdToString(const TensorId& id) {
+string TensorIdToString(const tf2xla::TensorId& id) {
return strings::StrCat(id.node_name(), ":", id.output_index());
}
-} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/tfcompile_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index 365f7b0e7b..a29d0c16f9 100644
--- a/tensorflow/compiler/aot/tfcompile_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -13,26 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
-#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
+#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
#include <unordered_map>
-#include "tensorflow/compiler/aot/tfcompile.pb.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
-namespace tfcompile {
-
-// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is
-// appended to error messages.
-Status ValidateCppIdent(StringPiece ident, StringPiece msg);
// ValidateConfig returns OK iff config is valid.
-Status ValidateConfig(const Config& config);
+Status ValidateConfig(const tf2xla::Config& config);
// Modifies <graph_def> to include placeholders for each fed tensor, and
// update references to the fed tensors to refer to the placeholders.
@@ -40,18 +34,17 @@ Status ValidateConfig(const Config& config);
// (except where their input edges are modified by the replacement of other
// feeds).
Status AddPlaceholdersForFeeds(
- const Config& config, const OpRegistryInterface* op_registry,
+ const tf2xla::Config& config, const OpRegistryInterface* op_registry,
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
// Returns in <out> a copy of <in>, pruned to only include fetches from
// <config>.
-Status PruneGraphDefInto(const Config& config, const GraphDef& in,
+Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
GraphDef* out);
// Returns node:port for the given <id>.
-string TensorIdToString(const TensorId& id);
+string TensorIdToString(const tf2xla::TensorId& id);
-} // namespace tfcompile
} // namespace tensorflow
-#endif // TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
+#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
diff --git a/tensorflow/compiler/aot/tfcompile_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc
index 5a92851ceb..b98c89f284 100644
--- a/tensorflow/compiler/aot/tfcompile_util_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/aot/tfcompile_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
-namespace tfcompile {
namespace {
void ExpectErrorContains(const Status& status, StringPiece str) {
@@ -32,45 +31,16 @@ void ExpectErrorContains(const Status& status, StringPiece str) {
<< "expected error: " << status.error_message() << " to contain: " << str;
}
-TEST(ValidateCppIdent, Simple) {
- TF_EXPECT_OK(ValidateCppIdent("a", ""));
- TF_EXPECT_OK(ValidateCppIdent("abc", ""));
- TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
- TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
- // Make sure we didn't skip a valid letter or digit
- string ident;
- for (char c = 'a'; c <= 'z'; c++) {
- ident.append(1, c);
- }
- for (char c = 'A'; c <= 'Z'; c++) {
- ident.append(1, c);
- }
- for (char c = '0'; c <= '9'; c++) {
- ident.append(1, c);
- }
- ident += "_";
- TF_EXPECT_OK(ValidateCppIdent(ident, ""));
-
- ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
- ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
- ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
- ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
- ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
- ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
- ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
- ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
-}
-
TEST(ValidateConfig, Good) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->mutable_id()->set_output_index(123);
feed->set_name("foo_debug");
feed = config.add_feed();
feed->mutable_id()->set_node_name("bar");
feed->mutable_id()->set_output_index(0);
- Fetch* fetch = config.add_fetch();
+ tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("baz");
fetch->mutable_id()->set_output_index(456);
fetch->set_name("baz_debug");
@@ -81,62 +51,62 @@ TEST(ValidateConfig, Good) {
}
TEST(ValidateConfig, BadEmpty) {
- Config config;
+ tf2xla::Config config;
ExpectErrorContains(ValidateConfig(config),
"feeds and fetches must be specified");
}
TEST(ValidateConfig, BadNoFeed) {
- Config config;
- Fetch* fetch = config.add_fetch();
+ tf2xla::Config config;
+ tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("foo");
ExpectErrorContains(ValidateConfig(config),
"feeds and fetches must be specified");
}
TEST(ValidateConfig, BadNoFetch) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
ExpectErrorContains(ValidateConfig(config),
"feeds and fetches must be specified");
}
TEST(ValidateConfig, BadFeedNodeName) {
- Config config;
+ tf2xla::Config config;
config.add_feed();
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
}
TEST(ValidateConfig, BadFeedOutputIndex) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->mutable_id()->set_output_index(-1);
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
}
TEST(ValidateConfig, BadFetchNodeName) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
config.add_fetch();
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
}
TEST(ValidateConfig, BadFetchOutputIndex) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
- Fetch* fetch = config.add_fetch();
+ tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("bar");
fetch->mutable_id()->set_output_index(-1);
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
}
TEST(ValidateConfig, DuplicateFeedName) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->set_name("dup");
feed = config.add_feed();
@@ -146,10 +116,10 @@ TEST(ValidateConfig, DuplicateFeedName) {
}
TEST(ValidateConfig, DuplicateFetchName) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
- Fetch* fetch = config.add_fetch();
+ tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("bar");
fetch->set_name("dup");
fetch = config.add_fetch();
@@ -159,8 +129,8 @@ TEST(ValidateConfig, DuplicateFetchName) {
}
TEST(ValidateConfig, ConflictingFeedName) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->set_name("conflict");
feed = config.add_feed();
@@ -170,10 +140,10 @@ TEST(ValidateConfig, ConflictingFeedName) {
}
TEST(ValidateConfig, ConflictingFetchName) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
- Fetch* fetch = config.add_fetch();
+ tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("bar");
fetch->set_name("conflict");
fetch = config.add_fetch();
@@ -182,8 +152,8 @@ TEST(ValidateConfig, ConflictingFetchName) {
ExpectErrorContains(ValidateConfig(config), "conflicting fetch name");
}
-static Config FetchesConfig(std::vector<string> fetches) {
- Config config;
+static tf2xla::Config FetchesConfig(std::vector<string> fetches) {
+ tf2xla::Config config;
for (const auto& fetch_node_name : fetches) {
auto* fetch = config.add_fetch();
fetch->set_name(strings::StrCat("fetch_", fetch_node_name));
@@ -242,5 +212,4 @@ TEST(PruneGraphDefInto, Basic) {
}
} // namespace
-} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index 30afaed732..e41a391ac5 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -1703,7 +1703,6 @@ StatusOr<Computation> ComputationBuilder::Build() {
}
void ComputationBuilder::AddOpMetadata(OpRequest* request) const {
- tensorflow::mutex_lock lock(mutex_);
*request->mutable_metadata() = metadata_;
}
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index cf1f3b074e..96db56bc53 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -37,7 +37,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stacktrace.h"
#include "tensorflow/core/platform/types.h"
@@ -57,10 +56,10 @@ class ComputationBuilder {
~ComputationBuilder();
// Returns the client the builder was initialized with.
- Client* client() { return client_; }
+ Client* client() const { return client_; }
// Returns the computation name.
- const string& name() { return name_; }
+ const string& name() const { return name_; }
// Sets OpMetadata that will be added to all instructions until cleared.
//
@@ -69,13 +68,11 @@ class ComputationBuilder {
// instructions generated via this Computation Builder will have the same
// OpMetadata attached until a call to ClearOpMetdata.
void SetOpMetadata(const OpMetadata& metadata) {
- tensorflow::mutex_lock lock(mutex_);
metadata_ = metadata;
}
// Clears the HloMetdata state.
void ClearOpMetadata() {
- tensorflow::mutex_lock lock(mutex_);
metadata_.Clear();
}
@@ -826,15 +823,12 @@ class ComputationBuilder {
Client* client_;
// Mode bit that indicates whether to die when a first error is encountered.
- bool die_immediately_on_error_{false};
-
- // Mutex to guard against concurrent access to metadata_.
- mutable tensorflow::mutex mutex_;
+ bool die_immediately_on_error_ = false;
// The metadata to attach to each op. This is structured as a "modal"-like
// operation, in order to simplify client code (and not sprinkle this metadata
// throughout the TensorFlow op kernel implementations).
- OpMetadata metadata_ GUARDED_BY(mutex_);
+ OpMetadata metadata_;
TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder);
};
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 95f8165795..1a18b28cbb 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -180,15 +180,18 @@ cc_library(
cc_library(
name = "ir_emitter",
- srcs = ["ir_emitter.cc"],
+ srcs = [
+ "elemental_ir_emitter.cc",
+ "ir_emitter.cc",
+ ],
hdrs = [
+ "elemental_ir_emitter.h",
"ir_emitter.h",
],
deps = [
":cpu_options",
":cpu_runtime",
":dot_op_emitter",
- ":elemental_ir_emitter",
":ir_emission_utils",
":simple_orc_jit",
"//tensorflow/compiler/xla:shape_util",
@@ -526,22 +529,6 @@ cc_library(
)
cc_library(
- name = "elemental_ir_emitter",
- srcs = ["elemental_ir_emitter.cc"],
- hdrs = ["elemental_ir_emitter.h"],
- deps = [
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:elemental_ir_emitter",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
- "@llvm//:core",
- ],
-)
-
-cc_library(
name = "ir_emission_utils",
srcs = ["ir_emission_utils.cc"],
hdrs = ["ir_emission_utils.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
index 511f89144a..902309b338 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
@@ -50,14 +50,6 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return false;
}
- // Producer or consumer cannot be Map. Maps are technically elementwise but
- // of a slightly different form (call instead of a computation). These are not
- // yet supported in the CPU backend.
- if (producer->opcode() == HloOpcode::kMap ||
- consumer->opcode() == HloOpcode::kMap) {
- return false;
- }
-
// Cost condition: not fuse (simple, expensive producers) and (consumers who
// reuse operand elements).
if (producer->opcode() != HloOpcode::kFusion &&
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 0fc62281a0..b56466d5e4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -209,6 +209,31 @@ class OpcodeFusionTest : public InstructionFusionTest {
std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()),
expected_opcodes);
}
+
+ HloComputation* CreateAdderToOne(HloModule* module) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* arg0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {}), "arg0"));
+ HloInstruction* one = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one));
+ return module->AddEmbeddedComputation(builder.Build());
+ }
+
+ HloComputation* CreateMax(HloModule* module) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* arg0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {}), "arg0"));
+ HloInstruction* arg1 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {}), "arg1"));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, arg0, arg1));
+ return module->AddEmbeddedComputation(builder.Build());
+ }
};
TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) {
@@ -402,6 +427,49 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) {
HloOpcode::kParameter});
}
+TEST_F(OpcodeFusionTest, UnaryMapOfExp) {
+ auto module = CreateNewModule();
+
+ HloComputation::Builder builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+
+ HloInstruction* exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
+ builder.AddInstruction(HloInstruction::CreateMap(
+ shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{}));
+
+ module->AddEntryComputation(builder.Build());
+
+ RunFusionAndCheckOpcodesWereFused(
+ module.get(), {HloOpcode::kParameter, HloOpcode::kExp, HloOpcode::kMap});
+}
+
+TEST_F(OpcodeFusionTest, BinaryMapOfExps) {
+ auto module = CreateNewModule();
+
+ HloComputation::Builder builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, shape, "param"));
+
+ HloInstruction* exp0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
+ HloInstruction* exp1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1));
+
+ builder.AddInstruction(HloInstruction::CreateMap(
+ shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{}));
+
+ module->AddEntryComputation(builder.Build());
+
+ RunFusionAndCheckOpcodesWereFused(
+ module.get(), {HloOpcode::kParameter, HloOpcode::kParameter,
+ HloOpcode::kExp, HloOpcode::kExp, HloOpcode::kMap});
+}
} // namespace
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index fe447adf89..73e039250b 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -64,5 +64,25 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
}
}
+llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
+ const HloInstruction* hlo,
+ const HloToElementGeneratorMap& operand_to_generator) const {
+ if (hlo->opcode() == HloOpcode::kMap) {
+ return [this, hlo, &operand_to_generator](
+ const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ std::vector<llvm::Value*> operands;
+ for (int i = 0; i < hlo->operand_count(); i++) {
+ TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
+ operand_to_generator.at(hlo->operand(i))(
+ ElementwiseSourceIndex(index, *hlo, 0)));
+ operands.push_back(operand_value);
+ }
+ return ir_emitter_->EmitScalarCall(hlo->shape().element_type(),
+ hlo->to_apply(), operands,
+ llvm_ir::IrName(hlo));
+ };
+ }
+ return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator);
+}
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index 6f9d6a24b4..7e9f27befb 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -29,12 +30,19 @@ namespace cpu {
class CpuElementalIrEmitter : public ElementalIrEmitter {
public:
CpuElementalIrEmitter(const HloModuleConfig& module_config,
- llvm::IRBuilder<>* ir_builder, llvm::Module* module)
- : ElementalIrEmitter(module_config, module, ir_builder) {}
+ IrEmitter* ir_emitter, llvm::Module* module)
+ : ElementalIrEmitter(module_config, module, ir_emitter->ir_builder()),
+ ir_emitter_(ir_emitter) {}
+
+ llvm_ir::ElementGenerator MakeElementGenerator(
+ const HloInstruction* hlo,
+ const HloToElementGeneratorMap& operand_to_generator) const override;
protected:
StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;
+
+ IrEmitter* ir_emitter_;
};
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
index 94d4ce4a94..91b09f2472 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
@@ -136,6 +136,10 @@ DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr(
const int64 kReductionDimensionThresholdBytes = 8 * 1024;
const bool single_threaded_eigen =
!dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen();
+
+ // This is the point at which it is better to call into Eigen and shard the
+ // dot across multiple worker threads. This is a rough estimate by running
+ // a matmult benchmark on my local machine, and it can be tuned further.
const int64 kMaxSingleThreadedFlops = 16 * 1024;
const int64 M = result_shape.dimensions(0);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index bc51ad2b36..8cd8740ee8 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -2354,8 +2354,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArrayForOp(operand));
}
- CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_,
- module_);
+ CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
@@ -2737,14 +2736,10 @@ llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction* hlo) {
}
prof_counter_idx = it->second;
- uintptr_t hlo_address = reinterpret_cast<uintptr_t>(hlo);
- counter_name = tensorflow::strings::StrCat(
- "prof_counter_0x",
- tensorflow::strings::Hex(
- hlo_address, tensorflow::strings::PadSpec(sizeof(hlo_address))));
+ counter_name = IrName("prof_counter", hlo->name());
} else {
prof_counter_idx = hlo_to_profile_idx_->size();
- counter_name = "prof_counter_computation";
+ counter_name = "prof_counter.computation";
}
return ir_builder_.CreateGEP(GetProfileCountersArgument(),
ir_builder_.getInt64(prof_counter_idx),
@@ -3180,12 +3175,27 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
return GetIrArrayForOp(operand).EmitReadArrayElement(index, &ir_builder_);
};
}
- CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_,
- module_);
+ CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
return EmitTargetElementLoop(
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
}
+StatusOr<llvm::Value*> IrEmitter::EmitScalarCall(
+ PrimitiveType return_type, HloComputation* computation,
+ const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) {
+ llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation);
+ std::vector<llvm::Value*> argument_addrs;
+ for (auto argument : arguments) {
+ llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ argument->getType(), "arg_addr", &ir_builder_);
+ ir_builder_.CreateStore(argument, argument_addr);
+ argument_addrs.push_back(argument_addr);
+ }
+ return EmitElementFunctionCall(llvm_function,
+ ShapeUtil::MakeShape(return_type, {}),
+ argument_addrs, name);
+}
+
unsigned TargetMachineFeatures::largest_register_size_in_bytes(
llvm::Function* function) {
auto itr = largest_register_size_in_bytes_.find(function);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index bcd33c3810..fa33a1eb7b 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -133,6 +133,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
bool is_top_level_computation,
std::vector<const HloInstruction*>* instruction_order);
+ llvm::IRBuilder<>* ir_builder() { return &ir_builder_; }
+
+ // Emits a call to `computation` with scalar arguments `arguments`.
+ StatusOr<llvm::Value*> EmitScalarCall(
+ PrimitiveType return_type, HloComputation* computation,
+ const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
+
protected:
//
// The following methods implement the DfsHloVisitor interface.
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
index 5d650b872f..b24fe417ff 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
@@ -76,10 +76,11 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path,
// Since CUDA 9.0, all GPU versions are included in a single file
const char* unified_libdevice_filename = "libdevice.10.bc";
std::vector<string> unified_libdevice_files;
- tensorflow::Env::Default()->GetMatchingPaths(
+ const tensorflow::Status status =
+ tensorflow::Env::Default()->GetMatchingPaths(
tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename),
&unified_libdevice_files);
- if( unified_libdevice_files.size() == 1 ) {
+ if (status.ok() && unified_libdevice_files.size() == 1) {
return unified_libdevice_filename;
}
// There are only four libdevice files: compute_{20,30,35,50}. Each GPU
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index efb5fca188..e0c23a3a08 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -77,7 +77,7 @@ class HloOrdering {
// Precondition: 'a' and 'b' are in the same computation.
//
// Derived classes should implement this method for determining order of
- // instructions in the same comptuation. ExecutesBefore() analyzes the
+ // instructions in the same computation. ExecutesBefore() analyzes the
// callgraph and uses this method to determine ordering of instructions in
// different computations.
virtual bool ExecutesBeforeInSameComputation(
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 1865004911..a0f9be3dd8 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -50,7 +50,7 @@ class WhileTest : public ClientLibraryTestBase {};
// while (result < 5) {
// result = result + 1;
// }
-TEST_F(WhileTest, WhileWithScalarResult) {
+TEST_F(WhileTest, WhileWithScalarS32Result) {
auto result_shape = ShapeUtil::MakeShape(S32, {});
// Create a computation for the condition: repeat for 5 iterations.
@@ -81,6 +81,43 @@ TEST_F(WhileTest, WhileWithScalarResult) {
ComputeAndCompareR0<int32>(&builder, 5, {});
}
+// Tests a while node when the result type T is S64.
+//
+// int32 result = 0;
+// while (result < 5) {
+// result = result + 1;
+// }
+TEST_F(WhileTest, WhileWithScalarS64Result) {
+ auto result_shape = ShapeUtil::MakeShape(S64, {});
+
+ // Create a computation for the condition: repeat for 5 iterations.
+ Computation condition;
+ {
+ ComputationBuilder builder(client_, "condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ builder.Gt(builder.ConstantR0<int64>(5), prev);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the body: add 1 to the result variable.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto input = builder.ConstantR0<int64>(1);
+ auto result = builder.Add(input, prev);
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ ComputationBuilder builder(client_, TestName());
+ auto init = builder.ConstantR0<int64>(0);
+ auto result = builder.While(condition, body, init);
+ auto shape = builder.GetShape(result).ConsumeValueOrDie();
+
+ ComputeAndCompareR0<int64>(&builder, 5, {});
+}
+
TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
auto result_shape = ShapeUtil::MakeShape(S32, {});
auto orig_shape = ShapeUtil::MakeShape(S32, {2});
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 11e4ea888c..5e7d513d17 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -33,6 +33,7 @@ py_library(
"//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
"//tensorflow/contrib/hooks",
+ "//tensorflow/contrib/image:distort_image_py",
"//tensorflow/contrib/image:image_py",
"//tensorflow/contrib/image:single_image_random_dot_stereograms_py",
"//tensorflow/contrib/imperative",
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
index 090621f29e..f60bd8282c 100644
--- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
+++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
@@ -20,10 +20,10 @@ import android.os.Build.VERSION;
import android.os.Trace;
import android.text.TextUtils;
import android.util.Log;
+import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
-import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
@@ -79,24 +79,32 @@ public class TensorFlowInferenceInterface {
throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
-
+
try {
if (VERSION.SDK_INT >= 18) {
Trace.beginSection("initializeTensorFlow");
Trace.beginSection("readGraphDef");
}
-
+
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
byte[] graphDef = new byte[is.available()];
+ final int numBytesRead = is.read(graphDef);
+ if (numBytesRead != graphDef.length) {
+ throw new IOException(
+ "read error: read only "
+ + numBytesRead
+ + " of the graph, expected to read "
+ + graphDef.length);
+ }
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // readGraphDef.
}
-
+
loadGraph(graphDef, g);
is.close();
Log.i(TAG, "Successfully loaded model from '" + model + "'");
-
+
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // initializeTensorFlow.
}
@@ -121,13 +129,13 @@ public class TensorFlowInferenceInterface {
this.g = new Graph();
this.sess = new Session(g);
this.runner = sess.runner();
-
+
try {
if (VERSION.SDK_INT >= 18) {
Trace.beginSection("initializeTensorFlow");
Trace.beginSection("readGraphDef");
}
-
+
int baosInitSize = is.available() > 16384 ? is.available() : 16384;
ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
int numBytesRead;
@@ -143,7 +151,7 @@ public class TensorFlowInferenceInterface {
loadGraph(graphDef, g);
Log.i(TAG, "Successfully loaded model from the input stream");
-
+
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // initializeTensorFlow.
}
@@ -309,8 +317,8 @@ public class TensorFlowInferenceInterface {
/**
* Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued
- * scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of
- * bytes, not a Java {@code String} (which is a sequence of characters).
+ * scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of bytes, not
+ * a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[] src) {
addFeed(inputName, Tensor.create(src));
@@ -318,9 +326,8 @@ public class TensorFlowInferenceInterface {
/**
* Copy an array of byte sequences into the input Tensor with name {@link inputName} as a
- * string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string"
- * is an arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of
- * characters).
+ * string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" is an
+ * arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[][] src) {
addFeed(inputName, Tensor.create(src));
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
index a8b60460c8..7773125c16 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
@@ -151,7 +151,7 @@ def convert_to_universal_format(dtec, sorted_feature_names,
generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
inequality_test.threshold.float_value = split.threshold
elif node_type == "sparse_float_binary_split_default_right":
- split = gtflow_node.sparse_float_binary_split_default_right
+ split = gtflow_node.sparse_float_binary_split_default_right.split
node.default_direction = (
generic_tree_model_pb2.BinaryNode.RIGHT)
feature_id = split.feature_column + num_dense
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 7d63f5c39e..e2031a6483 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -329,7 +329,6 @@ add_python_module("tensorflow/contrib/cudnn_rnn/python/kernel_tests")
add_python_module("tensorflow/contrib/cudnn_rnn/python/ops")
add_python_module("tensorflow/contrib/data")
add_python_module("tensorflow/contrib/data/python")
-add_python_module("tensorflow/contrib/data/python/framework")
add_python_module("tensorflow/contrib/data/python/kernel_tests")
add_python_module("tensorflow/contrib/data/python/ops")
add_python_module("tensorflow/contrib/data/python/util")
@@ -362,6 +361,8 @@ add_python_module("tensorflow/contrib/framework/python/framework")
add_python_module("tensorflow/contrib/framework/python/ops")
add_python_module("tensorflow/contrib/gan")
add_python_module("tensorflow/contrib/gan/python")
+add_python_module("tensorflow/contrib/gan/python/eval")
+add_python_module("tensorflow/contrib/gan/python/eval/python")
add_python_module("tensorflow/contrib/gan/python/features")
add_python_module("tensorflow/contrib/gan/python/features/python")
add_python_module("tensorflow/contrib/gan/python/losses")
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index b9cf111224..be7664d087 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -20,6 +20,7 @@
@@FixedLengthRecordDataset
@@TextLineDataset
+@@batch_and_drop_remainder
@@read_batch_features
@@rejection_resample
@@group_by_window
@@ -32,6 +33,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
+from tensorflow.contrib.data.python.ops.dataset_ops import batch_and_drop_remainder
from tensorflow.contrib.data.python.ops.dataset_ops import Dataset
from tensorflow.contrib.data.python.ops.dataset_ops import FixedLengthRecordDataset
from tensorflow.contrib.data.python.ops.dataset_ops import group_by_window
diff --git a/tensorflow/contrib/data/python/framework/BUILD b/tensorflow/contrib/data/python/framework/BUILD
deleted file mode 100644
index c3c4911af4..0000000000
--- a/tensorflow/contrib/data/python/framework/BUILD
+++ /dev/null
@@ -1,48 +0,0 @@
-package(default_visibility = ["//tensorflow:internal"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_library(
- name = "function",
- srcs = ["function.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python:graph_to_function_def",
- "//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
- ],
-)
-
-py_test(
- name = "function_test",
- size = "medium",
- srcs = ["function_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":function",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- ],
-)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/data/python/framework/function.py b/tensorflow/contrib/data/python/framework/function.py
deleted file mode 100644
index 171ed7c496..0000000000
--- a/tensorflow/contrib/data/python/framework/function.py
+++ /dev/null
@@ -1,275 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""An experimental fork of the Python TensorFlow-function library.
-
-NOTE: functions are currently experimental and subject to change!
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.eager import context
-from tensorflow.python.framework import function
-from tensorflow.python.framework import graph_to_function_def
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import variable_scope as vs
-from tensorflow.python.util import tf_inspect
-
-# NOTE(mrry): This is an experimental extension of a core class that wasn't
-# designed to be extended, so we disable protected access checks for the
-# whole file.
-# pylint: disable=protected-access
-
-
-class _ExperimentalFuncGraph(function._FuncGraph):
- """A helper for construction a function (supporting capture-by-value).
-
- _ExperimentalFuncGraph overrides ops.Graph's create_op() so that we can keep
- track of every inputs into every op created inside the function. If
- any input is from other graphs, we keep track of it in self.capture
- and substitute the input with a place holder.
-
- Each captured input's corresponding place holder is converted into a
- function argument and the caller passes in the captured tensor.
- """
-
- def __init__(self, capture_by_value, *args, **kwargs):
- super(_ExperimentalFuncGraph, self).__init__(*args, **kwargs)
- self._capture_by_value = capture_by_value
- self._building_function = True
- self._outer_graph = ops.get_default_graph()
- self._vscope = vs.get_variable_scope()
- self._old_custom_getter = self._vscope.custom_getter
- self._captured = {}
- self.extra_inputs = []
- self.extra_args = []
- self.extra_vars = []
-
- def create_op(self, op_type, inputs, data_types, **kwargs):
- for i, x in enumerate(inputs):
- if x.graph is not self:
- # Referring to a tensor from other graph.
- if x in self._captured:
- # Captured already.
- inputs[i] = self._captured[x]
- elif self._capture_by_value:
- inputs[i] = self._add_tensor_and_parents(x)
- else:
- # Substitute with a placeholder.
- self.extra_inputs.append(x)
- ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
- # pylint: disable=protected-access
- ph._handle_data = x._handle_data
- # pylint: enable=protected-access
- inputs[i] = ph
- self._captured[x] = ph
- self.extra_args.append(ph)
- return super(_ExperimentalFuncGraph, self).create_op(op_type, inputs,
- data_types, **kwargs)
-
- def _add_tensor_and_parents(self, tensor):
- op = self._add_op_and_parents(tensor.op)
- return op.outputs[tensor.value_index]
-
- def _add_op_and_parents(self, op):
- op_def = graph_to_function_def._get_op_def(op)
- if op_def.is_stateful:
- raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
- "by value." % (op.name, op.type))
- elif op.type in ("Placeholder", "PlaceholderV2"):
- raise ValueError("Cannot capture a placeholder (name:%s, type:%s) "
- "by value." % (op.name, op.type))
-
- captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
-
- captured_op = self.create_op(op.type, captured_inputs,
- [o.dtype for o in op.outputs],
- name=op.name, attrs=op.node_def.attr,
- op_def=op_def)
-
- for t, captured_t in zip(op.outputs, captured_op.outputs):
- self._captured[t] = captured_t
-
- return captured_op
-
-
-class _ExperimentalDefinedFunction(function._DefinedFunction):
- """Overrides _DefinedFunction with support for capture-by-value."""
-
- def __init__(self,
- func,
- argnames,
- input_types,
- func_name=None,
- grad_func=None,
- python_grad_func=None,
- out_names=None,
- shape_func=None,
- capture_by_value=False,
- **kwargs):
- """Creates an _ExperimentalDefinedFunction.
-
- Args:
- func: A python callable which constructs a tf function body.
- argnames: A list of strings for function argument names.
- input_types: The function's argument types. Can be a tuple, list of
- tf data types.
- func_name: The function name. Defaults to None, in which derives from
- 'func'.
- grad_func: This function's gradient function, if not None. Defaults
- to None.
- python_grad_func: A python callable implementing the gradient of
- the function python-side.
- out_names: An optional list of strings for the function return value
- names.
- shape_func: An optional function mapping an op to a list of static
- output shapes.
- capture_by_value: Boolean (defaults to False). If True, captured values
- will be copied into the function body.
- **kwargs: The keyword arguments. **kwargs is passed to every call
- site of this function.
-
- Raises:
- ValueError: The function definition is invalid.
- """
- super(_ExperimentalDefinedFunction, self).__init__(
- func, argnames, input_types, func_name, grad_func, python_grad_func,
- out_names, shape_func, **kwargs)
- self._capture_by_value = capture_by_value
-
- def _create_definition_if_needed(self):
- """Creates the function definition if it's not created yet."""
- with context.graph_mode():
- self._create_definition_if_needed_impl()
-
- def _create_definition_if_needed_impl(self):
- """You're looking for _create_definition_if_needed(), not this."""
-
- if self._definition is not None:
- return
-
- # Create the func_def object.
- temp_graph = _ExperimentalFuncGraph(capture_by_value=self._capture_by_value)
- with temp_graph.as_default():
- # List of placeholders for the function_def.
- inputs = []
- for (argname, argtype) in self._args:
- argholder = array_ops.placeholder(argtype, name=argname)
- inputs.append(argholder)
- # Call func and gather the output tensors.
- with vs.variable_scope("", custom_getter=temp_graph.getvar):
- outputs = self._func(*inputs)
- # If func only returned one value, make it a tuple.
- if not isinstance(outputs, (list, tuple)):
- outputs = (outputs,)
- if any([_ is None for _ in outputs]):
- raise ValueError("Function can not return None.")
- # Ensures each output is a Tensor.
- outputs = [ops.convert_to_tensor(_) for _ in outputs]
- self._extra_inputs = temp_graph.extra_inputs
- inputs.extend(temp_graph.extra_args)
- self._sub_functions = temp_graph._functions
-
- # Build the FunctionDef
- self._definition = graph_to_function_def.graph_to_function_def(
- temp_graph, temp_graph.get_operations(), inputs, outputs,
- out_names=self._out_names)
-
- # Extra kwargs are treated as attrs on the function def.
- sig_pre_func_name = self._func_name or function._get_func_name(self._func)
- kwargs_attr = function._parse_kwargs_as_attrs(
- sig_pre_func_name, **self._extra_kwargs)
- for k in kwargs_attr:
- self._definition.attr[k].CopyFrom(kwargs_attr[k])
-
- # Hash the definition and its dependencies.
- self._hash_str = self._create_hash_str(
- self._definition.signature.input_arg,
- self._definition.signature.output_arg,
- self._definition.node_def)
-
- # Finally, we decide the function name to use. If not specified,
- # make up something which is almost certainly unique (but deterministic).
- if not self._func_name:
- self._func_name = "_".join([function._get_func_name(self._func),
- self._hash_str])
- self._definition.signature.name = self._func_name
- if self._func.__doc__:
- self._definition.signature.description = self._func.__doc__
-
-
-class Defun(function.Defun):
- """Experimental version of Defun supporting capture-by-value."""
-
- def __init__(self, *input_types, **kwargs):
- """Create an experimental `Defun` decorator.
-
- Args:
- *input_types: A list of `tf.DType`
- **kwargs: Optional keyword arguments (see `function.Defun`) plus:
- capture_by_value - Boolean (defaults to False). If True, captured values
- will be copied into the function body.
- """
- super(Defun, self).__init__(*input_types, **kwargs)
-
- def __call__(self, func):
- # Various sanity checks on the callable func.
- if not callable(func):
- raise ValueError("func %s must be callable" % func)
-
- # Func should not use kwargs and defaults.
- argspec = tf_inspect.getargspec(func)
- if argspec.keywords or argspec.defaults:
- raise ValueError("Functions with argument defaults or keyword "
- "arguments are not supported.")
-
- # Computes how many arguments 'func' has.
- min_args = len(argspec.args)
- max_args = min_args
- if argspec.varargs:
- max_args = 1000000
- argnames = argspec.args
- if tf_inspect.ismethod(func):
- # 1st argument is the "class" type.
- min_args -= 1
- argnames = argnames[1:]
-
- if self._input_types:
- # If Defun is given a list of types for the inputs, the number
- # of input types should be compatible with 'func'.
- num = len(self._input_types)
- if num < min_args or num > max_args:
- raise ValueError(
- "The function has fewer arguments than the number of specified "
- "input types.")
- return _ExperimentalDefinedFunction(
- func, argnames, self._input_types, self._func_name, self._grad_func,
- self._python_grad_func, out_names=self._out_names,
- **self._extra_kwargs)
-
- # 'func' expects no arguments and input types is an empty list.
- if min_args == 0 and max_args == 0:
- return _ExperimentalDefinedFunction(
- func, [], [], self._func_name, self._grad_func,
- self._python_grad_func, out_names=self._out_names,
- **self._extra_kwargs)
-
- # Input types are unknown. It's an overloaded function and hence
- # its definition needs to be deferred until it's called.
- return function._OverloadedFunction(
- func, argnames, self._func_name, self._grad_func,
- self._python_grad_func, out_names=self._out_names, **self._extra_kwargs)
diff --git a/tensorflow/contrib/data/python/framework/function_test.py b/tensorflow/contrib/data/python/framework/function_test.py
deleted file mode 100644
index c493170d28..0000000000
--- a/tensorflow/contrib/data/python/framework/function_test.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for experimental capture-by-value feature in TF functions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.framework import function
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class FunctionTest(test.TestCase):
-
- def testCaptureByValue(self):
- g = ops.Graph()
- with g.as_default():
- w = constant_op.constant([[1.0]])
- b = constant_op.constant([2.0])
-
- # Foo() captures w and b.
- @function.Defun(dtypes.float32, capture_by_value=True)
- def Foo(x):
-
- # Plus() captures b.
- @function.Defun(dtypes.float32, capture_by_value=True)
- def Plus(y):
- return y + b
-
- self.assertEqual(0, len(Plus.captured_inputs))
-
- return Plus(math_ops.matmul(w, x))
-
- y = Foo(constant_op.constant([[10.]]))
-
- self.assertEqual(0, len(Foo.captured_inputs))
-
- with self.test_session(graph=g):
- self.assertAllEqual(y.eval(), [[12.0]])
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 6891fd4231..4e5bb9086c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -264,6 +264,7 @@ py_test(
srcs = ["resample_test.py"],
shard_count = 2,
srcs_version = "PY2AND3",
+ tags = ["noasan"],
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 496cdab4ba..c6afbd23ab 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -333,6 +333,55 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
+ def testBatchAndDropRemainder(self):
+ components = (np.arange(7),
+ np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+ np.array(37.0) * np.arange(7))
+
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ iterator = (dataset_ops.Dataset.from_tensor_slices(components)
+ .apply(dataset_ops.batch_and_drop_remainder(batch_size))
+ .make_initializable_iterator())
+
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ for test_batch_size in [1, 3, 7, 10]:
+ sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
+ num_batches = 7 // test_batch_size
+ for i in range(num_batches):
+ result = sess.run(next_element)
+ for component, result_component in zip(components, result):
+ for j in range(test_batch_size):
+ self.assertAllEqual(component[(i * test_batch_size + j)],
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testBatchAndDropRemainderShapeInference(self):
+ components = (array_ops.placeholder(dtypes.int32), (array_ops.placeholder(
+ dtypes.int32, shape=[None]), array_ops.placeholder(
+ dtypes.int32, shape=[20, 30])))
+
+ # Test with a statically known batch size.
+ dataset = (dataset_ops.Dataset.from_tensor_slices(components)
+ .apply(dataset_ops.batch_and_drop_remainder(128)))
+
+ self.assertIs(None, dataset.output_shapes[0].ndims)
+ self.assertEqual([128], dataset.output_shapes[1][0].as_list())
+ self.assertEqual([128, 30], dataset.output_shapes[1][1].as_list())
+
+ # Test with a dynamic batch size: the static shape will be unknown, because
+ # `batch_size` is a placeholder.
+ batch_size = array_ops.placeholder(dtypes.int64)
+ dataset = (dataset_ops.Dataset.from_tensor_slices(components)
+ .apply(dataset_ops.batch_and_drop_remainder(batch_size)))
+
+ self.assertIs(None, dataset.output_shapes[0].ndims)
+ self.assertEqual([None], dataset.output_shapes[1][0].as_list())
+ self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
index 1297a031d1..fb1305f735 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
@@ -22,13 +22,17 @@ import threading
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
+from tensorflow.contrib.data.python.util import nest
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-from tensorflow.python.util import nest
class DatasetConstructorTest(test.TestCase):
@@ -475,6 +479,75 @@ class DatasetConstructorTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testSplitPipelineFailsWithPlacementError(self):
+ with session.Session(
+ target="",
+ config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
+
+ dataset = dataset_ops.Dataset.from_tensors(0)
+
+ # Define a pipeline that attempts to use variables on two
+ # different devices.
+ #
+ # Initialize the variables before creating to iterator, to avoid the
+ # placement algorithm overriding the DT_RESOURCE colocation constraints.
+ with ops.device("/cpu:0"):
+ var_0 = resource_variable_ops.ResourceVariable(initial_value=0)
+ dataset = dataset.map(lambda x: x + var_0.read_value())
+ sess.run(var_0.initializer)
+
+ with ops.device("/cpu:1"):
+ var_1 = resource_variable_ops.ResourceVariable(initial_value=0)
+ dataset = dataset.map(lambda x: x + var_1.read_value())
+ sess.run(var_1.initializer)
+
+ iterator = dataset.make_initializable_iterator()
+
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Trying to access resource located in device"):
+ sess.run(iterator.initializer)
+
+ def testRestructureDataset(self):
+ components = (array_ops.placeholder(dtypes.int32),
+ (array_ops.placeholder(dtypes.int32, shape=[None]),
+ array_ops.placeholder(dtypes.int32, shape=[20, 30])))
+ dataset = dataset_ops.Dataset.from_tensors(components)
+
+ i32 = dtypes.int32
+
+ test_cases = [((i32, i32, i32), None),
+ (((i32, i32), i32), None),
+ ((i32, i32, i32), (None, None, None)),
+ ((i32, i32, i32), ([17], [17], [20, 30]))]
+
+ for new_types, new_shape_lists in test_cases:
+ # pylint: disable=protected-access
+ new = dataset_ops._RestructuredDataset(
+ dataset, new_types, new_shape_lists)
+ # pylint: enable=protected-access
+ self.assertEqual(new_types, new.output_types)
+ if new_shape_lists is not None:
+ for expected_shape_list, shape in zip(
+ nest.flatten(new_shape_lists), nest.flatten(new.output_shapes)):
+ if expected_shape_list is None:
+ self.assertIs(None, shape.ndims)
+ else:
+ self.assertEqual(expected_shape_list, shape.as_list())
+
+ fail_cases = [((i32, dtypes.int64, i32), None),
+ ((i32, i32, i32, i32), None),
+ ((i32, i32, i32), ((None, None), None)),
+ ((i32, i32, i32), (None, None, None, None)),
+ ((i32, i32, i32), (None, [None], [21, 30]))]
+
+ for new_types, new_shape_lists in fail_cases:
+ with self.assertRaises(ValueError):
+ # pylint: disable=protected-access
+ new = dataset_ops._RestructuredDataset(
+ dataset, new_types, new_shape_lists)
+ # pylint: enable=protected-access
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
index fb66acdcac..8662c6520d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
@@ -21,9 +21,11 @@ import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import device_setter
from tensorflow.python.util import compat
@@ -50,7 +52,7 @@ class ResampleTest(test.TestCase):
seed=27))
init_op = iterator.initializer
get_next = iterator.get_next()
- variable_init_op = variables.global_variables_initializer()
+ variable_init_op = variables.local_variables_initializer()
with self.test_session() as sess:
sess.run(variable_init_op)
@@ -74,6 +76,22 @@ class ResampleTest(test.TestCase):
returned_dist = class_counts / total_returned
self.assertAllClose(target_dist, returned_dist, atol=1e-2)
+ def testVariableDevicePlacement(self):
+ classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
+ target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
+ with ops.device(
+ device_setter.replica_device_setter(ps_tasks=1, ps_device="/cpu:0")):
+ dataset = (dataset_ops.Dataset.from_tensor_slices(classes)
+ .shuffle(200, seed=21)
+ .map(lambda c: (c, string_ops.as_string(c))))
+ dataset = dataset_ops.rejection_resample(
+ dataset, target_dist=target_dist, initial_dist=None,
+ class_func=lambda c, _: c, seed=27)
+
+ self.assertEqual(1, len(variables.local_variables()))
+ self.assertEqual(b"",
+ compat.as_bytes(variables.local_variables()[0].device))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 94969c1c70..1abbf38d1a 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -9,7 +9,6 @@ py_library(
srcs = ["dataset_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/framework:function",
"//tensorflow/contrib/data/python/util:nest",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
@@ -17,6 +16,7 @@ py_library(
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
"//tensorflow/python:logging_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
@@ -38,11 +38,11 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":dataset_ops",
- "//tensorflow/contrib/data/python/framework:function",
"//tensorflow/contrib/data/python/util:nest",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
"//tensorflow/python:platform",
],
)
diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py
index b67aa9012d..040987b071 100644
--- a/tensorflow/contrib/data/python/ops/dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/dataset_ops.py
@@ -23,10 +23,10 @@ import threading
import numpy as np
-from tensorflow.contrib.data.python.framework import function
from tensorflow.contrib.data.python.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
@@ -99,8 +99,9 @@ class Iterator(object):
shared_name=shared_name,
output_types=nest.flatten(dataset.output_types),
output_shapes=nest.flatten(dataset.output_shapes))
- initializer = gen_dataset_ops.make_iterator(dataset.make_dataset_resource(),
- iterator_resource)
+ with ops.colocate_with(iterator_resource):
+ initializer = gen_dataset_ops.make_iterator(
+ dataset.make_dataset_resource(), iterator_resource)
return Iterator(iterator_resource, initializer, dataset.output_types,
dataset.output_shapes)
@@ -291,6 +292,7 @@ class Iterator(object):
raise TypeError("Expected output shapes compatible with %r but got "
"dataset with output shapes %r." %
(self._output_shapes, dataset.output_shapes))
+ with ops.colocate_with(self._iterator_resource):
return gen_dataset_ops.make_iterator(
dataset.make_dataset_resource(), self._iterator_resource, name=name)
@@ -2404,12 +2406,16 @@ def rejection_resample(dataset,
num_classes = (target_dist.shape[0].value or
array_ops.shape(target_dist)[0])
smoothing_constant = 10
- num_examples_per_class_seen = resource_variable_ops.ResourceVariable(
- initial_value=array_ops.fill([num_classes],
- np.int64(smoothing_constant)),
- trainable=False,
- name="class_count",
- dtype=dtypes.int64)
+ # Disable device functions and colocation constraints so that the variable
+ # will be placed with the eventual DT_VARIANT dataset tensor.
+ with ops.colocate_with(None, ignore_existing=True):
+ num_examples_per_class_seen = resource_variable_ops.ResourceVariable(
+ initial_value=array_ops.fill([num_classes],
+ np.int64(smoothing_constant)),
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ name="local_class_count",
+ dtype=dtypes.int64)
def update_estimate_and_tile(c):
return array_ops.tile(
@@ -2519,7 +2525,13 @@ def read_batch_features(file_pattern,
dataset = reader(filenames, *reader_args)
else:
dataset = reader(filenames)
- dataset = dataset.repeat(num_epochs)
+ if dataset.output_types == (dtypes.string, dtypes.string):
+ dataset = dataset.map(lambda unused_k, v: v)
+ elif dataset.output_types != dtypes.string:
+ raise TypeError("`reader` must be a dataset of `tf.string` values, "
+ "or `(tf.string, tf.string)` key-value pairs.")
+ if num_epochs != 1:
+ dataset = dataset.repeat(num_epochs)
if randomize_input:
dataset = dataset.shuffle(capacity)
dataset = dataset.batch(batch_size)
@@ -2729,3 +2741,137 @@ def group_by_window(dataset,
assert window_size_func is not None
return GroupByWindowDataset(dataset, key_func, reduce_func, window_size_func)
+
+
+class _RestructuredDataset(Dataset):
+ """An internal helper for changing the structure and shape of a dataset."""
+
+ def __init__(self, dataset, output_types, output_shapes=None):
+ """Creates a new dataset with the given output types and shapes.
+
+ The given `dataset` must have a structure that is convertible:
+ * `dataset.output_types` must be the same as `output_types` module nesting.
+ * Each shape in `dataset.output_shapes` must be compatible with each shape
+ in `output_shapes` (if given).
+
+ Note: This helper permits "unsafe casts" for shapes, equivalent to using
+ `tf.Tensor.set_shape()` where domain-specific knowledge is available.
+
+ Args:
+ dataset: A `Dataset` object.
+ output_types: A nested structure of `tf.DType` objects.
+ output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
+ If omitted, the shapes will be inherited from `dataset`.
+
+ Raises:
+ ValueError: If either `output_types` or `output_shapes` is not compatible
+ with the structure of `dataset`.
+ """
+ super(_RestructuredDataset, self).__init__()
+ self._dataset = dataset
+
+ # Validate that the types are compatible.
+ output_types = nest.map_structure(dtypes.as_dtype, output_types)
+ flat_original_types = nest.flatten(dataset.output_types)
+ flat_new_types = nest.flatten(output_types)
+ if flat_original_types != flat_new_types:
+ raise ValueError(
+ "Dataset with output types %r cannot be restructured to have output "
+ "types %r" % (dataset.output_types, output_types))
+
+ self._output_types = output_types
+
+ if output_shapes is None:
+ # Inherit shapes from the original `dataset`.
+ self._output_shapes = nest.pack_sequence_as(
+ output_types, nest.flatten(dataset.output_shapes))
+ else:
+ # Validate that the shapes are compatible.
+ nest.assert_same_structure(output_types, output_shapes)
+ flat_original_shapes = nest.flatten(dataset.output_shapes)
+ flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
+
+ for original_shape, new_shape in zip(flat_original_shapes,
+ flat_new_shapes):
+ if not original_shape.is_compatible_with(new_shape):
+ raise ValueError(
+ "Dataset with output shapes %r cannot be restructured to have "
+ "incompatible output shapes %r"
+ % (dataset.output_shapes, output_shapes))
+ self._output_shapes = nest.map_structure_up_to(
+ output_types, tensor_shape.as_shape, output_shapes)
+
+ def make_dataset_resource(self):
+ return self._dataset.make_dataset_resource()
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+
+def batch_and_drop_remainder(batch_size):
+ """A batching transformation that omits the final small batch (if present).
+
+ Like @{tf.contrib.data.Dataset.batch}, this transformation combines
+ consecutive elements of this dataset into batches. However, if the batch
+ size does not evenly divide the input dataset size, this transformation will
+ drop the final smaller element.
+
+ The following example illustrates the difference between this
+ transformation and `Dataset.batch()`:
+
+ ```python
+ dataset = tf.contrib.data.Dataset.range(200)
+ batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128))
+ print(batched.output_shapes) # ==> "(128,)" (the batch dimension is known)
+ ```
+
+ By contrast, `dataset.batch(128)` would yield a two-element dataset with
+ shapes `(128,)` and `(72,)`, so the batch dimension would not be statically
+ known.
+
+ Args:
+ batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ consecutive elements of this dataset to combine in a single batch.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.contrib.data.Dataset.apply}
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ tensor_batch_size = ops.convert_to_tensor(
+ batch_size, dtype=dtypes.int64, name="batch_size")
+
+ batched = dataset.batch(tensor_batch_size)
+ flattened = _RestructuredDataset(batched,
+ tuple(nest.flatten(batched.output_types)))
+
+ def _predicate(*xs):
+ """Return `True` if this element is a full batch."""
+ # Extract the dynamic batch size from the first component of the flattened
+ # batched element.
+ first_component = xs[0]
+ first_component_batch_size = array_ops.shape(
+ first_component, out_type=dtypes.int64)[0]
+
+ return math_ops.equal(first_component_batch_size, tensor_batch_size)
+
+ filtered = flattened.filter(_predicate)
+
+ maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
+
+ def _set_first_dimension(shape):
+ return shape.merge_with(
+ tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
+
+ known_shapes = nest.map_structure(_set_first_dimension,
+ batched.output_shapes)
+ return _RestructuredDataset(filtered, batched.output_types, known_shapes)
+
+ return _apply_fn
diff --git a/tensorflow/contrib/data/python/ops/sloppy_ops.py b/tensorflow/contrib/data/python/ops/sloppy_ops.py
index 010bd31161..9e4708f919 100644
--- a/tensorflow/contrib/data/python/ops/sloppy_ops.py
+++ b/tensorflow/contrib/data/python/ops/sloppy_ops.py
@@ -17,10 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.framework import function
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.util import nest
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index 0f97464ebc..27eaeb43b0 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -30,6 +30,7 @@ cuda_py_test(
":tfe",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform_test",
+ "//tensorflow/python/eager:test",
],
)
diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py
index 2a9d7589d3..6ed7e0d60d 100644
--- a/tensorflow/contrib/eager/python/tfe_test.py
+++ b/tensorflow/contrib/eager/python/tfe_test.py
@@ -18,7 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.eager.python import tfe
-from tensorflow.python.platform import test
+from tensorflow.python.eager import test
class TFETest(test.TestCase):
@@ -31,6 +31,14 @@ class TFETest(test.TestCase):
devices = tfe.list_devices()
self.assertEqual(len(devices) - 1, tfe.num_gpus())
+ def testCallingEnableEagerExecutionMoreThanOnce(self):
+ # Note that eager.test.main() has already invoked enable_eager_exceution().
+ with self.assertRaisesRegexp(
+ ValueError,
+ r"Do not call tfe\.%s more than once in the same process" %
+ tfe.enable_eager_execution.__name__):
+ tfe.enable_eager_execution()
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 2d2794e350..4b2050c932 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -25,12 +25,47 @@ py_library(
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
+ ":dnn",
":extenders",
":head",
],
)
py_library(
+ name = "dnn",
+ srcs = ["python/estimator/dnn.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:nn",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:dnn",
+ ],
+)
+
+py_test(
+ name = "dnn_test",
+ size = "small",
+ srcs = ["python/estimator/dnn_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dnn",
+ ":head",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:summary",
+ "//tensorflow/python/estimator:dnn_testing_utils",
+ "//tensorflow/python/estimator:export_export",
+ "//tensorflow/python/estimator:numpy_io",
+ "//tensorflow/python/estimator:prediction_keys",
+ "//tensorflow/python/feature_column",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "extenders",
srcs = [
"python/estimator/extenders.py",
@@ -68,6 +103,37 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:summary",
+ "//tensorflow/python/estimator:export_output",
"//tensorflow/python/estimator:head",
+ "//tensorflow/python/estimator:metric_keys",
+ "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/estimator:prediction_keys",
+ "//tensorflow/python/ops/losses",
+ ],
+)
+
+py_test(
+ name = "head_test",
+ size = "small",
+ srcs = ["python/estimator/head_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":head",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/estimator:metric_keys",
+ "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/estimator:prediction_keys",
+ "//tensorflow/python/saved_model:signature_constants",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 346653c47f..7bb53d7715 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.estimator.python.estimator.dnn import *
from tensorflow.contrib.estimator.python.estimator.extenders import *
from tensorflow.contrib.estimator.python.estimator.head import *
@@ -29,7 +30,9 @@ _allowed_symbols = [
'add_metrics',
'binary_classification_head',
'multi_class_head',
+ 'multi_label_head',
'regression_head',
+ 'DNNEstimator',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py
new file mode 100644
index 0000000000..cf6e3329d2
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/dnn.py
@@ -0,0 +1,134 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Deep Neural Network estimators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator.canned import dnn as dnn_lib
+from tensorflow.python.ops import nn
+
+
+class DNNEstimator(estimator.Estimator):
+ """An estimator for TensorFlow DNN models with user-specified head.
+
+ Example:
+
+ ```python
+ sparse_feature_a = sparse_column_with_hash_bucket(...)
+ sparse_feature_b = sparse_column_with_hash_bucket(...)
+
+ sparse_feature_a_emb = embedding_column(sparse_id_column=sparse_feature_a,
+ ...)
+ sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b,
+ ...)
+
+ estimator = DNNEstimator(
+ head=tf.contrib.estimator.multi_label_head(n_classes=3),
+ feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
+ hidden_units=[1024, 512, 256])
+
+ # Or estimator using the ProximalAdagradOptimizer optimizer with
+ # regularization.
+ estimator = DNNEstimator(
+ head=tf.contrib.estimator.multi_label_head(n_classes=3),
+ feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
+ hidden_units=[1024, 512, 256],
+ optimizer=tf.train.ProximalAdagradOptimizer(
+ learning_rate=0.1,
+ l1_regularization_strength=0.001
+ ))
+
+ # Input builders
+ def input_fn_train: # returns x, y
+ pass
+ estimator.train(input_fn=input_fn_train, steps=100)
+
+ def input_fn_eval: # returns x, y
+ pass
+ metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
+ def input_fn_predict: # returns x, None
+ pass
+ predictions = estimator.predict(input_fn=input_fn_predict)
+ ```
+
+ Input of `train` and `evaluate` should have following features,
+ otherwise there will be a `KeyError`:
+
+ * if `weight_column` is not `None`, a feature with
+ `key=weight_column` whose value is a `Tensor`.
+ * for each `column` in `feature_columns`:
+ - if `column` is a `_CategoricalColumn`, a feature with `key=column.name`
+ whose `value` is a `SparseTensor`.
+ - if `column` is a `_WeightedCategoricalColumn`, two features: the first
+ with `key` the id column name, the second with `key` the weight column
+ name. Both features' `value` must be a `SparseTensor`.
+ - if `column` is a `_DenseColumn`, a feature with `key=column.name`
+ whose `value` is a `Tensor`.
+
+ Loss and predicted output are determined by the specified head.
+ """
+
+ def __init__(self,
+ head,
+ hidden_units,
+ feature_columns,
+ model_dir=None,
+ optimizer='Adagrad',
+ activation_fn=nn.relu,
+ dropout=None,
+ input_layer_partitioner=None,
+ config=None):
+ """Initializes a `DNNClassifier` instance.
+
+ Args:
+ head: A `_Head` instance constructed with a method such as
+ `tf.contrib.estimator.multi_label_head`.
+ hidden_units: Iterable of number hidden units per layer. All layers are
+ fully connected. Ex. `[64, 32]` means first layer has 64 nodes and
+ second one has 32.
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `_FeatureColumn`.
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
+ to Adagrad optimizer.
+ activation_fn: Activation function applied to each layer. If `None`, will
+ use `tf.nn.relu`.
+ dropout: When not `None`, the probability we will drop out a given
+ coordinate.
+ input_layer_partitioner: Optional. Partitioner for input layer. Defaults
+ to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+ config: `RunConfig` object to configure the runtime settings.
+ """
+ def _model_fn(features, labels, mode, config):
+ return dnn_lib._dnn_model_fn( # pylint: disable=protected-access
+ features=features,
+ labels=labels,
+ mode=mode,
+ head=head,
+ hidden_units=hidden_units,
+ feature_columns=tuple(feature_columns or []),
+ optimizer=optimizer,
+ activation_fn=activation_fn,
+ dropout=dropout,
+ input_layer_partitioner=input_layer_partitioner,
+ config=config)
+ super(DNNEstimator, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_test.py
new file mode 100644
index 0000000000..71f810acec
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_test.py
@@ -0,0 +1,153 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for dnn.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import shutil
+import tempfile
+
+import numpy as np
+import six
+
+from tensorflow.contrib.estimator.python.estimator import dnn
+from tensorflow.contrib.estimator.python.estimator import head as head_lib
+from tensorflow.python.estimator.canned import dnn_testing_utils
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+
+
+def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs):
+ """Returns a DNNEstimator that uses regression_head."""
+ return dnn.DNNEstimator(
+ head=head_lib.regression_head(
+ weight_column=weight_column, label_dimension=label_dimension),
+ *args, **kwargs)
+
+
+class DNNEstimatorEvaluateTest(
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
+ self, _dnn_estimator_fn)
+
+
+class DNNEstimatorPredictTest(
+ dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
+ self, _dnn_estimator_fn)
+
+
+class DNNEstimatorTrainTest(
+ dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
+ self, _dnn_estimator_fn)
+
+
+class DNNEstimatorIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _test_complete_flow(
+ self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
+ label_dimension, batch_size):
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))]
+ est = dnn.DNNEstimator(
+ head=head_lib.regression_head(label_dimension=label_dimension),
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ num_steps = 10
+ est.train(train_input_fn, steps=num_steps)
+
+ # EVALUTE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn('loss', six.iterkeys(scores))
+
+ # PREDICT
+ predictions = np.array([
+ x[prediction_keys.PredictionKeys.PREDICTIONS]
+ for x in est.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((batch_size, label_dimension), predictions.shape)
+
+ # EXPORT
+ feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ def test_numpy_input_fn(self):
+ """Tests complete flow with numpy_input_fn."""
+ label_dimension = 2
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+ # learn y = x
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size,
+ shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ batch_size=batch_size,
+ shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=label_dimension,
+ label_dimension=label_dimension,
+ batch_size=batch_size)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index 005de115d4..164dfe6e82 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -18,7 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.estimator.canned import metric_keys
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export_output
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.summary import summary
def multi_class_head(n_classes,
@@ -33,7 +46,7 @@ def multi_class_head(n_classes,
Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
- `_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`).
+ `binary_classification_head`).
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It
@@ -123,3 +136,206 @@ def regression_head(weight_column=None,
weight_column=weight_column,
label_dimension=label_dimension,
head_name=head_name)
+
+
+# TODO(roumposg): Support label_vocabulary.
+def multi_label_head(n_classes,
+ weight_column=None,
+ thresholds=None,
+ head_name=None):
+ """Creates a `_Head` for multi-label classification.
+
+ Multi-label classification handles the case where each example may have zero
+ or more associated labels, from a discrete set. This is distinct from
+ `multi_class_head` which has exactly one label per example.
+
+ Uses `sigmoid_cross_entropy` loss averaged over classes. Expects labels as a
+ multi-hot tensor of shape `[batch_size, n_classes]`, or as an integer
+ `SparseTensor` of class indices.
+
+ Args:
+ n_classes: Number of classes, must be greater than 1 (for 1 class, use
+ `binary_classification_head`).
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example.
+ thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision
+ and recall metrics are evaluated for each threshold value. The threshold
+ is applied to the predicted probabilities, i.e. above the threshold is
+ `true`, below is `false`.
+ head_name: name of the head. If provided, summary and metrics keys will be
+ suffixed by `"/" + head_name`.
+
+ Returns:
+ An instance of `_Head` for multi-label classification.
+
+ Raises:
+ ValueError: if `n_classes` or `thresholds` is invalid.
+ """
+ thresholds = tuple(thresholds) if thresholds else tuple()
+ if n_classes is None or n_classes < 2:
+ raise ValueError(
+ 'n_classes must be > 1 for multi-class classification. '
+ 'Given: {}'.format(n_classes))
+ for threshold in thresholds:
+ if (threshold <= 0.0) or (threshold >= 1.0):
+ raise ValueError(
+ 'thresholds must be in (0, 1) range. Given: {}'.format(threshold))
+ return _MultiLabelHead(
+ n_classes=n_classes, weight_column=weight_column, thresholds=thresholds,
+ head_name=head_name)
+
+
+class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
+ """`_Head` for multi-label classification."""
+
+ def __init__(self,
+ n_classes,
+ weight_column=None,
+ thresholds=None,
+ head_name=None):
+ self._n_classes = n_classes
+ self._weight_column = weight_column
+ self._thresholds = thresholds
+ self._head_name = head_name
+
+ @property
+ def logits_dimension(self):
+ return self._n_classes
+
+ def _process_labels(self, labels):
+ if isinstance(labels, sparse_tensor.SparseTensor):
+ return math_ops.to_int64(
+ sparse_ops.sparse_to_indicator(labels, self._n_classes))
+ msg = ('labels shape must be [batch_size, {}]. '
+ 'Given: ').format(self._n_classes)
+ labels_shape = array_ops.shape(labels)
+ check_rank_op = control_flow_ops.Assert(
+ math_ops.equal(array_ops.rank(labels), 2),
+ data=[msg, labels_shape])
+ check_label_dim = control_flow_ops.Assert(
+ math_ops.equal(labels_shape[-1], self._n_classes),
+ data=[msg, labels_shape])
+ with ops.control_dependencies([check_rank_op, check_label_dim]):
+ return array_ops.identity(labels)
+
+ def create_loss(self, features, mode, logits, labels):
+ """See `Head`."""
+ del mode, features # Unused for this head.
+ processed_labels = self._process_labels(labels)
+ unweighted_loss = losses.sigmoid_cross_entropy(
+ multi_class_labels=processed_labels, logits=logits,
+ reduction=losses.Reduction.NONE)
+ return head_lib.LossAndLabels(
+ unweighted_loss=unweighted_loss,
+ processed_labels=processed_labels)
+
+ def create_estimator_spec(
+ self, features, mode, logits, labels=None, train_op_fn=None):
+ """See `Head`."""
+ with ops.name_scope('head'):
+ logits = head_lib._check_logits(logits, self.logits_dimension) # pylint:disable=protected-access
+
+ # Predict.
+ pred_keys = prediction_keys.PredictionKeys
+ with ops.name_scope(None, 'predictions', (logits,)):
+ probabilities = math_ops.sigmoid(logits, name=pred_keys.PROBABILITIES)
+ predictions = {
+ pred_keys.LOGITS: logits,
+ pred_keys.PROBABILITIES: probabilities,
+ }
+ if mode == model_fn.ModeKeys.PREDICT:
+ return model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.PREDICT,
+ predictions=predictions,
+ export_outputs={
+ '': export_output.ClassificationOutput(scores=probabilities)
+ })
+
+ # Eval.
+ unweighted_loss, _ = self.create_loss(
+ features=features, mode=mode, logits=logits, labels=labels)
+ # Averages loss over classes.
+ per_example_loss = math_ops.reduce_mean(
+ unweighted_loss, axis=-1, keep_dims=True)
+ weights = head_lib._weights(features, self._weight_column) # pylint:disable=protected-access
+ training_loss = losses.compute_weighted_loss(
+ per_example_loss, weights=weights, reduction=losses.Reduction.SUM)
+ if mode == model_fn.ModeKeys.EVAL:
+ return model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL,
+ predictions=predictions,
+ loss=training_loss,
+ eval_metric_ops=self._eval_metric_ops(
+ labels=labels,
+ probabilities=probabilities,
+ weights=weights,
+ per_example_loss=per_example_loss))
+
+ # Train.
+ if train_op_fn is None:
+ raise ValueError('train_op_fn can not be None.')
+ with ops.name_scope(''):
+ summary.scalar(
+ head_lib._summary_key(self._head_name, metric_keys.MetricKeys.LOSS), # pylint:disable=protected-access
+ training_loss)
+ summary.scalar(
+ head_lib._summary_key( # pylint:disable=protected-access
+ self._head_name, metric_keys.MetricKeys.LOSS_MEAN),
+ losses.compute_weighted_loss(
+ unweighted_loss, weights=weights,
+ reduction=losses.Reduction.MEAN))
+ return model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.TRAIN,
+ predictions=predictions,
+ loss=training_loss,
+ train_op=train_op_fn(training_loss))
+
+ def _eval_metric_ops(self, labels, probabilities, weights, per_example_loss):
+ """Returns a dict of metrics for eval_metric_ops."""
+ with ops.name_scope(
+ None, 'metrics', [labels, probabilities, weights, per_example_loss]):
+ keys = metric_keys.MetricKeys
+ metric_ops = {
+ # Estimator already adds a metric for loss.
+ head_lib._summary_key(self._head_name, keys.LOSS_MEAN): # pylint:disable=protected-access
+ metrics_lib.mean(
+ per_example_loss, weights=weights, name=keys.LOSS_MEAN),
+ head_lib._summary_key(self._head_name, keys.AUC): # pylint:disable=protected-access
+ metrics_lib.auc(
+ labels=labels, predictions=probabilities, weights=weights,
+ name=keys.AUC),
+ head_lib._summary_key(self._head_name, keys.AUC_PR): # pylint:disable=protected-access
+ metrics_lib.auc(
+ labels=labels, predictions=probabilities, weights=weights,
+ curve='PR', name=keys.AUC_PR),
+ }
+ for threshold in self._thresholds:
+ accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold
+ metric_ops[head_lib._summary_key(self._head_name, accuracy_key)] = ( # pylint:disable=protected-access
+ head_lib._accuracy_at_threshold( # pylint:disable=protected-access
+ labels=labels,
+ predictions=probabilities,
+ weights=weights,
+ threshold=threshold,
+ name=accuracy_key))
+ # Precision for positive examples.
+ precision_key = keys.PRECISION_AT_THRESHOLD % threshold
+ metric_ops[head_lib._summary_key(self._head_name, precision_key)] = ( # pylint:disable=protected-access
+ head_lib._precision_at_threshold( # pylint:disable=protected-access
+ labels=labels,
+ predictions=probabilities,
+ weights=weights,
+ threshold=threshold,
+ name=precision_key))
+ # Recall for positive examples.
+ recall_key = keys.RECALL_AT_THRESHOLD % threshold
+ metric_ops[head_lib._summary_key(self._head_name, recall_key)] = ( # pylint:disable=protected-access
+ head_lib._recall_at_threshold( # pylint:disable=protected-access
+ labels=labels,
+ predictions=probabilities,
+ weights=weights,
+ threshold=threshold,
+ name=recall_key))
+ return metric_ops
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
new file mode 100644
index 0000000000..17753b4c9b
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -0,0 +1,570 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for head."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import six
+
+from tensorflow.contrib.estimator.python.estimator import head as head_lib
+from tensorflow.core.framework import summary_pb2
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator.canned import metric_keys
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.training import monitored_session
+
+
+_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+
+
+def _initialize_variables(test_case, scaffold):
+ scaffold.finalize()
+ test_case.assertIsNone(scaffold.init_feed_dict)
+ test_case.assertIsNone(scaffold.init_fn)
+ scaffold.init_op.run()
+ scaffold.ready_for_local_init_op.eval()
+ scaffold.local_init_op.run()
+ scaffold.ready_op.eval()
+ test_case.assertIsNotNone(scaffold.saver)
+
+
+def _assert_simple_summaries(test_case, expected_summaries, summary_str,
+ tol=1e-6):
+ """Assert summary the specified simple values.
+
+ Args:
+ test_case: test case.
+ expected_summaries: Dict of expected tags and simple values.
+ summary_str: Serialized `summary_pb2.Summary`.
+ tol: Tolerance for relative and absolute.
+ """
+ summary = summary_pb2.Summary()
+ summary.ParseFromString(summary_str)
+ test_case.assertAllClose(expected_summaries, {
+ v.tag: v.simple_value for v in summary.value
+ }, rtol=tol, atol=tol)
+
+
+def _assert_no_hooks(test_case, spec):
+ test_case.assertAllEqual([], spec.training_chief_hooks)
+ test_case.assertAllEqual([], spec.training_hooks)
+
+
+def _sigmoid(logits):
+ return 1 / (1 + np.exp(-logits))
+
+
+def _sigmoid_cross_entropy(labels, logits):
+ sigmoid_logits = _sigmoid(logits)
+ return (-labels * np.log(sigmoid_logits)
+ -(1 - labels) * np.log(1 - sigmoid_logits))
+
+
+class MultiLabelHead(test.TestCase):
+
+ def setUp(self):
+ ops.reset_default_graph()
+
+ def test_n_classes_is_none(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'n_classes must be > 1 for multi-class classification\. Given: None'):
+ head_lib.multi_label_head(n_classes=None)
+
+ def test_n_classes_is_1(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'n_classes must be > 1 for multi-class classification\. Given: 1'):
+ head_lib.multi_label_head(n_classes=1)
+
+ def test_threshold_too_small(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'thresholds must be in \(0, 1\) range\. Given: 0\.0'):
+ head_lib.multi_label_head(n_classes=2, thresholds=[0., 0.5])
+
+ def test_threshold_too_large(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'thresholds must be in \(0, 1\) range\. Given: 1\.0'):
+ head_lib.multi_label_head(n_classes=2, thresholds=[0.5, 1.0])
+
+ def test_predict(self):
+ n_classes = 4
+ head = head_lib.multi_label_head(n_classes)
+ self.assertEqual(n_classes, head.logits_dimension)
+
+ logits = np.array(
+ [[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32)
+ expected_probabilities = _sigmoid(logits)
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.PREDICT,
+ logits=logits)
+
+ self.assertItemsEqual(
+ ('', _DEFAULT_SERVING_KEY), spec.export_outputs.keys())
+
+ # Assert predictions and export_outputs.
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNone(spec.scaffold.summary_op)
+ predictions = sess.run(spec.predictions)
+ self.assertAllClose(logits,
+ predictions[prediction_keys.PredictionKeys.LOGITS])
+ self.assertAllClose(
+ expected_probabilities,
+ predictions[prediction_keys.PredictionKeys.PROBABILITIES])
+
+ self.assertAllClose(
+ expected_probabilities,
+ sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores))
+
+ def test_weight_should_not_impact_prediction(self):
+ n_classes = 4
+ head = head_lib.multi_label_head(n_classes, weight_column='label_weights')
+ self.assertEqual(n_classes, head.logits_dimension)
+
+ logits = np.array(
+ [[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32)
+ expected_probabilities = _sigmoid(logits)
+
+ weights_2x1 = [[1.], [2.]]
+ spec = head.create_estimator_spec(
+ features={
+ 'x': np.array(((42,),), dtype=np.int32),
+ 'label_weights': weights_2x1,
+ },
+ mode=model_fn.ModeKeys.PREDICT,
+ logits=logits)
+
+ # Assert predictions and export_outputs.
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNone(spec.scaffold.summary_op)
+ predictions = sess.run(spec.predictions)
+ self.assertAllClose(logits,
+ predictions[prediction_keys.PredictionKeys.LOGITS])
+ self.assertAllClose(
+ expected_probabilities,
+ predictions[prediction_keys.PredictionKeys.PROBABILITIES])
+
+ def test_eval_create_loss(self):
+ """Tests head.create_loss for eval mode."""
+ n_classes = 2
+ head = head_lib.multi_label_head(n_classes)
+
+ logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)
+ labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # loss = labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))
+ expected_unweighted_loss = _sigmoid_cross_entropy(
+ labels=labels, logits=logits)
+ actual_unweighted_loss, _ = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ self.assertAllClose(
+ expected_unweighted_loss, actual_unweighted_loss.eval())
+
+ def test_eval_create_loss_large_logits(self):
+ """Tests head.create_loss for eval mode and large logits."""
+ n_classes = 2
+ head = head_lib.multi_label_head(n_classes)
+
+ logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
+ labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # loss = labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))
+ # For large logits, this is approximated as:
+ # loss = labels * (logits < 0) * (-logits) +
+ # (1 - labels) * (logits > 0) * logits
+ expected_unweighted_loss = np.array(
+ [[10., 10.], [15., 0.]], dtype=np.float32)
+ actual_unweighted_loss, _ = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ self.assertAllClose(
+ expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
+
+ def test_eval_create_loss_sparse_labels(self):
+ """Tests head.create_loss for eval mode and sparse labels."""
+ n_classes = 2
+ head = head_lib.multi_label_head(n_classes)
+
+ logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
+ labels = sparse_tensor.SparseTensor(
+ values=[0, 0, 1],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ expected_labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # loss = labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))
+ # For large logits, this is approximated as:
+ # loss = labels * (logits < 0) * (-logits) +
+ # (1 - labels) * (logits > 0) * logits
+ expected_unweighted_loss = np.array(
+ [[10., 10.], [15., 0.]], dtype=np.float32)
+ actual_unweighted_loss, actual_labels = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ self.assertAllEqual(expected_labels, actual_labels.eval())
+ self.assertAllClose(
+ expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
+
+ def test_eval_create_loss_labels_wrong_shape(self):
+ """Tests head.create_loss for eval mode when labels has the wrong shape."""
+ n_classes = 2
+ head = head_lib.multi_label_head(n_classes)
+
+ logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)
+ labels_placeholder = array_ops.placeholder(dtype=dtypes.int64)
+ actual_unweighted_loss, _ = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels_placeholder)
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'labels shape must be \[batch_size, 2\]\. Given: \] \[2 1\]'):
+ actual_unweighted_loss.eval(
+ {labels_placeholder: np.array([[1], [1]], dtype=np.int64)})
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'labels shape must be \[batch_size, 2\]\. Given: \] \[2\]'):
+ actual_unweighted_loss.eval(
+ {labels_placeholder: np.array([1, 1], dtype=np.int64)})
+
+ def test_eval(self):
+ n_classes = 2
+ head = head_lib.multi_label_head(n_classes)
+
+ logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
+ labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # loss = labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))
+ # Average over classes, and sum over examples.
+ expected_loss = (
+ np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes
+ )
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)
+
+ keys = metric_keys.MetricKeys
+ expected_metrics = {
+ # Average loss over examples.
+ keys.LOSS_MEAN: expected_loss / 2,
+ # auc and auc_pr cannot be reliably calculated for only 4 samples, but
+ # this assert tests that the algorithm remains consistent.
+ keys.AUC: 0.3333,
+ keys.AUC_PR: 0.7639,
+ }
+
+ # Assert spec contains expected tensors.
+ self.assertIsNotNone(spec.loss)
+ self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
+ self.assertIsNone(spec.train_op)
+ self.assertIsNone(spec.export_outputs)
+ _assert_no_hooks(self, spec)
+
+ # Assert predictions, loss, and metrics.
+ tol = 1e-3
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNone(spec.scaffold.summary_op)
+ value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
+ update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
+ loss, metrics = sess.run((spec.loss, update_ops))
+ self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
+ # Check results of both update (in `metrics`) and value ops.
+ self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
+ self.assertAllClose(
+ expected_metrics, {k: value_ops[k].eval() for k in value_ops},
+ rtol=tol,
+ atol=tol)
+
+ def test_eval_with_thresholds(self):
+ n_classes = 2
+ thresholds = [0.25, 0.5, 0.75]
+ head = head_lib.multi_label_head(n_classes, thresholds=thresholds)
+
+ logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
+ labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # loss = labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))
+ # Average over classes, and sum over examples.
+ expected_loss = (
+ np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes
+ )
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)
+
+ keys = metric_keys.MetricKeys
+ expected_metrics = {
+ # Average loss over examples.
+ keys.LOSS_MEAN: expected_loss / 2,
+ # auc and auc_pr cannot be reliably calculated for only 4 samples, but
+ # this assert tests that the algorithm remains consistent.
+ keys.AUC: 0.3333,
+ keys.AUC_PR: 0.7639,
+ keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4.,
+ keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3.,
+ keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3.,
+ keys.ACCURACY_AT_THRESHOLD % thresholds[1]: 1. / 4.,
+ keys.PRECISION_AT_THRESHOLD % thresholds[1]: 1. / 2.,
+ keys.RECALL_AT_THRESHOLD % thresholds[1]: 1. / 3.,
+ keys.ACCURACY_AT_THRESHOLD % thresholds[2]: 2. / 4.,
+ keys.PRECISION_AT_THRESHOLD % thresholds[2]: 1. / 1.,
+ keys.RECALL_AT_THRESHOLD % thresholds[2]: 1. / 3.,
+ }
+
+ # Assert spec contains expected tensors.
+ self.assertIsNotNone(spec.loss)
+ self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
+ self.assertIsNone(spec.train_op)
+ self.assertIsNone(spec.export_outputs)
+ _assert_no_hooks(self, spec)
+
+ # Assert predictions, loss, and metrics.
+ tol = 1e-3
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNone(spec.scaffold.summary_op)
+ value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
+ update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
+ loss, metrics = sess.run((spec.loss, update_ops))
+ self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
+ # Check results of both update (in `metrics`) and value ops.
+ self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
+ self.assertAllClose(
+ expected_metrics, {k: value_ops[k].eval() for k in value_ops},
+ rtol=tol,
+ atol=tol)
+
+ def test_eval_with_weights(self):
+ n_classes = 2
+ head = head_lib.multi_label_head(n_classes, weight_column='label_weights')
+
+ logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
+ labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # For large logits, sigmoid cross entropy loss is approximated as:
+ # loss = labels * (logits < 0) * (-logits) +
+ # (1 - labels) * (logits > 0) * logits =>
+ # expected_unweighted_loss = [[10., 10.], [15., 0.]]
+ # Average over classes, weighted sum over examples.
+ expected_loss = 25.
+
+ spec = head.create_estimator_spec(
+ features={
+ 'x': np.array([[41], [42]], dtype=np.int32),
+ 'label_weights': np.array([[1.], [2.]], dtype=np.float32),
+ },
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)
+
+ keys = metric_keys.MetricKeys
+ expected_metrics = {
+ # Average loss over weighted examples.
+ keys.LOSS_MEAN: expected_loss / 3,
+ # auc and auc_pr cannot be reliably calculated for only 4 samples, but
+ # this assert tests that the algorithm remains consistent.
+ keys.AUC: 0.2000,
+ keys.AUC_PR: 0.7833,
+ }
+
+ # Assert spec contains expected tensors.
+ self.assertIsNotNone(spec.loss)
+ self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
+ self.assertIsNone(spec.train_op)
+ self.assertIsNone(spec.export_outputs)
+ _assert_no_hooks(self, spec)
+
+ # Assert predictions, loss, and metrics.
+ tol = 1e-3
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNone(spec.scaffold.summary_op)
+ value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
+ update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
+ loss, metrics = sess.run((spec.loss, update_ops))
+ self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
+ # Check results of both update (in `metrics`) and value ops.
+ self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
+ self.assertAllClose(
+ expected_metrics, {k: value_ops[k].eval() for k in value_ops},
+ rtol=tol,
+ atol=tol)
+
+ def test_train_create_loss_large_logits(self):
+ """Tests head.create_loss for train mode and large logits."""
+ n_classes = 2
+ head = head_lib.multi_label_head(n_classes)
+
+ logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
+ labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # loss = labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))
+ # For large logits, this is approximated as:
+ # loss = labels * (logits < 0) * (-logits) +
+ # (1 - labels) * (logits > 0) * logits
+ expected_unweighted_loss = np.array(
+ [[10., 10.], [15., 0.]], dtype=np.float32)
+ actual_unweighted_loss, _ = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels)
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ self.assertAllClose(
+ expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
+
+ def test_train(self):
+ n_classes = 2
+ head = head_lib.multi_label_head(n_classes)
+
+ logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
+ labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # For large logits, sigmoid cross entropy loss is approximated as:
+ # loss = labels * (logits < 0) * (-logits) +
+ # (1 - labels) * (logits > 0) * logits =>
+ # expected_unweighted_loss = [[10., 10.], [15., 0.]]
+ # Average over classes, sum over weights.
+ expected_loss = 17.5
+ expected_train_result = 'my_train_op'
+ def _train_op_fn(loss):
+ return string_ops.string_join(
+ [constant_op.constant(expected_train_result),
+ string_ops.as_string(loss, precision=3)])
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn)
+
+ self.assertIsNotNone(spec.loss)
+ self.assertEqual({}, spec.eval_metric_ops)
+ self.assertIsNotNone(spec.train_op)
+ self.assertIsNone(spec.export_outputs)
+ _assert_no_hooks(self, spec)
+
+ # Assert predictions, loss, train_op, and summaries.
+ tol = 1e-3
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNotNone(spec.scaffold.summary_op)
+ loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
+ spec.scaffold.summary_op))
+ self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
+ self.assertEqual(
+ six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
+ train_result)
+ _assert_simple_summaries(self, {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ # Average loss over examples.
+ metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
+ }, summary_str, tol)
+
+ def test_train_with_weights(self):
+ n_classes = 2
+ head = head_lib.multi_label_head(n_classes, weight_column='label_weights')
+
+ logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
+ labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # For large logits, sigmoid cross entropy loss is approximated as:
+ # loss = labels * (logits < 0) * (-logits) +
+ # (1 - labels) * (logits > 0) * logits =>
+ # expected_unweighted_loss = [[10., 10.], [15., 0.]]
+ # Average over classes, weighted sum over examples.
+ expected_loss = 25.
+ expected_train_result = 'my_train_op'
+ def _train_op_fn(loss):
+ return string_ops.string_join(
+ [constant_op.constant(expected_train_result),
+ string_ops.as_string(loss, precision=3)])
+
+ spec = head.create_estimator_spec(
+ features={
+ 'x': np.array([[41], [42]], dtype=np.int32),
+ 'label_weights': np.array([[1.], [2.]], dtype=np.float32),
+ },
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn)
+
+ self.assertIsNotNone(spec.loss)
+ self.assertEqual({}, spec.eval_metric_ops)
+ self.assertIsNotNone(spec.train_op)
+ self.assertIsNone(spec.export_outputs)
+ _assert_no_hooks(self, spec)
+
+ # Assert predictions, loss, train_op, and summaries.
+ tol = 1e-3
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNotNone(spec.scaffold.summary_op)
+ loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
+ spec.scaffold.summary_op))
+ self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
+ self.assertEqual(
+ six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
+ train_result)
+ _assert_simple_summaries(self, {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ # Average loss over weighted examples.
+ metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3,
+ }, summary_str, tol)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index 638a4be446..c468c544d3 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -221,6 +221,7 @@ tf_py_test(
"manual",
"noasan", # times out b/63678675
"nomsan",
+ "notsan",
],
)
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index c3ae738acf..39acfcc187 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -14,10 +14,12 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":eval",
":features",
":losses",
":namedtuples",
":train",
+ "//tensorflow/python:util",
],
)
@@ -74,6 +76,18 @@ py_test(
)
py_library(
+ name = "eval",
+ srcs = ["python/eval/__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":classifier_metrics",
+ ":eval_utils",
+ ":summaries",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
name = "losses",
srcs = ["python/losses/__init__.py"],
srcs_version = "PY2AND3",
@@ -257,6 +271,105 @@ py_test(
],
)
+py_library(
+ name = "classifier_metrics",
+ srcs = [
+ "python/eval/python/classifier_metrics.py",
+ "python/eval/python/classifier_metrics_impl.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:image_ops",
+ "//tensorflow/python:linalg_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_test(
+ name = "classifier_metrics_test",
+ srcs = ["python/eval/python/classifier_metrics_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":classifier_metrics",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "eval_utils",
+ srcs = [
+ "python/eval/python/eval_utils.py",
+ "python/eval/python/eval_utils_impl.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_test(
+ name = "eval_utils_test",
+ srcs = ["python/eval/python/eval_utils_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":eval_utils",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_library(
+ name = "summaries",
+ srcs = [
+ "python/eval/python/summaries.py",
+ "python/eval/python/summaries_impl.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":eval_utils",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:util",
+ "//tensorflow/python/ops/losses",
+ ],
+)
+
+py_test(
+ name = "summaries_test",
+ srcs = ["python/eval/python/summaries_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":namedtuples",
+ ":summaries",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py
index 3c423e72d0..67eee771d0 100644
--- a/tensorflow/contrib/gan/__init__.py
+++ b/tensorflow/contrib/gan/__init__.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# Collapse TFGAN into a tiered namespace.
+from tensorflow.contrib.gan.python import eval # pylint:disable=redefined-builtin
from tensorflow.contrib.gan.python import features
from tensorflow.contrib.gan.python import losses
from tensorflow.contrib.gan.python import namedtuples
@@ -32,6 +33,7 @@ from tensorflow.contrib.gan.python.train import *
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
+ 'eval',
'features',
'losses',
]
diff --git a/tensorflow/contrib/gan/python/eval/__init__.py b/tensorflow/contrib/gan/python/eval/__init__.py
new file mode 100644
index 0000000000..bb80461878
--- /dev/null
+++ b/tensorflow/contrib/gan/python/eval/__init__.py
@@ -0,0 +1,39 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TFGAN grouped API. Please see README.md for details and usage."""
+# pylint: disable=,wildcard-import,unused-import
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Collapse eval into a single namespace.
+from tensorflow.contrib.gan.python.eval.python import classifier_metrics
+from tensorflow.contrib.gan.python.eval.python import eval_utils
+from tensorflow.contrib.gan.python.eval.python import summaries
+
+from tensorflow.contrib.gan.python.eval.python.classifier_metrics import *
+from tensorflow.contrib.gan.python.eval.python.eval_utils import *
+from tensorflow.contrib.gan.python.eval.python.summaries import *
+# pylint: enable=wildcard-import,unused-import
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ 'classifier_metrics',
+ 'summaries',
+ 'eval_utils',
+] + classifier_metrics.__all__ + summaries.__all__ + eval_utils.__all__
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py
new file mode 100644
index 0000000000..1c872626a9
--- /dev/null
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py
@@ -0,0 +1,28 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model evaluation tools for TFGAN."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.gan.python.eval.python import classifier_metrics_impl
+# pylint: disable=wildcard-import
+from tensorflow.contrib.gan.python.eval.python.classifier_metrics_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
+
+__all__ = classifier_metrics_impl.__all__
+remove_undocumented(__name__, __all__)
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
new file mode 100644
index 0000000000..727dc81be2
--- /dev/null
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
@@ -0,0 +1,401 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model evaluation tools for TFGAN.
+
+These methods come from https://arxiv.org/abs/1606.03498 and
+https://arxiv.org/abs/1706.08500.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import sys
+import tarfile
+
+from six.moves import urllib
+
+from tensorflow.contrib.layers.python.layers import layers
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import image_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import gfile
+
+
+__all__ = [
+ 'get_graph_def_from_disk',
+ 'preprocess_image',
+ 'run_image_classifier',
+ 'run_inception',
+ 'inception_score',
+ 'classifier_score',
+ 'frechet_inception_distance',
+ 'frechet_classifier_distance',
+]
+
+
+INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v3_2017_09_13.tar.gz'
+INCEPTION_FROZEN_GRAPH = 'frozen_inception_v3.pb'
+INCEPTION_V3_INPUT = 'inputs'
+INCEPTION_V3_OUTPUT = 'InceptionV3/Logits/SpatialSqueeze:0'
+INCEPTION_V3_FINAL_POOL = 'InceptionV3/Logits/AvgPool_1a_8x8/AvgPool:0'
+_INCEPTION_V3_NUM_CLASSES = 1001
+_INCEPTION_V3_FINAL_POOL_SIZE = 2048
+INCEPTION_V3_DEFAULT_IMG_SIZE = 299
+
+
+def _validate_images(images, image_size):
+ images = ops.convert_to_tensor(images)
+ images.shape.with_rank(4)
+ images.shape.assert_is_compatible_with(
+ [None, image_size, image_size, None])
+ return images
+
+
+def _matrix_square_root(mat, eps=1e-10):
+ """Compute symmetric square root of matrix.
+
+ Equivalent to matrix square root when matrix is invertible; note that this is
+ different from an elementwise square root. We want to compute M' where M' =
+ sqrt(mat) such that M' * M' = mat.
+
+ Args:
+ mat: Matrix to take the square root of.
+ eps: Small epsilon such that any element less than eps will not be square
+ rooted to guard against numerical instability.
+
+ Returns:
+ Matrix square root of mat.
+ """
+ s, u, v = linalg_ops.svd(mat)
+ # sqrt is unstable around 0, just use 0 in such case
+ si = array_ops.where(math_ops.less(s, eps), s, math_ops.sqrt(s))
+ return math_ops.matmul(
+ math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True)
+
+
+# Convenience preprocessing function, with fixed defaults.
+# NOTE: Floating-point inputs are expected to be in [0, 1].
+# Copied from /tensorflow_models/slim/preprocessing/inception_preprocessing.py.
+def preprocess_image(
+ image, height=INCEPTION_V3_DEFAULT_IMG_SIZE,
+ width=INCEPTION_V3_DEFAULT_IMG_SIZE, central_fraction=0.875, scope=None):
+ """Prepare one image for evaluation.
+
+ If height and width are specified it would output an image with that size by
+ applying resize_bilinear.
+
+ If central_fraction is specified it would crop the central fraction of the
+ input image.
+
+ Args:
+ image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
+ [0, 1], otherwise it would converted to tf.float32 assuming that the range
+ is [0, MAX], where MAX is largest positive representable number for
+ int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
+ height: integer
+ width: integer
+ central_fraction: Optional Float, fraction of the image to crop.
+ scope: Optional scope for name_scope.
+ Returns:
+ 3-D float Tensor of prepared image.
+ """
+ with ops.name_scope(scope, 'eval_image', [image, height, width]):
+ if image.dtype != dtypes.float32:
+ image = image_ops.convert_image_dtype(image, dtype=dtypes.float32)
+ # Crop the central region of the image with an area containing 87.5% of
+ # the original image.
+ image = image_ops.central_crop(image, central_fraction=central_fraction)
+
+ # Resize the image to the specified height and width.
+ image = array_ops.expand_dims(image, 0)
+ image = image_ops.resize_bilinear(image, [height, width],
+ align_corners=False)
+ image = array_ops.squeeze(image, [0])
+ image = (image - 0.5) * 2.0
+ return image
+
+
+def _kl_divergence(p, p_logits, q):
+ """Computes the Kullback-Liebler divergence between p and q.
+
+ This function uses p's logits in some places to improve numerical stability.
+
+ Specifically:
+
+ KL(p || q) = sum[ p * log(p / q) ]
+ = sum[ p * ( log(p) - log(q) ) ]
+ = sum[ p * ( log_softmax(p_logits) - log(q) ) ]
+
+ Args:
+ p: A 2-D floating-point Tensor p_ij, where `i` corresponds to the minibatch
+ example and `j` corresponds to the probability of being in class `j`.
+ p_logits: A 2-D floating-point Tensor corresponding to logits for `p`.
+ q: A 1-D floating-point Tensor, where q_j corresponds to the probability
+ of class `j`.
+
+ Returns:
+ KL divergence between two distributions. Output dimension is 1D, one entry
+ per distribution in `p`.
+
+ Raises:
+ ValueError: If any of the inputs aren't floating-point.
+ ValueError: If p or p_logits aren't 2D.
+ ValueError: If q isn't 1D.
+ """
+ for tensor in [p, p_logits, q]:
+ if not tensor.dtype.is_floating:
+ raise ValueError('Input %s must be floating type.', tensor.name)
+ p.shape.assert_has_rank(2)
+ p_logits.shape.assert_has_rank(2)
+ q.shape.assert_has_rank(1)
+ return math_ops.reduce_sum(
+ p * (nn_ops.log_softmax(p_logits) - math_ops.log(q)), axis=1)
+
+
+def get_graph_def_from_disk(filename):
+ """Get a GraphDef proto from a disk location."""
+ with gfile.FastGFile(filename, 'rb') as f:
+ return graph_pb2.GraphDef.FromString(f.read())
+
+
+def get_graph_def_from_url_tarball(url, filename):
+ """Get a GraphDef proto from a tarball on the web."""
+ def _progress(count, block_size, total_size):
+ sys.stdout.write('\r>> Downloading %s %.1f%%' % (
+ url, float(count * block_size) / float(total_size) * 100.0))
+ sys.stdout.flush()
+ tar_filename, _ = urllib.request.urlretrieve(url, reporthook=_progress)
+ with tarfile.open(tar_filename, 'r:gz') as tar:
+ proto_str = tar.extractfile(filename).read()
+ return graph_pb2.GraphDef.FromString(proto_str)
+
+
+def _default_graph_def_fn():
+ return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH)
+
+
+def run_inception(images,
+ graph_def=None,
+ default_graph_def_fn=_default_graph_def_fn,
+ image_size=INCEPTION_V3_DEFAULT_IMG_SIZE,
+ input_tensor=INCEPTION_V3_INPUT,
+ output_tensor=INCEPTION_V3_OUTPUT):
+ """Run images through a pretrained Inception classifier.
+
+ Args:
+ images: Input tensors. Must be [batch, height, width, channels]. Input shape
+ and values must be in [-1, 1], which can be achieved using
+ `preprocess_image`.
+ graph_def: A GraphDef proto of a pretrained Inception graph. If `None`,
+ call `default_graph_def_fn` to get GraphDef.
+ default_graph_def_fn: A function that returns a GraphDef. Used if
+ `graph_def` is `None. By default, returns a pretrained InceptionV3 graph.
+ image_size: Required image width and height. See unit tests for the default
+ values.
+ input_tensor: Name of input Tensor.
+ output_tensor: Name of output Tensor. This function will compute activations
+ at the specified layer. Examples include INCEPTION_V3_OUTPUT and
+ INCEPTION_V3_FINAL_POOL which would result in this function computing
+ the final logits or the penultimate pooling layer.
+
+ Returns:
+ Logits.
+
+ Raises:
+ ValueError: If images are not the correct size.
+ ValueError: If neither `graph_def` nor `default_graph_def_fn` are provided.
+ """
+ images = _validate_images(images, image_size)
+
+ if graph_def is None:
+ if default_graph_def_fn is None:
+ raise ValueError('If `graph_def` is `None`, must provide '
+ '`default_graph_def_fn`.')
+ graph_def = default_graph_def_fn()
+
+ activations = run_image_classifier(images, graph_def, input_tensor,
+ output_tensor)
+ if array_ops.rank(activations) != 2:
+ activations = layers.flatten(activations)
+ return activations
+
+
+def run_image_classifier(tensor, graph_def, input_tensor,
+ output_tensor, scope='RunClassifier'):
+ """Runs a network from a frozen graph.
+
+ Args:
+ tensor: An Input tensor.
+ graph_def: A GraphDef proto.
+ input_tensor: Name of input tensor in graph def.
+ output_tensor: Name of output tensor in graph def.
+ scope: Name scope for classifier.
+
+ Returns:
+ Classifier output. Shape depends on the classifier used, but is often
+ [batch, classes].
+
+ Raises:
+ ValueError: If `image_size` is not `None`, and `tensor` are not the correct
+ size.
+ """
+ input_map = {input_tensor: tensor}
+ return_elements = [output_tensor]
+ classifier_output = importer.import_graph_def(
+ graph_def, input_map, return_elements, name=scope)[0]
+
+ return classifier_output
+
+
+def classifier_score(images, classifier_fn, num_batches=1):
+ """Classifier score for evaluating a conditional generative model.
+
+ This is based on the Inception Score, but for an arbitrary classifier.
+
+ This technique is described in detail in https://arxiv.org/abs/1606.03498. In
+ summary, this function calculates
+
+ exp( E[ KL(p(y|x) || p(y)) ] )
+
+ which captures how different the network's classification prediction is from
+ the prior distribution over classes.
+
+ Args:
+ images: Images to calculate the classifier score for.
+ classifier_fn: A function that takes images and produces logits based on a
+ classifier.
+ num_batches: Number of batches to split `generated_images` in to in order to
+ efficiently run them through the classifier network.
+
+ Returns:
+ The classifier score. A floating-point scalar.
+ """
+ generated_images_list = array_ops.split(
+ images, num_or_size_splits=num_batches)
+
+ # Compute the classifier splits using the memory-efficient `map_fn`.
+ logits = functional_ops.map_fn(
+ fn=classifier_fn,
+ elems=array_ops.stack(generated_images_list),
+ parallel_iterations=1,
+ back_prop=False,
+ swap_memory=True,
+ name='RunClassifier')
+ logits = array_ops.concat(array_ops.unstack(logits), 0)
+ logits.shape.assert_has_rank(2)
+ p = nn_ops.softmax(logits)
+ q = math_ops.reduce_mean(p, axis=0)
+ kl = _kl_divergence(p, logits, q)
+ kl.shape.assert_has_rank(1)
+ log_score = math_ops.reduce_mean(kl)
+
+ return math_ops.exp(log_score)
+
+
+inception_score = functools.partial(
+ classifier_score,
+ classifier_fn=functools.partial(
+ run_inception, output_tensor=INCEPTION_V3_OUTPUT))
+
+
+def frechet_classifier_distance(real_images,
+ generated_images,
+ classifier_fn,
+ num_batches=1):
+ """Classifier distance for evaluating a conditional generative model.
+
+ This is based on the Frechet Inception distance, but for an arbitrary
+ classifier.
+
+ This technique is described in detail in https://arxiv.org/abs/1706.08500.
+ Given two Gaussian distribution with means m and m_w and covariance matrices
+ C and C_w, this function calcuates
+
+ |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
+
+ which captures how different the distributions of real images and generated
+ images (or more accurately, their visual features) are. Note that unlike the
+ Inception score, this is a true distance and utilizes information about real
+ world images.
+
+ Args:
+ real_images: Real images to use to compute Frechet Inception distance.
+ generated_images: Generated images to use to compute Frechet Inception
+ distance.
+ classifier_fn: A function that takes images and produces activations
+ based on a classifier.
+ num_batches: Number of batches to split images in to in order to
+ efficiently run them through the classifier network.
+
+ Returns:
+ The Frechet Inception distance. A floating-point scalar.
+ """
+
+ real_images_list = array_ops.split(
+ real_images, num_or_size_splits=num_batches)
+ generated_images_list = array_ops.split(
+ generated_images, num_or_size_splits=num_batches)
+
+ imgs = array_ops.stack(real_images_list + generated_images_list)
+
+ # Compute the activations using the memory-efficient `map_fn`.
+ activations = functional_ops.map_fn(
+ fn=classifier_fn,
+ elems=imgs,
+ parallel_iterations=1,
+ back_prop=False,
+ swap_memory=True,
+ name='RunClassifier')
+
+ # Split the activations by the real and generated images.
+ real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)
+
+ # Ensure the activations have the right shapes.
+ real_a = array_ops.concat(array_ops.unstack(real_a), 0)
+ gen_a = array_ops.concat(array_ops.unstack(gen_a), 0)
+ real_a.shape.assert_has_rank(2)
+ gen_a.shape.assert_has_rank(2)
+
+ # Compute mean and covariance matrices of activations.
+ m = math_ops.reduce_mean(real_a, 0)
+ m_v = math_ops.reduce_mean(gen_a, 0)
+ dim = math_ops.to_float(array_ops.shape(m)[0])
+ sigma = math_ops.matmul(real_a - m, real_a - m, transpose_b=True) / dim
+ sigma_v = math_ops.matmul(gen_a - m, gen_a - m, transpose_b=True) / dim
+
+ # Take matrix square root of the product of covariance matrices.
+ sqcc = _matrix_square_root(math_ops.matmul(sigma, sigma_v))
+
+ # Compute the two components of FID.
+ trace = math_ops.trace(sigma + sigma_v - 2.0 * sqcc)
+ mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm.
+ fid = trace + mean
+
+ return fid
+
+
+frechet_inception_distance = functools.partial(
+ frechet_classifier_distance,
+ classifier_fn=functools.partial(
+ run_inception, output_tensor=INCEPTION_V3_FINAL_POOL))
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
new file mode 100644
index 0000000000..d7bfa1ae28
--- /dev/null
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
@@ -0,0 +1,316 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for TFGAN classifier_metrics."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tarfile
+import tempfile
+
+import numpy as np
+
+from google.protobuf import text_format
+
+from tensorflow.contrib.gan.python.eval.python import classifier_metrics_impl as classifier_metrics
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+mock = test.mock
+
+
+def _numpy_softmax(x):
+ e_x = np.exp(x - np.max(x, axis=1)[:, None])
+ return e_x / np.sum(e_x, axis=1)[:, None]
+
+
+def _expected_inception_score(logits):
+ p = _numpy_softmax(logits)
+ q = np.expand_dims(np.mean(p, 0), 0)
+ per_example_logincscore = np.sum(p * (np.log(p) - np.log(q)), 1)
+ return np.exp(np.mean(per_example_logincscore))
+
+
+def _approximate_matrix_sqrt(mat, eps=1e-8):
+ s, u, v = np.linalg.svd(mat)
+ si = np.where(s < eps, s, np.sqrt(s))
+ return np.dot(np.dot(u, np.diag(si)), v.T)
+
+
+def _expected_fid(real_imgs, gen_imgs):
+ real_imgs = np.asarray(real_imgs)
+ gen_imgs = np.asarray(gen_imgs)
+ m = np.mean(real_imgs, axis=0)
+ m_v = np.mean(gen_imgs, axis=0)
+ dim = float(m.shape[0])
+ sigma = np.dot((real_imgs - m), (real_imgs - m).T) / dim
+ sigma_v = np.dot((gen_imgs - m), (gen_imgs - m).T) / dim
+ sqcc = _approximate_matrix_sqrt(np.dot(sigma, sigma_v))
+ mean = np.square(np.linalg.norm(m - m_v))
+ trace = np.trace(sigma + sigma_v - 2 * sqcc)
+ fid = mean + trace
+ return fid
+
+
+# A dummy GraphDef string with the minimum number of Ops.
+graphdef_string = """
+node {
+ name: "inputs"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 299
+ }
+ dim {
+ size: 299
+ }
+ dim {
+ size: 3
+ }
+ }
+ }
+ }
+}
+node {
+ name: "InceptionV3/Logits/SpatialSqueeze"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1001
+ }
+ }
+ }
+ }
+}
+node {
+ name: "InceptionV3/Logits/AvgPool_1a_8x8/AvgPool"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 2048
+ }
+ }
+ }
+ }
+}
+versions {
+ producer: 24
+}
+"""
+
+
+def _get_dummy_graphdef():
+ dummy_graphdef = graph_pb2.GraphDef()
+ text_format.Merge(graphdef_string, dummy_graphdef)
+ return dummy_graphdef
+
+
+def _run_with_mock(function, *args, **kwargs):
+ with mock.patch.object(
+ classifier_metrics,
+ 'get_graph_def_from_url_tarball') as mock_tarball_getter:
+ mock_tarball_getter.return_value = _get_dummy_graphdef()
+ return function(*args, **kwargs)
+
+
+class ClassifierMetricsTest(test.TestCase):
+
+ def test_run_inception_graph(self):
+ """Test `run_inception` graph construction."""
+ batch_size = 7
+ img = array_ops.ones([batch_size, 299, 299, 3])
+ logits = _run_with_mock(classifier_metrics.run_inception, img)
+
+ self.assertTrue(isinstance(logits, ops.Tensor))
+ logits.shape.assert_is_compatible_with([batch_size, 1001])
+
+ # Check that none of the model variables are trainable.
+ self.assertListEqual([], variables.trainable_variables())
+
+ def test_run_inception_graph_pool_output(self):
+ """Test `run_inception` graph construction with pool output."""
+ batch_size = 3
+ img = array_ops.ones([batch_size, 299, 299, 3])
+ pool = _run_with_mock(
+ classifier_metrics.run_inception, img,
+ output_tensor=classifier_metrics.INCEPTION_V3_FINAL_POOL)
+
+ self.assertTrue(isinstance(pool, ops.Tensor))
+ pool.shape.assert_is_compatible_with([batch_size, 2048])
+
+ # Check that none of the model variables are trainable.
+ self.assertListEqual([], variables.trainable_variables())
+
+ def test_inception_score_graph(self):
+ """Test `inception_score` graph construction."""
+ score = _run_with_mock(classifier_metrics.inception_score,
+ array_ops.zeros([6, 299, 299, 3]), num_batches=3)
+ self.assertTrue(isinstance(score, ops.Tensor))
+ score.shape.assert_has_rank(0)
+
+ # Check that none of the model variables are trainable.
+ self.assertListEqual([], variables.trainable_variables())
+
+ def test_frechet_inception_distance_graph(self):
+ """Test `frechet_inception_distance` graph construction."""
+ img = array_ops.ones([7, 299, 299, 3])
+ distance = _run_with_mock(
+ classifier_metrics.frechet_inception_distance, img, img)
+
+ self.assertTrue(isinstance(distance, ops.Tensor))
+ distance.shape.assert_has_rank(0)
+
+ # Check that none of the model variables are trainable.
+ self.assertListEqual([], variables.trainable_variables())
+
+ def test_run_inception_multicall(self):
+ """Test that `run_inception` can be called multiple times."""
+ for batch_size in (7, 3, 2):
+ img = array_ops.ones([batch_size, 299, 299, 3])
+ _run_with_mock(classifier_metrics.run_inception, img)
+
+ def test_invalid_input(self):
+ """Test that functions properly fail on invalid input."""
+ with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
+ classifier_metrics.run_inception(array_ops.ones([7, 50, 50, 3]))
+
+ p = array_ops.zeros([8, 10])
+ p_logits = array_ops.zeros([8, 10])
+ q = array_ops.zeros([10])
+ with self.assertRaisesRegexp(ValueError, 'must be floating type'):
+ classifier_metrics._kl_divergence(
+ array_ops.zeros([8, 10], dtype=dtypes.int32), p_logits, q)
+
+ with self.assertRaisesRegexp(ValueError, 'must be floating type'):
+ classifier_metrics._kl_divergence(
+ p, array_ops.zeros([8, 10], dtype=dtypes.int32), q)
+
+ with self.assertRaisesRegexp(ValueError, 'must be floating type'):
+ classifier_metrics._kl_divergence(
+ p, p_logits, array_ops.zeros([10], dtype=dtypes.int32))
+
+ with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
+ classifier_metrics._kl_divergence(array_ops.zeros([8]), p_logits, q)
+
+ with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
+ classifier_metrics._kl_divergence(p, array_ops.zeros([8]), q)
+
+ with self.assertRaisesRegexp(ValueError, 'must have rank 1'):
+ classifier_metrics._kl_divergence(p, p_logits, array_ops.zeros([10, 8]))
+
+ def test_inception_score_value(self):
+ """Test that `inception_score` gives the correct value."""
+ logits = np.array([np.array([1, 2] * 500 + [4]),
+ np.array([4, 5] * 500 + [6])])
+ unused_image = array_ops.zeros([2, 299, 299, 3])
+ incscore = _run_with_mock(classifier_metrics.inception_score, unused_image)
+
+ with self.test_session(use_gpu=True) as sess:
+ incscore_np = sess.run(incscore, {'concat:0': logits})
+
+ self.assertAllClose(_expected_inception_score(logits), incscore_np)
+
+ def test_frechet_inception_distance_value(self):
+ """Test that `frechet_inception_distance` gives the correct value."""
+ np.random.seed(0)
+ test_pool_real_a = np.random.randn(5, 2048)
+ test_pool_gen_a = np.random.randn(5, 2048)
+ unused_image = array_ops.zeros([5, 299, 299, 3])
+
+ pool_a = np.stack((test_pool_real_a, test_pool_gen_a))
+ fid_op = _run_with_mock(classifier_metrics.frechet_inception_distance,
+ unused_image, unused_image)
+ activations_tensor = 'RunClassifier/TensorArrayStack/TensorArrayGatherV3:0'
+
+ with self.test_session() as sess:
+ actual_fid = sess.run(fid_op, {activations_tensor: pool_a})
+
+ expected_fid = _expected_fid(test_pool_real_a, test_pool_gen_a)
+ self.assertAllClose(expected_fid, actual_fid, 0.01)
+
+ def test_preprocess_image_graph(self):
+ """Test `preprocess_image` graph construction."""
+ incorrectly_sized_image = array_ops.zeros([520, 240, 3])
+ correct_image = classifier_metrics.preprocess_image(
+ image=incorrectly_sized_image)
+ _run_with_mock(classifier_metrics.run_inception,
+ array_ops.expand_dims(correct_image, 0))
+
+ def test_get_graph_def_from_url_tarball(self):
+ """Test `get_graph_def_from_url_tarball`."""
+ # Write dummy binary GraphDef to tempfile.
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
+ tmp_file.write(_get_dummy_graphdef().SerializeToString())
+ relative_path = os.path.relpath(tmp_file.name)
+
+ # Create gzip tarball.
+ tar_dir = tempfile.mkdtemp()
+ tar_filename = os.path.join(tar_dir, 'tmp.tar.gz')
+ with tarfile.open(tar_filename, 'w:gz') as tar:
+ tar.add(relative_path)
+
+ with mock.patch.object(classifier_metrics, 'urllib') as mock_urllib:
+ mock_urllib.request.urlretrieve.return_value = tar_filename, None
+ graph_def = classifier_metrics.get_graph_def_from_url_tarball(
+ 'unused_url', relative_path)
+
+ self.assertIsInstance(graph_def, graph_pb2.GraphDef)
+ self.assertEqual(_get_dummy_graphdef(), graph_def)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/gan/python/eval/python/eval_utils.py b/tensorflow/contrib/gan/python/eval/python/eval_utils.py
new file mode 100644
index 0000000000..bb7327040c
--- /dev/null
+++ b/tensorflow/contrib/gan/python/eval/python/eval_utils.py
@@ -0,0 +1,28 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility file for visualizing generated images."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.gan.python.eval.python import eval_utils_impl
+# pylint: disable=wildcard-import
+from tensorflow.contrib.gan.python.eval.python.eval_utils_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
+
+__all__ = eval_utils_impl.__all__
+remove_undocumented(__name__, __all__)
diff --git a/tensorflow/contrib/gan/python/eval/python/eval_utils_impl.py b/tensorflow/contrib/gan/python/eval/python/eval_utils_impl.py
new file mode 100644
index 0000000000..6623b56c70
--- /dev/null
+++ b/tensorflow/contrib/gan/python/eval/python/eval_utils_impl.py
@@ -0,0 +1,134 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility file for visualizing generated images."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+
+
+__all__ = [
+ "image_grid",
+ "image_reshaper",
+]
+
+
+# TODO(joelshor): Make this a special case of `image_reshaper`.
+def image_grid(input_tensor, grid_shape, image_shape=(32, 32), num_channels=3):
+ """Arrange a minibatch of images into a grid to form a single image.
+
+ Args:
+ input_tensor: Tensor. Minibatch of images to format, either 4D
+ ([batch size, height, width, num_channels]) or flattened
+ ([batch size, height * width * num_channels]).
+ grid_shape: Sequence of int. The shape of the image grid,
+ formatted as [grid_height, grid_width].
+ image_shape: Sequence of int. The shape of a single image,
+ formatted as [image_height, image_width].
+ num_channels: int. The number of channels in an image.
+
+ Returns:
+ Tensor representing a single image in which the input images have been
+ arranged into a grid.
+
+ Raises:
+ ValueError: The grid shape and minibatch size don't match, or the image
+ shape and number of channels are incompatible with the input tensor.
+ """
+ if grid_shape[0] * grid_shape[1] != int(input_tensor.shape[0]):
+ raise ValueError("Grid shape %s incompatible with minibatch size %i." %
+ (grid_shape, int(input_tensor.shape[0])))
+ if len(input_tensor.shape) == 2:
+ num_features = image_shape[0] * image_shape[1] * num_channels
+ if int(input_tensor.shape[1]) != num_features:
+ raise ValueError("Image shape and number of channels incompatible with "
+ "input tensor.")
+ elif len(input_tensor.shape) == 4:
+ if (int(input_tensor.shape[1]) != image_shape[0] or
+ int(input_tensor.shape[2]) != image_shape[1] or
+ int(input_tensor.shape[3]) != num_channels):
+ raise ValueError("Image shape and number of channels incompatible with "
+ "input tensor.")
+ else:
+ raise ValueError("Unrecognized input tensor format.")
+ height, width = grid_shape[0] * image_shape[0], grid_shape[1] * image_shape[1]
+ input_tensor = array_ops.reshape(
+ input_tensor, tuple(grid_shape) + tuple(image_shape) + (num_channels,))
+ input_tensor = array_ops.transpose(input_tensor, [0, 1, 3, 2, 4])
+ input_tensor = array_ops.reshape(
+ input_tensor, [grid_shape[0], width, image_shape[0], num_channels])
+ input_tensor = array_ops.transpose(input_tensor, [0, 2, 1, 3])
+ input_tensor = array_ops.reshape(
+ input_tensor, [1, height, width, num_channels])
+ return input_tensor
+
+
+def _validate_images(images):
+ for img in images:
+ img.shape.assert_has_rank(3)
+ img.shape.assert_is_fully_defined()
+ if img.shape[-1] not in (1, 3):
+ raise ValueError("image_reshaper only supports 1 or 3 channel images.")
+
+
+# TODO(joelshor): Move the dimension logic from Python to Tensorflow.
+def image_reshaper(images, num_cols=None):
+ """A reshaped summary image.
+
+ Returns an image that will contain all elements in the list and will be
+ laid out in a nearly-square tiling pattern (e.g. 11 images will lead to a
+ 3x4 tiled image).
+
+ Args:
+ images: Image data to summarize. Can be an RGB or grayscale image, a list of
+ such images, or a set of RGB images concatenated along the depth
+ dimension. The shape of each image is assumed to be [batch_size,
+ height, width, depth].
+ num_cols: (Optional) If provided, this is the number of columns in the final
+ output image grid. Otherwise, the number of columns is determined by
+ the number of images.
+
+ Returns:
+ A summary image matching the input with automatic tiling if needed.
+ Output shape is [1, height, width, channels].
+ """
+ if isinstance(images, ops.Tensor):
+ images = array_ops.unstack(images)
+ _validate_images(images)
+
+ num_images = len(images)
+ num_columns = (num_cols if num_cols else
+ int(math.ceil(math.sqrt(num_images))))
+ num_rows = int(math.ceil(float(num_images) / num_columns))
+ rows = [images[x:x+num_columns] for x in range(0, num_images, num_columns)]
+
+ # Add empty image tiles if the last row is incomplete.
+ num_short = num_rows * num_columns - num_images
+ assert num_short >= 0 and num_short < num_columns
+ if num_short > 0:
+ rows[-1].extend([array_ops.zeros_like(images[-1])] * num_short)
+
+ # Convert each row from a list of tensors to a single tensor.
+ rows = [array_ops.concat(row, 1) for row in rows]
+
+ # Stack rows vertically.
+ img = array_ops.concat(rows, 0)
+
+ return array_ops.expand_dims(img, 0)
diff --git a/tensorflow/contrib/gan/python/eval/python/eval_utils_test.py b/tensorflow/contrib/gan/python/eval/python/eval_utils_test.py
new file mode 100644
index 0000000000..cfed4dc513
--- /dev/null
+++ b/tensorflow/contrib/gan/python/eval/python/eval_utils_test.py
@@ -0,0 +1,48 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for eval_utils_test."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.gan.python.eval.python import eval_utils_impl as eval_utils
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class UtilsTest(test.TestCase):
+
+ def test_image_grid(self):
+ eval_utils.image_grid(
+ input_tensor=array_ops.zeros([25, 32, 32, 3]),
+ grid_shape=(5, 5))
+
+ # TODO(joelshor): Add more `image_reshaper` tests.
+ def test_image_reshaper_image_list(self):
+ images = eval_utils.image_reshaper(
+ images=array_ops.unstack(array_ops.zeros([25, 32, 32, 3])),
+ num_cols=2)
+ images.shape.assert_is_compatible_with([1, 13 * 32, 2 * 32, 3])
+
+ def test_image_reshaper_image(self):
+ images = eval_utils.image_reshaper(
+ images=array_ops.zeros([25, 32, 32, 3]),
+ num_cols=2)
+ images.shape.assert_is_compatible_with([1, 13 * 32, 2 * 32, 3])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries.py b/tensorflow/contrib/gan/python/eval/python/summaries.py
new file mode 100644
index 0000000000..ecfdb39499
--- /dev/null
+++ b/tensorflow/contrib/gan/python/eval/python/summaries.py
@@ -0,0 +1,28 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common TFGAN summaries."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.gan.python.eval.python import summaries_impl
+# pylint: disable=wildcard-import
+from tensorflow.contrib.gan.python.eval.python.summaries_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
+
+__all__ = summaries_impl.__all__
+remove_undocumented(__name__, __all__)
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
new file mode 100644
index 0000000000..940b523627
--- /dev/null
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
@@ -0,0 +1,157 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common TFGAN summaries."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.gan.python.eval.python import eval_utils
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.losses import util as loss_util
+from tensorflow.python.summary import summary
+
+__all__ = [
+ 'add_gan_model_image_summaries',
+ 'add_image_comparison_summaries',
+ 'add_gan_model_summaries',
+ 'add_regularization_loss_summaries',
+]
+
+
+def _assert_is_image(data):
+ data.shape.assert_has_rank(4)
+ data.shape[1:].assert_is_fully_defined()
+
+
+def add_gan_model_image_summaries(gan_model, grid_size=10):
+ """Adds image summaries for real and fake images.
+
+ Args:
+ gan_model: A GANModel tuple.
+ grid_size: The size of an image grid.
+
+ Raises:
+ ValueError: If real and generated data aren't images.
+ """
+ _assert_is_image(gan_model.real_data)
+ _assert_is_image(gan_model.generated_data)
+
+ num_images = grid_size ** 2
+ real_image_shape = gan_model.real_data.shape.as_list()[1:3]
+ generated_image_shape = gan_model.generated_data.shape.as_list()[1:3]
+ real_channels = gan_model.real_data.shape.as_list()[3]
+ generated_channels = gan_model.generated_data.shape.as_list()[3]
+
+ summary.image(
+ 'real_data',
+ eval_utils.image_grid(
+ gan_model.real_data[:num_images],
+ grid_shape=(grid_size, grid_size),
+ image_shape=real_image_shape,
+ num_channels=real_channels),
+ max_outputs=1)
+ summary.image(
+ 'generated_data',
+ eval_utils.image_grid(
+ gan_model.generated_data[:num_images],
+ grid_shape=(grid_size, grid_size),
+ image_shape=generated_image_shape,
+ num_channels=generated_channels),
+ max_outputs=1)
+ add_gan_model_summaries(gan_model)
+
+
+def add_image_comparison_summaries(gan_model, num_comparisons=2,
+ display_diffs=False):
+ """Adds image summaries to compare triplets of images.
+
+ The first image is the generator input, the second is the generator output,
+ and the third is the real data. This style of comparison is useful for
+ image translation problems, where the generator input is a corrupted image,
+ the generator output is the reconstruction, and the real data is the target.
+
+ Args:
+ gan_model: A GANModel tuple.
+ num_comparisons: The number of image triplets to display.
+ display_diffs: Also display the difference between generated and target.
+
+ Raises:
+ ValueError: If real data, generated data, and generator inputs aren't
+ images.
+ ValueError: If the generator input, real, and generated data aren't all the
+ same size.
+ """
+ _assert_is_image(gan_model.generator_inputs)
+ _assert_is_image(gan_model.generated_data)
+ _assert_is_image(gan_model.real_data)
+
+ gan_model.generated_data.shape.assert_is_compatible_with(
+ gan_model.generator_inputs.shape)
+ gan_model.real_data.shape.assert_is_compatible_with(
+ gan_model.generated_data.shape)
+
+ image_list = []
+ image_list.extend(
+ array_ops.unstack(gan_model.generator_inputs[:num_comparisons]))
+ image_list.extend(
+ array_ops.unstack(gan_model.generated_data[:num_comparisons]))
+ image_list.extend(array_ops.unstack(gan_model.real_data[:num_comparisons]))
+ if display_diffs:
+ generated_list = array_ops.unstack(
+ gan_model.generated_data[:num_comparisons])
+ real_list = array_ops.unstack(gan_model.real_data[:num_comparisons])
+ diffs = [
+ math_ops.abs(math_ops.to_float(generated) - math_ops.to_float(real)) for
+ generated, real in zip(generated_list, real_list)]
+ image_list.extend(diffs)
+
+ # Reshape image and display.
+ summary.image(
+ 'image_comparison',
+ eval_utils.image_reshaper(image_list, num_cols=num_comparisons),
+ max_outputs=1)
+
+
+def add_gan_model_summaries(gan_model):
+ """Adds typical GANModel summaries.
+
+ Args:
+ gan_model: A GANModel tuple.
+ """
+ with ops.name_scope('generator_variables'):
+ for var in gan_model.generator_variables:
+ summary.histogram(var.name, var)
+ with ops.name_scope('discriminator_variables'):
+ for var in gan_model.discriminator_variables:
+ summary.histogram(var.name, var)
+
+
+def add_regularization_loss_summaries(gan_model):
+ """Adds summaries for a regularization losses..
+
+ Args:
+ gan_model: A GANModel tuple.
+ """
+ if gan_model.generator_scope:
+ summary.scalar(
+ 'generator_regularization_loss',
+ loss_util.get_regularization_loss(gan_model.generator_scope.name))
+ if gan_model.discriminator_scope:
+ summary.scalar(
+ 'discriminator_regularization_loss',
+ loss_util.get_regularization_loss(gan_model.discriminator_scope.name))
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
new file mode 100644
index 0000000000..a3b02bcefc
--- /dev/null
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
@@ -0,0 +1,96 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for TFGAN summaries."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.contrib.gan.python import namedtuples
+from tensorflow.contrib.gan.python.eval.python import summaries_impl as summaries
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.summary import summary
+
+
+def generator_model(inputs):
+ return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs
+
+
+def discriminator_model(inputs, _):
+ return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs
+
+
+def get_gan_model():
+ # TODO(joelshor): Find a better way of creating a variable scope.
+ with variable_scope.variable_scope('generator') as gen_scope:
+ pass
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ pass
+ return namedtuples.GANModel(
+ generator_inputs=array_ops.zeros([4, 32, 32, 3]),
+ generated_data=array_ops.zeros([4, 32, 32, 3]),
+ generator_variables=[variables.Variable(0), variables.Variable(1)],
+ generator_scope=gen_scope,
+ generator_fn=generator_model,
+ real_data=array_ops.ones([4, 32, 32, 3]),
+ discriminator_real_outputs=array_ops.ones([1, 2, 3]),
+ discriminator_gen_outputs=array_ops.ones([1, 2, 3]),
+ discriminator_variables=[variables.Variable(0)],
+ discriminator_scope=dis_scope,
+ discriminator_fn=discriminator_model)
+
+
+class SummariesTest(test.TestCase):
+
+ def testAddGanModelImageSummaries(self):
+ summaries.add_gan_model_image_summaries(get_gan_model(), grid_size=2)
+
+ self.assertEquals(5, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
+ with self.test_session(use_gpu=True):
+ variables.global_variables_initializer().run()
+ summary.merge_all().eval()
+
+ def testAddGanModelSummaries(self):
+ summaries.add_gan_model_summaries(get_gan_model())
+
+ self.assertEquals(3, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
+ with self.test_session(use_gpu=True):
+ variables.global_variables_initializer().run()
+ summary.merge_all().eval()
+
+ def testAddRegularizationLossSummaries(self):
+ summaries.add_regularization_loss_summaries(get_gan_model())
+
+ self.assertEquals(2, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
+ with self.test_session(use_gpu=True):
+ summary.merge_all().eval()
+
+ # TODO(joelshor): Add correctness test.
+ def testAddImageComparisonSummaries(self):
+ summaries.add_image_comparison_summaries(
+ get_gan_model(), display_diffs=True)
+
+ self.assertEquals(1, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
+ with self.test_session(use_gpu=True):
+ summary.merge_all().eval()
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD
index a27bec4801..a18f14112e 100755
--- a/tensorflow/contrib/image/BUILD
+++ b/tensorflow/contrib/image/BUILD
@@ -88,6 +88,7 @@ cuda_py_test(
size = "medium",
srcs = ["python/kernel_tests/image_ops_test.py"],
additional_deps = [
+ ":distort_image_py",
":image_py",
":single_image_random_dot_stereograms_py",
"//third_party/py/numpy",
@@ -100,6 +101,80 @@ cuda_py_test(
)
tf_custom_op_library(
+ name = "python/ops/_distort_image_ops.so",
+ srcs = [
+ "kernels/adjust_hsv_in_yiq_op.cc",
+ "ops/distort_image_ops.cc",
+ ],
+ deps = [
+ "@protobuf_archive//:protobuf",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["distort_image_ops"],
+)
+
+tf_gen_op_wrapper_py(
+ name = "distort_image_ops",
+ deps = [":distort_image_ops_op_lib"],
+)
+
+cc_library(
+ name = "distort_image_ops_cc",
+ srcs = [
+ "kernels/adjust_hsv_in_yiq_op.cc",
+ ],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+ alwayslink = 1,
+)
+
+py_library(
+ name = "distort_image_py",
+ srcs = [
+ "__init__.py",
+ "python/ops/distort_image_ops.py",
+ ],
+ data = [":python/ops/_distort_image_ops.so"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":distort_image_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:image_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:random_ops",
+ ],
+)
+
+cuda_py_test(
+ name = "distort_image_ops_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/distort_image_ops_test.py"],
+ additional_deps = [
+ ":distort_image_py",
+ ":image_py",
+ ":single_image_random_dot_stereograms_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+tf_custom_op_library(
name = "python/ops/_single_image_random_dot_stereograms.so",
srcs = [
"kernels/single_image_random_dot_stereograms_ops.cc",
diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py
index 1ed19265b3..59a322d3ca 100755
--- a/tensorflow/contrib/image/__init__.py
+++ b/tensorflow/contrib/image/__init__.py
@@ -16,11 +16,14 @@
### API
-This module provides functions for image manipulation; currently, only
+This module provides functions for image manipulation; currently, chrominance
+transformas (including changing saturation and hue) in YIQ space and
projective transforms (including rotation) are supported.
@@angles_to_projective_transforms
@@compose_transforms
+@@adjust_yiq_hsv
+@@random_yiq_hsv
@@rotate
@@transform
@@bipartite_match
@@ -31,6 +34,9 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=line-too-long
+from tensorflow.contrib.image.python.ops.distort_image_ops import adjust_hsv_in_yiq
+from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_yiq
+
from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_transforms
from tensorflow.contrib.image.python.ops.image_ops import compose_transforms
from tensorflow.contrib.image.python.ops.image_ops import rotate
@@ -39,5 +45,6 @@ from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms imp
from tensorflow.python.util.all_util import remove_undocumented
+# pylint: enable=line-too-long
remove_undocumented(__name__)
diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc
new file mode 100644
index 0000000000..f4962ed69d
--- /dev/null
+++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc
@@ -0,0 +1,172 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cmath>
+#include <memory>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+class AdjustHsvInYiqOpBase : public OpKernel {
+ protected:
+ explicit AdjustHsvInYiqOpBase(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ struct ComputeOptions {
+ const Tensor* input = nullptr;
+ const Tensor* delta_h = nullptr;
+ const Tensor* scale_s = nullptr;
+ const Tensor* scale_v = nullptr;
+ Tensor* output = nullptr;
+ int64 channel_count = 0;
+ };
+
+ virtual void DoCompute(OpKernelContext* context,
+ const ComputeOptions& options) = 0;
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& delta_h = context->input(1);
+ const Tensor& scale_s = context->input(2);
+ const Tensor& scale_v = context->input(3);
+ OP_REQUIRES(context, input.dims() >= 3,
+ errors::InvalidArgument("input must be at least 3-D, got shape",
+ input.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_h.shape()),
+ errors::InvalidArgument("delta_h must be scalar: ",
+ delta_h.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_s.shape()),
+ errors::InvalidArgument("scale_s must be scalar: ",
+ scale_s.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_v.shape()),
+ errors::InvalidArgument("scale_v must be scalar: ",
+ scale_v.shape().DebugString()));
+ auto channels = input.dim_size(input.dims() - 1);
+ OP_REQUIRES(
+ context, channels == 3,
+ errors::InvalidArgument("input must have 3 channels but instead has ",
+ channels, " channels."));
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+
+ if (input.NumElements() > 0) {
+ const int64 channel_count = input.NumElements() / channels;
+ ComputeOptions options;
+ options.input = &input;
+ options.delta_h = &delta_h;
+ options.scale_s = &scale_s;
+ options.scale_v = &scale_v;
+ options.output = output;
+ options.channel_count = channel_count;
+ DoCompute(context, options);
+ }
+ }
+};
+
+template <class Device>
+class AdjustHsvInYiqOp;
+
+template <>
+class AdjustHsvInYiqOp<CPUDevice> : public AdjustHsvInYiqOpBase {
+ public:
+ explicit AdjustHsvInYiqOp(OpKernelConstruction* context)
+ : AdjustHsvInYiqOpBase(context) {}
+
+ void DoCompute(OpKernelContext* context,
+ const ComputeOptions& options) override {
+ const Tensor* input = options.input;
+ Tensor* output = options.output;
+ const int64 channel_count = options.channel_count;
+ static const int kChannelSize = 3;
+ auto input_data = input->shaped<float, 2>({channel_count, kChannelSize});
+ const float delta_h = options.delta_h->scalar<float>()();
+ const float scale_s = options.scale_s->scalar<float>()();
+ const float scale_v = options.scale_v->scalar<float>()();
+ auto output_data = output->shaped<float, 2>({channel_count, kChannelSize});
+ const int kCostPerChannel = 10;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
+ kCostPerChannel,
+ [channel_count, &input_data, &output_data, delta_h, scale_s, scale_v](
+ int64 start_channel, int64 end_channel) {
+ // Using approximate linear transfomation described in:
+ // https://beesbuzz.biz/code/hsv_color_transforms.php
+ /** Get the constants from sympy
+ from sympy import Matrix
+ from sympy.abc import u, w
+ # Projection matrix to YIQ. http://en.wikipedia.org/wiki/YIQ
+ tyiq = Matrix([[0.299, 0.587, 0.114],
+ [0.596, -0.274, -0.322],
+ [0.211, -0.523, 0.312]])
+ # Hue rotation matrix in YIQ space.
+ hue_proj = Matrix(3,3, [v, 0, 0, 0, vsu, -vsw, 0, vsw, vsu])
+ m = tyiq.inv() * hue_proj * tyiq
+ **/
+ // TODO(huangyp): directly compute the projection matrix from tyiq.
+ static const float t[kChannelSize][kChannelSize][kChannelSize] = {
+ {{.299, .701, .16862179492229},
+ {.587, -.587, .329804745287403},
+ {.114, -.114, -0.498426540209694}},
+ {{.299, -.299, -.327963394172371},
+ {.587, .413, .0346106879248821},
+ {.114, -.114, .293352706247489}},
+ {{.299, -.299, 1.24646136576682},
+ {.587, -.587, -1.04322888291964},
+ {.114, .886, -.203232482847173}}};
+ float m[kChannelSize][kChannelSize] = {{0.}};
+ float su = scale_s * std::cos(delta_h);
+ float sw = scale_s * std::sin(delta_h);
+ for (int q_index = 0; q_index < kChannelSize; q_index++) {
+ for (int p_index = 0; p_index < kChannelSize; p_index++) {
+ m[q_index][p_index] = scale_v * (t[q_index][p_index][0] +
+ t[q_index][p_index][1] * su +
+ t[q_index][p_index][2] * sw);
+ }
+ }
+ // Applying projection matrix to input RGB vectors.
+ const float* p = input_data.data() + start_channel * kChannelSize;
+ float* q = output_data.data() + start_channel * kChannelSize;
+ for (int i = start_channel; i < end_channel; i++) {
+ for (int q_index = 0; q_index < kChannelSize; q_index++) {
+ q[q_index] = 0;
+ for (int p_index = 0; p_index < kChannelSize; p_index++) {
+ q[q_index] += m[q_index][p_index] * p[p_index];
+ }
+ }
+ p += kChannelSize;
+ q += kChannelSize;
+ }
+ });
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("AdjustHsvInYiq").Device(DEVICE_CPU),
+ AdjustHsvInYiqOp<CPUDevice>);
+
+// TODO(huangyp): add the GPU kernel
+} // namespace tensorflow
diff --git a/tensorflow/contrib/image/ops/distort_image_ops.cc b/tensorflow/contrib/image/ops/distort_image_ops.cc
new file mode 100644
index 0000000000..b169b0b2b2
--- /dev/null
+++ b/tensorflow/contrib/image/ops/distort_image_ops.cc
@@ -0,0 +1,60 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+
+// --------------------------------------------------------------------------
+REGISTER_OP("AdjustHsvInYiq")
+ .Input("images: T")
+ .Input("delta_h: float")
+ .Input("scale_s: float")
+ .Input("scale_v: float")
+ .Output("output: T")
+ .Attr("T: {uint8, int8, int16, int32, int64, half, float, double}")
+ .SetShapeFn([](InferenceContext* c) {
+ return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
+ })
+ .Doc(R"Doc(
+Adjust the YIQ hue of one or more images.
+
+`images` is a tensor of at least 3 dimensions. The last dimension is
+interpretted as channels, and must be three.
+
+We used linear transfomation described in:
+ beesbuzz.biz/code/hsv_color_transforms.php
+The input image is considered in the RGB colorspace. Conceptually, the RGB
+colors are first mapped into YIQ space, rotated around the Y channel by
+delta_h in radians, multiplying the chrominance channels (I, Q) by scale_s,
+multiplying all channels (Y, I, Q) by scale_v, and then remapped back to RGB
+colorspace. Each operation described above is a linear transformation.
+
+images: Images to adjust. At least 3-D.
+delta_h: A float scale that represents the hue rotation amount, in radians.
+ Although delta_h can be any float value.
+scale_s: A float scale that represents the factor to multiply the saturation by.
+ scale_s needs to be non-negative.
+scale_v: A float scale that represents the factor to multiply the value by.
+ scale_v needs to be non-negative.
+output: The hsv-adjusted image or images. No clipping will be done in this op.
+ The client can clip them using additional ops in their graph.
+)Doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py
new file mode 100644
index 0000000000..b85f19d29b
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py
@@ -0,0 +1,338 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for python distort_image_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.contrib.image.python.ops import distort_image_ops
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+
+# TODO(huangyp): also measure the differences between AdjustHsvInYiq and
+# AdjustHsv in core.
+class AdjustHueInYiqTest(test_util.TensorFlowTestCase):
+
+ def _adjust_hue_in_yiq_np(self, x_np, delta_h):
+ """Rotate hue in YIQ space.
+
+ Mathematically we first convert rgb color to yiq space, rotate the hue
+ degrees, and then convert back to rgb.
+
+ Args:
+ x_np: input x with last dimension = 3.
+ delta_h: degree of hue rotation, in radians.
+
+ Returns:
+ Adjusted y with the same shape as x_np.
+ """
+ self.assertEqual(x_np.shape[-1], 3)
+ x_v = x_np.reshape([-1, 3])
+ y_v = np.ndarray(x_v.shape, dtype=x_v.dtype)
+ u = np.cos(delta_h)
+ w = np.sin(delta_h)
+ # Projection matrix from RGB to YIQ. Numbers from wikipedia
+ # https://en.wikipedia.org/wiki/YIQ
+ tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.322],
+ [0.211, -0.523, 0.312]])
+ y_v = np.dot(x_v, tyiq.T)
+ # Hue rotation matrix in YIQ space.
+ hue_rotation = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
+ y_v = np.dot(y_v, hue_rotation.T)
+ # Projecting back to RGB space.
+ y_v = np.dot(y_v, np.linalg.inv(tyiq).T)
+ return y_v.reshape(x_np.shape)
+
+ def _adjust_hue_in_yiq_tf(self, x_np, delta_h):
+ with self.test_session(use_gpu=True):
+ x = constant_op.constant(x_np)
+ y = distort_image_ops.adjust_hsv_in_yiq(x, delta_h, 1, 1)
+ y_tf = y.eval()
+ return y_tf
+
+ def test_adjust_random_hue_in_yiq(self):
+ x_shapes = [
+ [2, 2, 3],
+ [4, 2, 3],
+ [2, 4, 3],
+ [2, 5, 3],
+ [1000, 1, 3],
+ ]
+ test_styles = [
+ 'all_random',
+ 'rg_same',
+ 'rb_same',
+ 'gb_same',
+ 'rgb_same',
+ ]
+ for x_shape in x_shapes:
+ for test_style in test_styles:
+ x_np = np.random.rand(*x_shape) * 255.
+ delta_h = (np.random.rand() * 2.0 - 1.0) * np.pi
+ if test_style == 'all_random':
+ pass
+ elif test_style == 'rg_same':
+ x_np[..., 1] = x_np[..., 0]
+ elif test_style == 'rb_same':
+ x_np[..., 2] = x_np[..., 0]
+ elif test_style == 'gb_same':
+ x_np[..., 2] = x_np[..., 1]
+ elif test_style == 'rgb_same':
+ x_np[..., 1] = x_np[..., 0]
+ x_np[..., 2] = x_np[..., 0]
+ else:
+ raise AssertionError('Invalid test style: %s' % (test_style))
+ y_np = self._adjust_hue_in_yiq_np(x_np, delta_h)
+ y_tf = self._adjust_hue_in_yiq_tf(x_np, delta_h)
+ self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4)
+
+ def test_invalid_shapes(self):
+ x_np = np.random.rand(2, 3) * 255.
+ delta_h = np.random.rand() * 2.0 - 1.0
+ with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'):
+ self._adjust_hue_in_yiq_tf(x_np, delta_h)
+ x_np = np.random.rand(4, 2, 4) * 255.
+ delta_h = np.random.rand() * 2.0 - 1.0
+ with self.assertRaisesOpError('input must have 3 channels but instead has '
+ '4 channels'):
+ self._adjust_hue_in_yiq_tf(x_np, delta_h)
+
+
+class AdjustValueInYiqTest(test_util.TensorFlowTestCase):
+
+ def _adjust_value_in_yiq_np(self, x_np, scale):
+ return x_np * scale
+
+ def _adjust_value_in_yiq_tf(self, x_np, scale):
+ with self.test_session(use_gpu=True):
+ x = constant_op.constant(x_np)
+ y = distort_image_ops.adjust_hsv_in_yiq(x, 0, 1, scale)
+ y_tf = y.eval()
+ return y_tf
+
+ def test_adjust_random_value_in_yiq(self):
+ x_shapes = [
+ [2, 2, 3],
+ [4, 2, 3],
+ [2, 4, 3],
+ [2, 5, 3],
+ [1000, 1, 3],
+ ]
+ test_styles = [
+ 'all_random',
+ 'rg_same',
+ 'rb_same',
+ 'gb_same',
+ 'rgb_same',
+ ]
+ for x_shape in x_shapes:
+ for test_style in test_styles:
+ x_np = np.random.rand(*x_shape) * 255.
+ scale = np.random.rand() * 2.0 - 1.0
+ if test_style == 'all_random':
+ pass
+ elif test_style == 'rg_same':
+ x_np[..., 1] = x_np[..., 0]
+ elif test_style == 'rb_same':
+ x_np[..., 2] = x_np[..., 0]
+ elif test_style == 'gb_same':
+ x_np[..., 2] = x_np[..., 1]
+ elif test_style == 'rgb_same':
+ x_np[..., 1] = x_np[..., 0]
+ x_np[..., 2] = x_np[..., 0]
+ else:
+ raise AssertionError('Invalid test style: %s' % (test_style))
+ y_np = self._adjust_value_in_yiq_np(x_np, scale)
+ y_tf = self._adjust_value_in_yiq_tf(x_np, scale)
+ self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-5)
+
+ def test_invalid_shapes(self):
+ x_np = np.random.rand(2, 3) * 255.
+ scale = np.random.rand() * 2.0 - 1.0
+ with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'):
+ self._adjust_value_in_yiq_tf(x_np, scale)
+ x_np = np.random.rand(4, 2, 4) * 255.
+ scale = np.random.rand() * 2.0 - 1.0
+ with self.assertRaisesOpError('input must have 3 channels but instead has '
+ '4 channels'):
+ self._adjust_value_in_yiq_tf(x_np, scale)
+
+
+class AdjustSaturationInYiqTest(test_util.TensorFlowTestCase):
+
+ def _adjust_saturation_in_yiq_tf(self, x_np, scale):
+ with self.test_session(use_gpu=True):
+ x = constant_op.constant(x_np)
+ y = distort_image_ops.adjust_hsv_in_yiq(x, 0, scale, 1)
+ y_tf = y.eval()
+ return y_tf
+
+ def _adjust_saturation_in_yiq_np(self, x_np, scale):
+ """Adjust saturation using linear interpolation."""
+ rgb_weights = np.array([0.299, 0.587, 0.114])
+ gray = np.sum(x_np * rgb_weights, axis=-1, keepdims=True)
+ y_v = x_np * scale + gray * (1 - scale)
+ return y_v
+
+ def test_adjust_random_saturation_in_yiq(self):
+ x_shapes = [
+ [2, 2, 3],
+ [4, 2, 3],
+ [2, 4, 3],
+ [2, 5, 3],
+ [1000, 1, 3],
+ ]
+ test_styles = [
+ 'all_random',
+ 'rg_same',
+ 'rb_same',
+ 'gb_same',
+ 'rgb_same',
+ ]
+ with self.test_session():
+ for x_shape in x_shapes:
+ for test_style in test_styles:
+ x_np = np.random.rand(*x_shape) * 255.
+ scale = np.random.rand() * 2.0 - 1.0
+ if test_style == 'all_random':
+ pass
+ elif test_style == 'rg_same':
+ x_np[..., 1] = x_np[..., 0]
+ elif test_style == 'rb_same':
+ x_np[..., 2] = x_np[..., 0]
+ elif test_style == 'gb_same':
+ x_np[..., 2] = x_np[..., 1]
+ elif test_style == 'rgb_same':
+ x_np[..., 1] = x_np[..., 0]
+ x_np[..., 2] = x_np[..., 0]
+ else:
+ raise AssertionError('Invalid test style: %s' % (test_style))
+ y_baseline = self._adjust_saturation_in_yiq_np(x_np, scale)
+ y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale)
+ self.assertAllClose(y_tf, y_baseline, rtol=2e-5, atol=1e-5)
+
+ def test_invalid_shapes(self):
+ x_np = np.random.rand(2, 3) * 255.
+ scale = np.random.rand() * 2.0 - 1.0
+ with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'):
+ self._adjust_saturation_in_yiq_tf(x_np, scale)
+ x_np = np.random.rand(4, 2, 4) * 255.
+ scale = np.random.rand() * 2.0 - 1.0
+ with self.assertRaisesOpError('input must have 3 channels but instead has '
+ '4 channels'):
+ self._adjust_saturation_in_yiq_tf(x_np, scale)
+
+
+class AdjustHueInYiqBenchmark(test.Benchmark):
+
+ def _benchmark_adjust_hue_in_yiq(self, device, cpu_count):
+ image_shape = [299, 299, 3]
+ warmup_rounds = 100
+ benchmark_rounds = 1000
+ config = config_pb2.ConfigProto()
+ if cpu_count is not None:
+ config.inter_op_parallelism_threads = 1
+ config.intra_op_parallelism_threads = cpu_count
+ with session.Session('', graph=ops.Graph(), config=config) as sess:
+ with ops.device(device):
+ inputs = variables.Variable(
+ random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
+ trainable=False,
+ dtype=dtypes.float32)
+ delta = constant_op.constant(0.1, dtype=dtypes.float32)
+ outputs = distort_image_ops.adjust_hsv_in_yiq(inputs, delta, 1, 1)
+ run_op = control_flow_ops.group(outputs)
+ sess.run(variables.global_variables_initializer())
+ for i in xrange(warmup_rounds + benchmark_rounds):
+ if i == warmup_rounds:
+ start = time.time()
+ sess.run(run_op)
+ end = time.time()
+ step_time = (end - start) / benchmark_rounds
+ tag = device + '_%s' % (cpu_count if cpu_count is not None else 'all')
+ print('benchmarkadjust_hue_in_yiq_299_299_3_%s step_time: %.2f us' %
+ (tag, step_time * 1e6))
+ self.report_benchmark(
+ name='benchmarkadjust_hue_in_yiq_299_299_3_%s' % (tag),
+ iters=benchmark_rounds,
+ wall_time=step_time)
+
+ def benchmark_adjust_hue_in_yiqCpu1(self):
+ self._benchmark_adjust_hue_in_yiq('/cpu:0', 1)
+
+ def benchmark_adjust_hue_in_yiqCpuAll(self):
+ self._benchmark_adjust_hue_in_yiq('/cpu:0', None)
+
+
+class AdjustSaturationInYiqBenchmark(test.Benchmark):
+
+ def _benchmark_adjust_saturation_in_yiq(self, device, cpu_count):
+ image_shape = [299, 299, 3]
+ warmup_rounds = 100
+ benchmark_rounds = 1000
+ config = config_pb2.ConfigProto()
+ if cpu_count is not None:
+ config.inter_op_parallelism_threads = 1
+ config.intra_op_parallelism_threads = cpu_count
+ with session.Session('', graph=ops.Graph(), config=config) as sess:
+ with ops.device(device):
+ inputs = variables.Variable(
+ random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
+ trainable=False,
+ dtype=dtypes.float32)
+ scale = constant_op.constant(0.1, dtype=dtypes.float32)
+ outputs = distort_image_ops.adjust_hsv_in_yiq(inputs, 0, scale, 1)
+ run_op = control_flow_ops.group(outputs)
+ sess.run(variables.global_variables_initializer())
+ for _ in xrange(warmup_rounds):
+ sess.run(run_op)
+ start = time.time()
+ for _ in xrange(benchmark_rounds):
+ sess.run(run_op)
+ end = time.time()
+ step_time = (end - start) / benchmark_rounds
+ tag = '%s' % (cpu_count) if cpu_count is not None else '_all'
+ print('benchmarkAdjustSaturationInYiq_299_299_3_cpu%s step_time: %.2f us' %
+ (tag, step_time * 1e6))
+ self.report_benchmark(
+ name='benchmarkAdjustSaturationInYiq_299_299_3_cpu%s' % (tag),
+ iters=benchmark_rounds,
+ wall_time=step_time)
+
+ def benchmark_adjust_saturation_in_yiq_cpu1(self):
+ self._benchmark_adjust_saturation_in_yiq('/cpu:0', 1)
+
+ def benchmark_adjust_saturation_in_yiq_cpu_all(self):
+ self._benchmark_adjust_saturation_in_yiq('/cpu:0', None)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/image/python/ops/distort_image_ops.py b/tensorflow/contrib/image/python/ops/distort_image_ops.py
new file mode 100644
index 0000000000..39f023a2b4
--- /dev/null
+++ b/tensorflow/contrib/image/python/ops/distort_image_ops.py
@@ -0,0 +1,138 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python layer for distort_image_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import image_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import resource_loader
+
+_distort_image_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile('_distort_image_ops.so'))
+
+
+# pylint: disable=invalid-name
+def random_hsv_in_yiq(image,
+ max_delta_hue=0,
+ lower_saturation=1,
+ upper_saturation=1,
+ lower_value=1,
+ upper_value=1,
+ seed=None):
+ """Adjust hue, saturation, value of an RGB image randomly in YIQ color space.
+
+ Equivalent to `adjust_yiq_hsv()` but uses a `delta_h` randomly
+ picked in the interval `[-max_delta_hue, max_delta_hue]`, a `scale_saturation`
+ randomly picked in the interval `[lower_saturation, upper_saturation]`, and
+ a `scale_value` randomly picked in the interval
+ `[lower_saturation, upper_saturation]`.
+
+ Args:
+ image: RGB image or images. Size of the last dimension must be 3.
+ max_delta_hue: float. Maximum value for the random delta_hue. Passing 0
+ disables adjusting hue.
+ lower_saturation: float. Lower bound for the random scale_saturation.
+ upper_saturation: float. Upper bound for the random scale_saturation.
+ lower_value: float. Lower bound for the random scale_value.
+ upper_value: float. Upper bound for the random scale_value.
+ seed: An operation-specific seed. It will be used in conjunction
+ with the graph-level seed to determine the real seeds that will be
+ used in this operation. Please see the documentation of
+ set_random_seed for its interaction with the graph-level random seed.
+
+ Returns:
+ 3-D float tensor of shape `[height, width, channels]`.
+
+ Raises:
+ ValueError: if `max_delta`, `lower_saturation`, `upper_saturation`,
+ `lower_value`, or `upper_Value` is invalid.
+ """
+ if max_delta_hue < 0:
+ raise ValueError('max_delta must be non-negative.')
+
+ if lower_saturation < 0:
+ raise ValueError('lower_saturation must be non-negative.')
+
+ if lower_value < 0:
+ raise ValueError('lower_value must be non-negative.')
+
+ if lower_saturation > upper_saturation:
+ raise ValueError('lower_saturation must be < upper_saturation.')
+
+ if lower_value > upper_value:
+ raise ValueError('lower_value must be < upper_value.')
+
+ if max_delta_hue == 0:
+ delta_hue = 0
+ else:
+ delta_hue = random_ops.random_uniform(
+ [], -max_delta_hue, max_delta_hue, seed=seed)
+ if lower_saturation == upper_saturation:
+ scale_saturation = lower_saturation
+ else:
+ scale_saturation = random_ops.random_uniform(
+ [], lower_saturation, upper_saturation, seed=seed)
+ if lower_value == upper_value:
+ scale_value = lower_value
+ else:
+ scale_value = random_ops.random_uniform(
+ [], lower_value, upper_value, seed=seed)
+ return adjust_hsv_in_yiq(image, delta_hue, scale_saturation, scale_value)
+
+
+def adjust_hsv_in_yiq(image,
+ delta_hue=0,
+ scale_saturation=1,
+ scale_value=1,
+ name=None):
+ """Adjust hue, saturation, value of an RGB image in YIQ color space.
+
+ This is a convenience method that converts an RGB image to float
+ representation, converts it to YIQ, rotates the color around the Y channel by
+ delta_hue in radians, scales the chrominance channels (I, Q) by
+ scale_saturation, scales all channels (Y, I, Q) by scale_value,
+ converts back to RGB, and then back to the original data type.
+
+ `image` is an RGB image. The image hue is adjusted by converting the
+ image to YIQ, rotating around the luminance channel (Y) by
+ `delta_hue` in radians, multiplying the chrominance channels (I, Q) by
+ `scale_saturation`, and multiplying all channels (Y, I, Q) by
+ `scale_value`. The image is then converted back to RGB.
+
+ Args:
+ image: RGB image or images. Size of the last dimension must be 3.
+ delta_hue: float, the hue rotation amount, in radians.
+ scale_saturation: float, factor to multiply the saturation by.
+ scale_value: float, factor to multiply the value by.
+ name: A name for this operation (optional).
+
+ Returns:
+ Adjusted image(s), same shape and DType as `image`.
+ """
+ with ops.name_scope(name, 'adjust_hsv_in_yiq', [image]) as name:
+ image = ops.convert_to_tensor(image, name='image')
+ # Remember original dtype to so we can convert back if needed
+ orig_dtype = image.dtype
+ flt_image = image_ops.convert_image_dtype(image, dtypes.float32)
+
+ rgb_altered = _distort_image_ops.adjust_hsv_in_yiq(
+ flt_image, delta_hue, scale_saturation, scale_value)
+
+ return image_ops.convert_image_dtype(rgb_altered, orig_dtype)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index d6dc2a01ea..719e5da21d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -1678,9 +1678,14 @@ class _MultiHead(Head):
ModelFnOps that merges all heads for TRAIN.
"""
losses = []
+ metrics = {}
additional_train_ops = []
for m in all_model_fn_ops:
losses.append(m.loss)
+ if m.eval_metric_ops is not None:
+ for k, v in six.iteritems(m.eval_metric_ops):
+ # metrics["%s/%s" % (k, head_name)] = v
+ metrics[k] = v
additional_train_ops.append(m.train_op)
loss = self._loss_merger(losses)
@@ -1689,7 +1694,8 @@ class _MultiHead(Head):
return model_fn.ModelFnOps(
mode=model_fn.ModeKeys.TRAIN,
loss=loss,
- train_op=train_op)
+ train_op=train_op,
+ eval_metric_ops=metrics)
def _merge_infer(self, all_model_fn_ops):
"""Merges list of ModelFnOps for inference.
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index 25a6674858..3881bf533d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -1703,7 +1703,7 @@ class MultiHeadTest(test.TestCase):
self.assertIsNone(model_fn_ops.predictions)
self.assertIsNotNone(model_fn_ops.loss)
self.assertIsNotNone(model_fn_ops.train_op)
- self.assertFalse(model_fn_ops.eval_metric_ops)
+ self.assertTrue(model_fn_ops.eval_metric_ops)
self.assertIsNone(model_fn_ops.output_alternatives)
with session.Session() as sess:
@@ -1728,7 +1728,7 @@ class MultiHeadTest(test.TestCase):
self.assertIsNone(model_fn_ops.predictions)
self.assertIsNotNone(model_fn_ops.loss)
self.assertIsNotNone(model_fn_ops.train_op)
- self.assertFalse(model_fn_ops.eval_metric_ops)
+ self.assertTrue(model_fn_ops.eval_metric_ops)
self.assertIsNone(model_fn_ops.output_alternatives)
with session.Session() as sess:
@@ -1755,7 +1755,7 @@ class MultiHeadTest(test.TestCase):
self.assertIsNone(model_fn_ops.predictions)
self.assertIsNotNone(model_fn_ops.loss)
self.assertIsNotNone(model_fn_ops.train_op)
- self.assertFalse(model_fn_ops.eval_metric_ops)
+ self.assertTrue(model_fn_ops.eval_metric_ops)
self.assertIsNone(model_fn_ops.output_alternatives)
with session.Session() as sess:
diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
index ec493b8463..aab0f3f494 100644
--- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
+++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
@@ -113,7 +113,7 @@ struct GatherTree<CPUDevice, int32> {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
int32 seq_len_b = sequence_length(batch, beam);
- if (seq_len_b == 0) {
+ if (seq_len_b <= 0) {
continue;
}
beams(seq_len_b - 1, batch, beam) =
diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
index e3c0d0bfa9..ee68b55d20 100644
--- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
+++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
@@ -33,7 +33,10 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
+
const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam);
+ if (seq_len_b <= 0) continue;
+
#define GET_IX(time_ix, beam_ix) \
(batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam);
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index f8b92f1832..b1537eab01 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -155,7 +155,7 @@ void SetMemory(NodeExecStats* nt, OpKernelContext* ctx) {
// retrieving the sizes from the wrapped allocator removes the
// executor's reference to it, so allocator_pair.second must not
// be dereferenced again after this statement
- auto sizes = allocator_pair.second->GetSizesAndUnRef();
+ const auto sizes = allocator_pair.second->GetSizesAndUnRef();
memory->set_allocator_name(allocator_pair.first->Name());
memory->set_total_bytes(std::get<0>(sizes));
memory->set_peak_bytes(std::get<1>(sizes));
@@ -1373,7 +1373,7 @@ Status ExecutorImpl::BuildControlFlowInfo(const Graph* g,
for (const Edge* out_edge : curr_node->out_edges()) {
Node* out = out_edge->dst();
- int out_id = out->id();
+ const int out_id = out->id();
// Add to ready queue if not visited.
bool is_visited = visited[out_id];
@@ -1417,7 +1417,8 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
// Ask the device to fill in the device context map.
Device* device = impl_->params_.device;
- Status fill_status = device->FillContextMap(graph, &device_context_map_);
+ const Status fill_status =
+ device->FillContextMap(graph, &device_context_map_);
if (!fill_status.ok()) {
done(fill_status);
return;
@@ -1525,7 +1526,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
inline_ready.pop_front();
const Node* node = tagged_node.node;
FrameState* input_frame = tagged_node.input_frame;
- int64 input_iter = tagged_node.input_iter;
+ const int64 input_iter = tagged_node.input_iter;
const int id = node->id();
const NodeItem& item = *gview.node(id);
@@ -1637,7 +1638,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
accessed);
}
- bool completed =
+ const bool completed =
NodeDone(s, state->item->node, ready, stats, nullptr);
delete state;
if (completed) Finish();
@@ -1803,7 +1804,7 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
}
for (int i = 0; i < item.num_outputs; ++i) {
- TensorValue val = ctx->release_output(i);
+ const TensorValue val = ctx->release_output(i);
if (*ctx->is_output_dead() || val.tensor == nullptr) {
// Unless it's a Switch or a Recv, the node must produce a
// tensor value at i-th output.
@@ -1893,7 +1894,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
TaggedNodeSeq* ready) {
const Node* node = tagged_node.node;
FrameState* input_frame = tagged_node.input_frame;
- int64 input_iter = tagged_node.input_iter;
+ const int64 input_iter = tagged_node.input_iter;
const bool is_dead = tagged_node.is_dead;
// Propagates outputs along out edges, and puts newly ready nodes
@@ -1913,7 +1914,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
&impl_->gview_, input_iter, ready);
} else if (item->is_enter) {
bool is_constant;
- Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
+ const Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
DCHECK(s.ok()) << s;
FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
output_iter = 0;
@@ -1983,7 +1984,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
// completion of this node makes its frame completed.
if (is_frame_done) {
FrameState* parent_frame = input_frame->parent_frame;
- int64 parent_iter = input_frame->parent_iter;
+ const int64 parent_iter = input_frame->parent_iter;
DeleteFrame(input_frame, ready);
if (parent_frame != nullptr) {
// The completion of frame may cause completions in its parent frame.
@@ -2026,7 +2027,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
}
bool completed = false;
- size_t ready_size = ready.size();
+ const size_t ready_size = ready.size();
if (ready_size == 0 || !s.ok()) {
completed = (num_outstanding_ops_.fetch_sub(1) == 1);
} else if (ready_size > 1) {
@@ -2166,7 +2167,7 @@ void ExecutorState::DumpIterationState(const FrameState* frame,
const std::vector<const Node*>* nodes = frame->nodes;
// Dump any waiting nodes that are holding on to tensors.
for (const Node* node : *nodes) {
- int node_id = node->id();
+ const int node_id = node->id();
PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
@@ -2175,14 +2176,14 @@ void ExecutorState::DumpIterationState(const FrameState* frame,
}
// Then the active nodes.
for (const Node* node : *nodes) {
- int node_id = node->id();
+ const int node_id = node->id();
PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
DumpActiveNodeState(node_id, iteration->input_tensors);
}
}
// Show all input tensors in use.
- int total_input_tensors = frame->total_input_tensors;
+ const int total_input_tensors = frame->total_input_tensors;
size_t total_bytes = 0;
for (int i = 0; i < total_input_tensors; ++i) {
const Entry& input = iteration->input_tensors[i];
@@ -2291,7 +2292,7 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
// First, propagate dead_exits (if any) to the parent frame.
FrameState* parent_frame = frame->parent_frame;
- int64 parent_iter = frame->parent_iter;
+ const int64 parent_iter = frame->parent_iter;
if (parent_frame != nullptr) {
mutex_lock paranet_frame_lock(parent_frame->mu);
// Propagate all the dead exits to the parent frame.
@@ -2300,7 +2301,8 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
for (const Edge* e : node->out_edges()) {
const Node* dst_node = e->dst();
- auto dst_pending_id = impl_->gview_.node(dst_node->id())->pending_id;
+ const auto dst_pending_id =
+ impl_->gview_.node(dst_node->id())->pending_id;
// TODO(yuanbyu): We don't need this if we require the subgraph
// given to an executor not to contain a sink node.
@@ -2358,7 +2360,7 @@ void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter,
}
if (is_frame_done) {
FrameState* parent_frame = frame->parent_frame;
- int64 parent_iter = frame->parent_iter;
+ const int64 parent_iter = frame->parent_iter;
DeleteFrame(frame, ready);
if (parent_frame != nullptr) {
// The completion of frame may cause completions in its parent frame.
@@ -2433,7 +2435,7 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
}
}
} else {
- bool increment_dead =
+ const bool increment_dead =
(is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value));
int pending, dead;
iter_state->adjust_for_activation(dst_pending_id, increment_dead,
@@ -2497,7 +2499,7 @@ void ExecutorState::FrameState::AddLoopInv(const NodeItem* item,
inv_values.push_back({item->node, entry});
// Make this value available to all iterations.
- bool is_dead = !entry.has_value;
+ const bool is_dead = !entry.has_value;
for (int i = 0; i <= iteration_count; ++i) {
EntryVector outputs{entry};
ActivateNodes(item, is_dead, i, &outputs, ready);
@@ -2522,7 +2524,7 @@ bool ExecutorState::FrameState::IsIterationDone(int64 iter) {
void ExecutorState::FrameState::IncrementIteration(const GraphView* gview,
TaggedNodeSeq* ready) {
iteration_count++;
- int64 next_iter = iteration_count;
+ const int64 next_iter = iteration_count;
// Initialize the next iteration.
IterationState* iter_state =
@@ -2567,7 +2569,7 @@ void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph,
Executor** executor) {
ExecutorImpl* impl = new ExecutorImpl(params, graph);
- Status s = impl->Initialize();
+ const Status s = impl->Initialize();
if (s.ok()) {
*executor = impl;
} else {
@@ -2579,7 +2581,7 @@ Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph,
Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
const NodeDef& ndef, int graph_def_version,
OpKernel** kernel) {
- auto device_type = DeviceType(device->attributes().device_type());
+ const auto device_type = DeviceType(device->attributes().device_type());
auto allocator = device->GetAllocator(AllocatorAttributes());
return CreateOpKernel(device_type, device, allocator, flib, ndef,
graph_def_version, kernel);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 0ceef2687b..cff6e30c04 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -665,7 +665,9 @@ tf_kernel_library(
tf_kernel_library(
name = "matrix_band_part_op",
prefix = "matrix_band_part_op",
- deps = ARRAY_DEPS,
+ deps = if_cuda([
+ ":cuda_solvers",
+ ]) + ARRAY_DEPS,
)
tf_kernel_library(
@@ -1332,7 +1334,7 @@ tf_kernel_library(
"transpose_functor_gpu.cu.cc",
"transpose_functor.h",
],
- visibility = ["//visibility:private"],
+ visibility = [":friends"],
deps = [
":ops_util",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc
index 755ce7c43b..6668b0d654 100644
--- a/tensorflow/core/kernels/cholesky_op.cc
+++ b/tensorflow/core/kernels/cholesky_op.cc
@@ -76,18 +76,19 @@ class CholeskyOp : public LinearAlgebraOp<Scalar> {
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void MatrixBandPart<GPUDevice, T>::Compute( \
- const GPUDevice& d, Eigen::DenseIndex num_lower, \
- Eigen::DenseIndex num_upper, typename TTypes<T, 3>::ConstTensor input, \
- typename TTypes<T, 3>::Tensor output); \
- extern template struct MatrixBandPart<GPUDevice, T>;
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ struct MatrixBandPartFunctor<GPUDevice, T> { \
+ void operator()(OpKernelContext* context, const GPUDevice& device, \
+ int num_upper_diags, int num_lower_diags, bool transpose, \
+ typename TTypes<T, 3>::ConstTensor input, \
+ typename TTypes<T, 3>::Tensor output); \
+ }; \
+ extern template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC);
-
} // namespace functor
template <class Scalar>
@@ -131,9 +132,9 @@ class CholeskyOpGpu : public AsyncOpKernel {
// before we launch each of the Cholesky factorization kernels in paralle.
auto input_reshaped = input.template flat_inner_dims<Scalar, 3>();
auto output_reshaped = output->template flat_inner_dims<Scalar, 3>();
- functor::MatrixBandPart<GPUDevice, Scalar>::Compute(
- context->eigen_device<GPUDevice>(), n, 0, input_reshaped,
- output_reshaped);
+ functor::MatrixBandPartFunctor<GPUDevice, Scalar> fn;
+ fn(context, context->eigen_device<GPUDevice>(), n, 0, false /* transpose */,
+ input_reshaped, output_reshaped);
// Launch a Cholesky kernel for each matrix in the batch.
const int64 batch_size = input_reshaped.dimension(0);
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 1de7d6a2c0..ecfe51d599 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -1024,7 +1024,7 @@ __device__ __forceinline__ T WarpSumReduce(T val) {
assert(__popc(kWidth) == 1);
int sub_warp = cub::LaneId() / kWidth;
int zeros = sub_warp * kWidth;
- unsigned mask = ((1U << kWidth) - 1) << zeros;
+ unsigned mask = ((1UL << kWidth) - 1) << zeros;
for (int delta = kWidth / 2; delta > 0; delta /= 2) {
val += CudaShuffleXor(mask, val, delta);
}
diff --git a/tensorflow/core/kernels/matrix_band_part_op.cc b/tensorflow/core/kernels/matrix_band_part_op.cc
index 894b0113c2..8b8accc0b3 100644
--- a/tensorflow/core/kernels/matrix_band_part_op.cc
+++ b/tensorflow/core/kernels/matrix_band_part_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/kernels/matrix_band_part_op.h"
+#include <algorithm>
#include <memory>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -32,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -48,31 +50,50 @@ class MatrixBandPartOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
+ const TensorShape& input_shape = input.shape();
+ // Preliminary validation of sizes.
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
+ errors::InvalidArgument(
+ "input must be at least 2-dim, received shape: ",
+ input.shape().DebugString()));
+ auto input_reshaped = input.flat_inner_dims<T, 3>();
+
const Tensor& num_lower_in = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in.shape()),
errors::InvalidArgument("num_lower must be scalar, got shape ",
num_lower_in.shape().DebugString()));
const int64 num_lower = num_lower_in.scalar<int64>()();
+ OP_REQUIRES(
+ context, num_lower <= input_reshaped.dimension(1),
+ errors::InvalidArgument(
+ "num_lower must be negative or less or equal to number of rows (",
+ input_reshaped.dimension(1), ") got: ", num_lower));
const Tensor& num_upper_in = context->input(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()),
errors::InvalidArgument("num_upper must be scalar, got shape ",
num_upper_in.shape().DebugString()));
const int64 num_upper = num_upper_in.scalar<int64>()();
+ OP_REQUIRES(context, num_upper <= input_reshaped.dimension(2),
+ errors::InvalidArgument("num_upper must be negative or less or "
+ "equal to number of columns (",
+ input_reshaped.dimension(2),
+ ") got: ", num_upper));
+
+ if ((num_lower < 0 || num_lower == input_reshaped.dimension(1)) &&
+ (num_upper < 0 || num_upper == input_reshaped.dimension(2))) {
+ // This is a no-op.
+ context->set_output(0, input);
+ return;
+ }
- const TensorShape& input_shape = input.shape();
- // Preliminary validation of sizes.
- OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
- errors::InvalidArgument(
- "input must be at least 2-dim, received shape: ",
- input.shape().DebugString()));
- auto input_reshaped = input.flat_inner_dims<T, 3>();
Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &output));
+ OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
+ {0}, 0, input_shape, &output));
auto output_reshaped = output->flat_inner_dims<T, 3>();
- functor::MatrixBandPart<Device, T>::Compute(
- context->eigen_device<Device>(), num_lower, num_upper, input_reshaped,
- output_reshaped);
+ functor::MatrixBandPartFunctor<Device, T> fn;
+ fn(context, context->eigen_device<Device>(), num_lower, num_upper,
+ false /* transpose */, input_reshaped, output_reshaped);
}
private:
@@ -98,54 +119,118 @@ TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_BAND_PART);
// Implementation of the functor specialization for CPU.
namespace functor {
-template <typename T>
-struct MatrixBandPart<CPUDevice, T> {
- static void Compute(const CPUDevice& d, int64 num_lower, int64 num_upper,
- typename TTypes<T, 3>::ConstTensor input,
- typename TTypes<T, 3>::Tensor output) {
- if ((num_lower < 0 || num_lower >= input.dimension(1)) &&
- (num_upper < 0 || num_upper >= input.dimension(2))) {
- output.device(d) = input;
+
+// CPU implementation of BandPartFunctor.
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Scalar>
+struct MatrixBandPartFunctor<CPUDevice, Scalar> {
+ void operator()(OpKernelContext* context, const CPUDevice& device,
+ int num_lower_diags, int num_upper_diags, bool transpose,
+ typename TTypes<Scalar, 3>::ConstTensor input,
+ typename TTypes<Scalar, 3>::Tensor output) {
+ const int64 b = input.dimension(0);
+ const int64 m = input.dimension(1);
+ const int64 n = input.dimension(2);
+ auto thread_pool =
+ context->device()->tensorflow_cpu_worker_threads()->workers;
+ const int64 total_rows = b * m;
+ const int64 row_cost = 10 * n;
+ const bool in_place = input.data() == output.data();
+ CHECK(!(transpose && in_place));
+ if (!transpose) {
+ auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
+ if (!in_place) {
+ std::fill(output.data() + begin * n, output.data() + end * n,
+ Scalar());
+ }
+ const int64 batch_begin = begin / m;
+ const int64 batch_end = (end + m - 1) / m;
+ for (int64 batch = batch_begin; batch < batch_end; ++batch) {
+ const int64 row_begin = begin > batch * m ? begin % m : 0;
+ const int64 row_end = end < (batch + 1) * m ? end % m : m;
+ for (int64 row = row_begin; row < row_end; ++row) {
+ const int64 band_start =
+ num_lower_diags < 0
+ ? 0
+ : std::min(n, std::max(0ll, row - num_lower_diags));
+ const int64 band_end = num_upper_diags < 0
+ ? n
+ : std::min(static_cast<int64>(n),
+ row + num_upper_diags + 1);
+ if (in_place) {
+ if (band_start > 0) {
+ std::fill(&output(batch, row, 0),
+ &output(batch, row, band_start), Scalar());
+ }
+ if (band_end < n) {
+ std::fill(&output(batch, row, band_end), &output(batch, row, n),
+ Scalar());
+ }
+ } else {
+ if (band_start < band_end) {
+ const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row,
+ band_start);
+ const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
+ 1, 1, band_end - band_start);
+ output.slice(indices, sizes) = input.slice(indices, sizes);
+ }
+ }
+ }
+ }
+ };
+ thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
} else {
- output.device(d) = output.constant(T());
- for (int64 r = 0; r < output.dimension(0); ++r) {
- for (int64 i = 0; i < output.dimension(1); ++i) {
- const int64 band_start =
- num_lower < 0 ? 0 : std::max(0ll, i - num_lower);
- const int64 band_end =
- num_upper < 0 ? output.dimension(2)
- : std::min(static_cast<int64>(output.dimension(2)),
- i + num_upper + 1);
- if (band_start < band_end) {
- const Eigen::DSizes<Eigen::DenseIndex, 3> indices(r, i, band_start);
- const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
- 1, 1, band_end - band_start);
- output.slice(indices, sizes) = input.slice(indices, sizes);
+ output.device(device) = output.constant(Scalar());
+ auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
+ const int64 batch_begin = begin / m;
+ const int64 batch_end = (end + m - 1) / m;
+ for (int64 batch = batch_begin; batch < batch_end; ++batch) {
+ const int64 row_begin = begin > batch * m ? begin % m : 0;
+ const int64 row_end = end < (batch + 1) * m ? end % m : m;
+ for (int64 row = row_begin; row < row_end; ++row) {
+ const int64 band_start =
+ num_lower_diags < 0 ? 0 : std::max(0ll, row - num_lower_diags);
+ const int64 band_end = num_upper_diags < 0
+ ? n
+ : std::min(static_cast<int64>(n),
+ row + num_upper_diags + 1);
+ for (int64 col = band_start; col < band_end; ++col) {
+ output(batch, col, row) = input(batch, row, col);
+ }
}
}
- }
+ };
+ thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
}
}
};
+#define DEFINE_CPU_SPEC(T) template struct MatrixBandPartFunctor<CPUDevice, T>;
+TF_CALL_POD_TYPES(DEFINE_CPU_SPEC);
+#undef DEFINE_CPU_SPEC
+
} // namespace functor
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void MatrixBandPart<GPUDevice, T>::Compute( \
- const GPUDevice& d, Eigen::DenseIndex num_lower, \
- Eigen::DenseIndex num_upper, typename TTypes<T, 3>::ConstTensor input, \
- typename TTypes<T, 3>::Tensor output); \
- extern template struct MatrixBandPart<GPUDevice, T>;
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ struct MatrixBandPartFunctor<GPUDevice, T> { \
+ void operator()(OpKernelContext* context, const GPUDevice& device, \
+ int num_upper_diags, int num_lower_diags, bool transpose, \
+ typename TTypes<T, 3>::ConstTensor input, \
+ typename TTypes<T, 3>::Tensor output); \
+ }; \
+ extern template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_bool(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC);
+#undef DECLARE_GPU_SPEC
} // namespace functor
// Registration of the GPU implementations.
diff --git a/tensorflow/core/kernels/matrix_band_part_op.h b/tensorflow/core/kernels/matrix_band_part_op.h
index b601255b25..43b6724dae 100644
--- a/tensorflow/core/kernels/matrix_band_part_op.h
+++ b/tensorflow/core/kernels/matrix_band_part_op.h
@@ -16,61 +16,22 @@ limitations under the License.
#ifndef TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
#define TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
-// Generator definition for MatrixBandPartOp, must be compilable by nvcc.
-
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-
-namespace generator {
-
-template <typename T>
-class MatrixBandPartGenerator {
- public:
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE MatrixBandPartGenerator(
- Eigen::DenseIndex num_lower, Eigen::DenseIndex num_upper,
- typename TTypes<T, 3>::ConstTensor input)
- : num_lower_(num_lower), num_upper_(num_upper), input_(input) {}
-
- EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
- operator()(const Eigen::array<Eigen::DenseIndex, 3>& coords) const {
- return (((num_lower_ < 0 || coords[1] - coords[2] <= num_lower_) &&
- (num_upper_ < 0 || coords[2] - coords[1] <= num_upper_))
- ? input_(coords)
- : T());
- }
-
- private:
- const Eigen::DenseIndex num_lower_;
- const Eigen::DenseIndex num_upper_;
- typename TTypes<T, 3>::ConstTensor input_;
-};
-
-} // namespace generator
-
namespace functor {
-template <typename Device, typename T>
-struct MatrixBandPart {
- EIGEN_ALWAYS_INLINE static void Compute(
- const Device& d, Eigen::DenseIndex num_lower, Eigen::DenseIndex num_upper,
- typename TTypes<T, 3>::ConstTensor input,
- typename TTypes<T, 3>::Tensor output) {
- if ((num_lower < 0 || num_lower >= input.dimension(1)) &&
- (num_upper < 0 || num_upper >= input.dimension(2))) {
- output.device(d) = input;
- } else {
- generator::MatrixBandPartGenerator<T> generator(num_lower, num_upper,
- input);
- output.device(d) = output.generate(generator);
- }
- }
+template <typename Device, typename Scalar>
+struct MatrixBandPartFunctor {
+ void operator()(OpKernelContext* context, const Device& device,
+ int num_upper_diags, int num_lower_diags, bool transpose,
+ typename TTypes<Scalar, 3>::ConstTensor input,
+ typename TTypes<Scalar, 3>::Tensor output);
};
} // namespace functor
-
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
diff --git a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
index ccc10ebada..afebdacdca 100644
--- a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
@@ -17,22 +17,92 @@ limitations under the License.
#define EIGEN_USE_GPU
+#include <complex>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/matrix_band_part_op.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
-
+namespace functor {
typedef Eigen::GpuDevice GPUDevice;
-#define DEFINE_GPU_SPEC(T) \
- template class generator::MatrixBandPartGenerator<T>; \
- template struct functor::MatrixBandPart<GPUDevice, T>;
+template <bool transpose, typename Scalar>
+__global__ void MatrixBandPartKernel(const int num_threads,
+ const int batch_size, const int m,
+ const int n, const int num_lower_diags,
+ const int num_upper_diags,
+ const Scalar* input_ptr,
+ Scalar* output_ptr) {
+ if (!transpose) {
+ CUDA_1D_KERNEL_LOOP(index, num_threads) {
+ const int col = index % n;
+ const int row = (index / n) % m;
+ const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
+ const int band_end =
+ (num_upper_diags < 0 ? n : row + num_upper_diags + 1);
+ if (col < band_start || col >= band_end) {
+ output_ptr[index] = Scalar();
+ } else {
+ output_ptr[index] = input_ptr[index];
+ }
+ }
+ } else {
+ const int matrix_size = m * n;
+ CUDA_1D_KERNEL_LOOP(index, num_threads) {
+ const int col = index % n;
+ const int row = (index / n) % m;
+ const int batch = index / matrix_size;
+ const int transpose_index = batch * matrix_size + n * col + row;
+ const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
+ const int band_end =
+ (num_upper_diags < 0 ? n : row + num_upper_diags + 1);
+ if (col < band_start || col >= band_end) {
+ output_ptr[transpose_index] = Scalar();
+ } else {
+ output_ptr[transpose_index] = input_ptr[index];
+ }
+ }
+ }
+}
+
+template <typename Scalar>
+struct MatrixBandPartFunctor<GPUDevice, Scalar> {
+ void operator()(OpKernelContext* context, const GPUDevice& device,
+ int num_lower_diags, int num_upper_diags, bool transpose,
+ typename TTypes<Scalar, 3>::ConstTensor input,
+ typename TTypes<Scalar, 3>::Tensor output) {
+ using CudaType = typename CUDAComplexT<Scalar>::type;
+ const int batch_size = input.dimension(0);
+ const int m = input.dimension(1);
+ const int n = input.dimension(2);
+ const CudaType* input_ptr = reinterpret_cast<const CudaType*>(input.data());
+ CudaType* output_ptr = reinterpret_cast<CudaType*>(output.data());
+ CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device);
+ if (transpose) {
+ MatrixBandPartKernel<true>
+ <<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
+ config.virtual_thread_count, batch_size, m, n, num_lower_diags,
+ num_upper_diags, input_ptr, output_ptr);
+ } else {
+ MatrixBandPartKernel<false>
+ <<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
+ config.virtual_thread_count, batch_size, m, n, num_lower_diags,
+ num_upper_diags, input_ptr, output_ptr);
+ }
+ }
+};
+
+#define DEFINE_GPU_SPEC(T) template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
TF_CALL_bool(DEFINE_GPU_SPEC);
TF_CALL_complex64(DEFINE_GPU_SPEC);
TF_CALL_complex128(DEFINE_GPU_SPEC);
-} // end namespace tensorflow
+#undef DEFINE_GPU_SPEC
+} // namespace functor
+} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index d4b7aa96f2..099d67fdf9 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -6899,6 +6899,41 @@ op {
}
}
op {
+ name: "DecodeRaw"
+ input_arg {
+ name: "bytes"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT16
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "little_endian"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "DecodeWav"
input_arg {
name: "contents"
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index 48b2362342..322cf9dcb9 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -399,7 +399,7 @@ matrix is assumed to be zero and not accessed.
`rhs` is a tensor of shape `[..., M, K]`.
The output is a tensor of shape `[..., M, K]`. If `adjoint` is
-`True` then the innermost matrices in output` satisfy matrix equations
+`True` then the innermost matrices in `output` satisfy matrix equations
`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
If `adjoint` is `False` then the strictly then the innermost matrices in
`output` satisfy matrix equations
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index c0e073757b..691f320141 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -6308,6 +6308,7 @@ op {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
+ type: DT_UINT16
type: DT_UINT8
type: DT_INT16
type: DT_INT8
@@ -12593,7 +12594,7 @@ op {
}
}
summary: "Solves systems of linear equations with upper or lower triangular matrices by"
- description: "backsubstitution.\n\n`matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form\nsquare matrices. If `lower` is `True` then the strictly upper triangular part\nof each inner-most matrix is assumed to be zero and not accessed.\nIf `lower` is False then the strictly lower triangular part of each inner-most\nmatrix is assumed to be zero and not accessed.\n`rhs` is a tensor of shape `[..., M, K]`.\n\nThe output is a tensor of shape `[..., M, K]`. If `adjoint` is\n`True` then the innermost matrices in output` satisfy matrix equations\n`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.\nIf `adjoint` is `False` then the strictly then the innermost matrices in\n`output` satisfy matrix equations\n`adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`."
+ description: "backsubstitution.\n\n`matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form\nsquare matrices. If `lower` is `True` then the strictly upper triangular part\nof each inner-most matrix is assumed to be zero and not accessed.\nIf `lower` is False then the strictly lower triangular part of each inner-most\nmatrix is assumed to be zero and not accessed.\n`rhs` is a tensor of shape `[..., M, K]`.\n\nThe output is a tensor of shape `[..., M, K]`. If `adjoint` is\n`True` then the innermost matrices in `output` satisfy matrix equations\n`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.\nIf `adjoint` is `False` then the strictly then the innermost matrices in\n`output` satisfy matrix equations\n`adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`."
}
op {
name: "Max"
@@ -29115,7 +29116,7 @@ op {
}
}
summary: "Return substrings from `Tensor` of strings."
- description: "For each string in the input `Tensor`, creates a substring starting at index\n`pos` with a total length of `len`.\n\nIf `len` defines a substring that would extend beyond the length of the input\nstring, then as many characters as possible are used.\n\nIf `pos` is negative or specifies a character index larger than any of the input\nstrings, then an `InvalidArgumentError` is thrown.\n\n`pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on\nOp creation.\n\n*NOTE*: `Substr` supports broadcasting up to two dimensions. More about\nbroadcasting\n[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)\n\n---\n\nExamples\n\nUsing scalar `pos` and `len`:\n\n```python\ninput = [b\'Hello\', b\'World\']\nposition = 1\nlength = 3\n\noutput = [b\'ell\', b\'orl\']\n```\n\nUsing `pos` and `len` with same shape as `input`:\n\n```python\ninput = [[b\'ten\', b\'eleven\', b\'twelve\'],\n [b\'thirteen\', b\'fourteen\', b\'fifteen\'],\n [b\'sixteen\', b\'seventeen\', b\'eighteen\']]\nposition = [[1, 2, 3],\n [1, 2, 3],\n [1, 2, 3]]\nlength = [[2, 3, 4],\n [4, 3, 2],\n [5, 5, 5]]\n\noutput = [[b\'en\', b\'eve\', b\'lve\'],\n [b\'hirt\', b\'urt\', b\'te\'],\n [b\'ixtee\', b\'vente\', b\'hteen\']]\n```\n\nBroadcasting `pos` and `len` onto `input`:\n\n```\ninput = [[b\'ten\', b\'eleven\', b\'twelve\'],\n [b\'thirteen\', b\'fourteen\', b\'fifteen\'],\n [b\'sixteen\', b\'seventeen\', b\'eighteen\'],\n [b\'nineteen\', b\'twenty\', b\'twentyone\']]\nposition = [1, 2, 3]\nlength = [1, 2, 3]\n\noutput = [[b\'e\', b\'ev\', b\'lve\'],\n [b\'h\', b\'ur\', b\'tee\'],\n [b\'i\', b\'ve\', b\'hte\'],\n [b\'i\', b\'en\', b\'nty\']]\n```\n\nBroadcasting `input` onto `pos` and `len`:\n\n```\ninput = b\'thirteen\'\nposition = [1, 5, 7]\nlength = [3, 2, 1]\n\noutput = [b\'hir\', b\'ee\', b\'n\"]\n```"
+ description: "For each string in the input `Tensor`, creates a substring starting at index\n`pos` with a total length of `len`.\n\nIf `len` defines a substring that would extend beyond the length of the input\nstring, then as many characters as possible are used.\n\nIf `pos` is negative or specifies a character index larger than any of the input\nstrings, then an `InvalidArgumentError` is thrown.\n\n`pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on\nOp creation.\n\n*NOTE*: `Substr` supports broadcasting up to two dimensions. More about\nbroadcasting\n[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)\n\n---\n\nExamples\n\nUsing scalar `pos` and `len`:\n\n```python\ninput = [b\'Hello\', b\'World\']\nposition = 1\nlength = 3\n\noutput = [b\'ell\', b\'orl\']\n```\n\nUsing `pos` and `len` with same shape as `input`:\n\n```python\ninput = [[b\'ten\', b\'eleven\', b\'twelve\'],\n [b\'thirteen\', b\'fourteen\', b\'fifteen\'],\n [b\'sixteen\', b\'seventeen\', b\'eighteen\']]\nposition = [[1, 2, 3],\n [1, 2, 3],\n [1, 2, 3]]\nlength = [[2, 3, 4],\n [4, 3, 2],\n [5, 5, 5]]\n\noutput = [[b\'en\', b\'eve\', b\'lve\'],\n [b\'hirt\', b\'urt\', b\'te\'],\n [b\'ixtee\', b\'vente\', b\'hteen\']]\n```\n\nBroadcasting `pos` and `len` onto `input`:\n\n```\ninput = [[b\'ten\', b\'eleven\', b\'twelve\'],\n [b\'thirteen\', b\'fourteen\', b\'fifteen\'],\n [b\'sixteen\', b\'seventeen\', b\'eighteen\'],\n [b\'nineteen\', b\'twenty\', b\'twentyone\']]\nposition = [1, 2, 3]\nlength = [1, 2, 3]\n\noutput = [[b\'e\', b\'ev\', b\'lve\'],\n [b\'h\', b\'ur\', b\'tee\'],\n [b\'i\', b\'ve\', b\'hte\'],\n [b\'i\', b\'en\', b\'nty\']]\n```\n\nBroadcasting `input` onto `pos` and `len`:\n\n```\ninput = b\'thirteen\'\nposition = [1, 5, 7]\nlength = [3, 2, 1]\n\noutput = [b\'hir\', b\'ee\', b\'n\']\n```"
}
op {
name: "Sum"
diff --git a/tensorflow/core/profiler/g3doc/command_line.md b/tensorflow/core/profiler/g3doc/command_line.md
index d41ac7290d..7ce53bc1b8 100644
--- a/tensorflow/core/profiler/g3doc/command_line.md
+++ b/tensorflow/core/profiler/g3doc/command_line.md
@@ -62,7 +62,7 @@ Note: this feature is not well maintained now.
```shell
# Build the tool.
-bazel build --config opt tensorflow/core/profiler:profiler
+bazel build --config opt third_party/tensorflow/core/profiler:profiler
# Help information, including detail 'option' instructions.
bazel-bin/tensorflow/core/profiler/profiler help
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index ccb861c93a..9ba3a509c3 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -19,12 +19,12 @@ limitations under the License.
// TensorFlow uses semantic versioning, see http://semver.org/.
#define TF_MAJOR_VERSION 1
-#define TF_MINOR_VERSION 3
+#define TF_MINOR_VERSION 4
#define TF_PATCH_VERSION 0
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX ""
+#define TF_VERSION_SUFFIX "-dev"
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md
index 7ebf5c4a2c..04cd462848 100644
--- a/tensorflow/docs_src/install/install_c.md
+++ b/tensorflow/docs_src/install/install_c.md
@@ -35,7 +35,7 @@ enable TensorFlow for C:
OS="linux" # Change to "darwin" for Mac OS
TARGET_DIRECTORY="/usr/local"
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.3.0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.4.0-dev.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md
index b991fd0f93..b7fa1fe39a 100644
--- a/tensorflow/docs_src/install/install_go.md
+++ b/tensorflow/docs_src/install/install_go.md
@@ -35,7 +35,7 @@ steps to install this library and enable TensorFlow for Go:
TF_TYPE="cpu" # Change to "gpu" for GPU support
TARGET_DIRECTORY='/usr/local'
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.3.0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.4.0-dev.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
index 2adcd4da73..e1200dde12 100644
--- a/tensorflow/docs_src/install/install_java.md
+++ b/tensorflow/docs_src/install/install_java.md
@@ -34,7 +34,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.3.0</version>
+ <version>1.4.0-dev</version>
</dependency>
```
@@ -63,7 +63,7 @@ As an example, these steps will create a Maven project that uses TensorFlow:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.3.0</version>
+ <version>1.4.0-dev</version>
</dependency>
</dependencies>
</project>
@@ -122,7 +122,7 @@ refer to the simpler instructions above instead.
Take the following steps to install TensorFlow for Java on Linux or Mac OS:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.3.0.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.4.0-dev.jar),
which is the TensorFlow Java Archive (JAR).
2. Decide whether you will run TensorFlow for Java on CPU(s) only or with
@@ -141,7 +141,7 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS:
OS=$(uname -s | tr '[:upper:]' '[:lower:]')
mkdir -p ./jni
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.3.0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.4.0-dev.tar.gz" |
tar -xz -C ./jni
### Install on Windows
@@ -149,10 +149,10 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS:
Take the following steps to install TensorFlow for Java on Windows:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.3.0.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.4.0-dev.jar),
which is the TensorFlow Java Archive (JAR).
2. Download the following Java Native Interface (JNI) file appropriate for
- [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.3.0.zip).
+ [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.4.0-dev.zip).
3. Extract this .zip file.
@@ -200,7 +200,7 @@ must be part of your `classpath`. For example, you can include the
downloaded `.jar` in your `classpath` by using the `-cp` compilation flag
as follows:
-<pre><b>javac -cp libtensorflow-1.3.0.jar HelloTF.java</b></pre>
+<pre><b>javac -cp libtensorflow-1.4.0-dev.jar HelloTF.java</b></pre>
### Running
@@ -214,11 +214,11 @@ two files are available to the JVM:
For example, the following command line executes the `HelloTF` program on Linux
and Mac OS X:
-<pre><b>java -cp libtensorflow-1.3.0.jar:. -Djava.library.path=./jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.4.0-dev.jar:. -Djava.library.path=./jni HelloTF</b></pre>
And the following command line executes the `HelloTF` program on Windows:
-<pre><b>java -cp libtensorflow-1.3.0.jar;. -Djava.library.path=jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.4.0-dev.jar;. -Djava.library.path=jni HelloTF</b></pre>
If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
installed TensorFlow for Java and are ready to use the API. If the program
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index bd4550fc46..576099f054 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -172,7 +172,7 @@ Take the following steps to install TensorFlow with Virtualenv:
virtualenv environment:
<pre>(tensorflow)$ <b>pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp34-cp34m-linux_x86_64.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.4.0dev-cp34-cp34m-linux_x86_64.whl</b></pre>
If you encounter installation problems, see
[Common Installation Problems](#common_installation_problems).
@@ -277,7 +277,7 @@ take the following steps:
<pre>
$ <b>sudo pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp34-cp34m-linux_x86_64.whl</b>
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.4.0dev-cp34-cp34m-linux_x86_64.whl</b>
</pre>
If this step fails, see
@@ -464,7 +464,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
<pre>
(tensorflow)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp34-cp34m-linux_x86_64.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.4.0dev-cp34-cp34m-linux_x86_64.whl</b></pre>
<a name="ValidateYourInstallation"></a>
@@ -632,14 +632,14 @@ This section documents the relevant values for Linux installations.
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.4.0dev-cp27-none-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.3.0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.4.0dev-cp27-none-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -651,14 +651,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.4.0dev-cp34-cp34m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.3.0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.4.0dev-cp34-cp34m-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -670,14 +670,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.4.0dev-cp35-cp35m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.3.0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.4.0dev-cp35-cp35m-linux_x86_64.whl
</pre>
@@ -689,14 +689,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.4.0dev-cp36-cp36m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.3.0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.4.0dev-cp36-cp36m-linux_x86_64.whl
</pre>
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index 5a53def707..b6daeb0dd6 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -109,7 +109,7 @@ Take the following steps to install TensorFlow with Virtualenv:
TensorFlow in the active Virtualenv is as follows:
<pre> $ <b>pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0-py2-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.4.0dev-py2-none-any.whl</b></pre>
If you encounter installation problems, see
[Common Installation Problems](#common-installation-problems).
@@ -230,7 +230,7 @@ take the following steps:
issue the following command:
<pre> $ <b>sudo pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0-py2-none-any.whl</b> </pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.4.0dev-py2-none-any.whl</b> </pre>
If the preceding command fails, see
[installation problems](#common-installation-problems).
@@ -339,7 +339,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
TensorFlow for Python 2.7:
<pre> (tensorflow)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0-py2-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.4.0dev-py2-none-any.whl</b></pre>
<a name="ValidateYourInstallation"></a>
@@ -512,7 +512,7 @@ This section documents the relevant values for Mac OS installations.
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.4.0dev-py2-none-any.whl
</pre>
@@ -520,7 +520,7 @@ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0-py2-none-any.
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.4.0dev-py3-none-any.whl
</pre>
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index d58db00a4c..d8925d3909 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -342,10 +342,10 @@ Invoke `pip install` to install that pip package.
The filename of the `.whl` file depends on your platform.
For example, the following command will install the pip package
-for TensorFlow 1.3.0 on Linux:
+for TensorFlow 1.4.0dev on Linux:
<pre>
-$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.3.0-py2-none-any.whl</b>
+$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.4.0dev-py2-none-any.whl</b>
</pre>
## Validate your installation
diff --git a/tensorflow/docs_src/performance/xla/tfcompile.md b/tensorflow/docs_src/performance/xla/tfcompile.md
index 60aff2f633..f57ca3948d 100644
--- a/tensorflow/docs_src/performance/xla/tfcompile.md
+++ b/tensorflow/docs_src/performance/xla/tfcompile.md
@@ -47,7 +47,7 @@ This section details high level steps for generating an executable binary with
Identify the feeds and fetches that correspond to the input and output
arguments for the generated function. Then configure the `feeds` and `fetches`
-in a [`tensorflow.tfcompile.Config`](https://www.tensorflow.org/code/tensorflow/compiler/aot/tfcompile.proto)
+in a [`tensorflow.tf2xla.Config`](https://www.tensorflow.org/code/tensorflow/compiler/tf2xla/tf2xla.proto)
proto.
```textproto
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 1673b52289..7dbe115dd6 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -119,27 +119,6 @@ func VariableShape(scope *Scope, input tf.Output, optional ...VariableShapeAttr)
return op.Output(0)
}
-// Checks whether a resource handle-based variable has been initialized.
-//
-// Arguments:
-// resource: the input resource handle.
-//
-// Returns a scalar boolean which is true if the variable has been
-// initialized.
-func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "VarIsInitializedOp",
- Input: []tf.Input{
- resource,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Assigns a new value to a variable.
//
// Any ReadVariableOp with a control dependency on this op is guaranteed to return
@@ -7757,6 +7736,100 @@ func ShardedFilename(scope *Scope, basename tf.Output, shard tf.Output, num_shar
return op.Output(0)
}
+// Saves input tensors slices to disk.
+//
+// This is like `Save` except that tensors can be listed in the saved file as being
+// a slice of a larger tensor. `shapes_and_slices` specifies the shape of the
+// larger tensor and the slice that this tensor covers. `shapes_and_slices` must
+// have as many elements as `tensor_names`.
+//
+// Elements of the `shapes_and_slices` input must either be:
+//
+// * The empty string, in which case the corresponding tensor is
+// saved normally.
+// * A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the
+// `dimI` are the dimensions of the larger tensor and `slice-spec`
+// specifies what part is covered by the tensor to save.
+//
+// `slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1`
+// where each `sliceI` is either:
+//
+// * The string `-` meaning that the slice covers all indices of this dimension
+// * `start,length` where `start` and `length` are integers. In that
+// case the slice covers `length` indices starting at `start`.
+//
+// See also `Save`.
+//
+// Arguments:
+// filename: Must have a single element. The name of the file to which we write the
+// tensor.
+// tensor_names: Shape `[N]`. The names of the tensors to be saved.
+// shapes_and_slices: Shape `[N]`. The shapes and slice specifications to use when
+// saving the tensors.
+// data: `N` tensors to save.
+//
+// Returns the created operation.
+func SaveSlices(scope *Scope, filename tf.Output, tensor_names tf.Output, shapes_and_slices tf.Output, data []tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SaveSlices",
+ Input: []tf.Input{
+ filename, tensor_names, shapes_and_slices, tf.OutputList(data),
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// MergeV2CheckpointsAttr is an optional argument to MergeV2Checkpoints.
+type MergeV2CheckpointsAttr func(optionalAttr)
+
+// MergeV2CheckpointsDeleteOldDirs sets the optional delete_old_dirs attribute to value.
+//
+// value: see above.
+// If not specified, defaults to true
+func MergeV2CheckpointsDeleteOldDirs(value bool) MergeV2CheckpointsAttr {
+ return func(m optionalAttr) {
+ m["delete_old_dirs"] = value
+ }
+}
+
+// V2 format specific: merges the metadata files of sharded checkpoints. The
+//
+// result is one logical checkpoint, with one physical metadata file and renamed
+// data files.
+//
+// Intended for "grouping" multiple checkpoints in a sharded checkpoint setup.
+//
+// If delete_old_dirs is true, attempts to delete recursively the dirname of each
+// path in the input checkpoint_prefixes. This is useful when those paths are non
+// user-facing temporary locations.
+//
+// Arguments:
+// checkpoint_prefixes: prefixes of V2 checkpoints to merge.
+// destination_prefix: scalar. The desired final prefix. Allowed to be the same
+// as one of the checkpoint_prefixes.
+//
+// Returns the created operation.
+func MergeV2Checkpoints(scope *Scope, checkpoint_prefixes tf.Output, destination_prefix tf.Output, optional ...MergeV2CheckpointsAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MergeV2Checkpoints",
+ Input: []tf.Input{
+ checkpoint_prefixes, destination_prefix,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// AudioSpectrogramAttr is an optional argument to AudioSpectrogram.
type AudioSpectrogramAttr func(optionalAttr)
@@ -9378,52 +9451,6 @@ func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.
return op.Output(0)
}
-// Saves input tensors slices to disk.
-//
-// This is like `Save` except that tensors can be listed in the saved file as being
-// a slice of a larger tensor. `shapes_and_slices` specifies the shape of the
-// larger tensor and the slice that this tensor covers. `shapes_and_slices` must
-// have as many elements as `tensor_names`.
-//
-// Elements of the `shapes_and_slices` input must either be:
-//
-// * The empty string, in which case the corresponding tensor is
-// saved normally.
-// * A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the
-// `dimI` are the dimensions of the larger tensor and `slice-spec`
-// specifies what part is covered by the tensor to save.
-//
-// `slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1`
-// where each `sliceI` is either:
-//
-// * The string `-` meaning that the slice covers all indices of this dimension
-// * `start,length` where `start` and `length` are integers. In that
-// case the slice covers `length` indices starting at `start`.
-//
-// See also `Save`.
-//
-// Arguments:
-// filename: Must have a single element. The name of the file to which we write the
-// tensor.
-// tensor_names: Shape `[N]`. The names of the tensors to be saved.
-// shapes_and_slices: Shape `[N]`. The shapes and slice specifications to use when
-// saving the tensors.
-// data: `N` tensors to save.
-//
-// Returns the created operation.
-func SaveSlices(scope *Scope, filename tf.Output, tensor_names tf.Output, shapes_and_slices tf.Output, data []tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SaveSlices",
- Input: []tf.Input{
- filename, tensor_names, shapes_and_slices, tf.OutputList(data),
- },
- }
- return scope.AddOperation(opspec)
-}
-
// Returns the rank of a tensor.
//
// This operation returns an integer representing the rank of `input`.
@@ -13132,33 +13159,6 @@ func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_s
return op.Output(0)
}
-// Saves the input tensors to disk.
-//
-// The size of `tensor_names` must match the number of tensors in `data`. `data[i]`
-// is written to `filename` with name `tensor_names[i]`.
-//
-// See also `SaveSlices`.
-//
-// Arguments:
-// filename: Must have a single element. The name of the file to which we write
-// the tensor.
-// tensor_names: Shape `[N]`. The names of the tensors to be saved.
-// data: `N` tensors to save.
-//
-// Returns the created operation.
-func Save(scope *Scope, filename tf.Output, tensor_names tf.Output, data []tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Save",
- Input: []tf.Input{
- filename, tensor_names, tf.OutputList(data),
- },
- }
- return scope.AddOperation(opspec)
-}
-
// Shuffle dimensions of x according to a permutation.
//
// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
@@ -15248,6 +15248,80 @@ func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// Restores tensors from a V2 checkpoint.
+//
+// For backward compatibility with the V1 format, this Op currently allows
+// restoring from a V1 checkpoint as well:
+// - This Op first attempts to find the V2 index file pointed to by "prefix", and
+// if found proceed to read it as a V2 checkpoint;
+// - Otherwise the V1 read path is invoked.
+// Relying on this behavior is not recommended, as the ability to fall back to read
+// V1 might be deprecated and eventually removed.
+//
+// By default, restores the named tensors in full. If the caller wishes to restore
+// specific slices of stored tensors, "shape_and_slices" should be non-empty
+// strings and correspondingly well-formed.
+//
+// Callers must ensure all the named tensors are indeed stored in the checkpoint.
+//
+// Arguments:
+// prefix: Must have a single element. The prefix of a V2 checkpoint.
+// tensor_names: shape {N}. The names of the tensors to be restored.
+// shape_and_slices: shape {N}. The slice specs of the tensors to be restored.
+// Empty strings indicate that they are non-partitioned tensors.
+// dtypes: shape {N}. The list of expected dtype for the tensors. Must match
+// those stored in the checkpoint.
+//
+// Returns shape {N}. The restored tensors, whose shapes are read from the
+// checkpoint directly.
+func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, dtypes []tf.DataType) (tensors []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtypes": dtypes}
+ opspec := tf.OpSpec{
+ Type: "RestoreV2",
+ Input: []tf.Input{
+ prefix, tensor_names, shape_and_slices,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if tensors, idx, err = makeOutputList(op, idx, "tensors"); err != nil {
+ scope.UpdateErr("RestoreV2", err)
+ return
+ }
+ return tensors
+}
+
+// Returns x / y element-wise for integer types.
+//
+// Truncation designates that negative numbers will round fractional quantities
+// toward zero. I.e. -7 / 5 = 1. This matches C semantics but it is different
+// than Python semantics. See `FloorDiv` for a division function that matches
+// Python Semantics.
+//
+// *NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func TruncateDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "TruncateDiv",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// MaxPoolGradGradV2Attr is an optional argument to MaxPoolGradGradV2.
type MaxPoolGradGradV2Attr func(optionalAttr)
@@ -16448,6 +16522,27 @@ func LoadAndRemapMatrix(scope *Scope, ckpt_path tf.Output, old_tensor_name tf.Ou
return op.Output(0)
}
+// Checks whether a resource handle-based variable has been initialized.
+//
+// Arguments:
+// resource: the input resource handle.
+//
+// Returns a scalar boolean which is true if the variable has been
+// initialized.
+func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "VarIsInitializedOp",
+ Input: []tf.Input{
+ resource,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResizeAreaAttr is an optional argument to ResizeArea.
type ResizeAreaAttr func(optionalAttr)
@@ -18087,7 +18182,7 @@ func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Out
// position = [1, 5, 7]
// length = [3, 2, 1]
//
-// output = [b'hir', b'ee', b'n"]
+// output = [b'hir', b'ee', b'n']
// ```
//
// Arguments:
@@ -18762,54 +18857,6 @@ func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_f
return op.Output(0), op.Output(1), op.Output(2)
}
-// MergeV2CheckpointsAttr is an optional argument to MergeV2Checkpoints.
-type MergeV2CheckpointsAttr func(optionalAttr)
-
-// MergeV2CheckpointsDeleteOldDirs sets the optional delete_old_dirs attribute to value.
-//
-// value: see above.
-// If not specified, defaults to true
-func MergeV2CheckpointsDeleteOldDirs(value bool) MergeV2CheckpointsAttr {
- return func(m optionalAttr) {
- m["delete_old_dirs"] = value
- }
-}
-
-// V2 format specific: merges the metadata files of sharded checkpoints. The
-//
-// result is one logical checkpoint, with one physical metadata file and renamed
-// data files.
-//
-// Intended for "grouping" multiple checkpoints in a sharded checkpoint setup.
-//
-// If delete_old_dirs is true, attempts to delete recursively the dirname of each
-// path in the input checkpoint_prefixes. This is useful when those paths are non
-// user-facing temporary locations.
-//
-// Arguments:
-// checkpoint_prefixes: prefixes of V2 checkpoints to merge.
-// destination_prefix: scalar. The desired final prefix. Allowed to be the same
-// as one of the checkpoint_prefixes.
-//
-// Returns the created operation.
-func MergeV2Checkpoints(scope *Scope, checkpoint_prefixes tf.Output, destination_prefix tf.Output, optional ...MergeV2CheckpointsAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "MergeV2Checkpoints",
- Input: []tf.Input{
- checkpoint_prefixes, destination_prefix,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
// UnpackAttr is an optional argument to Unpack.
type UnpackAttr func(optionalAttr)
@@ -19448,7 +19495,7 @@ func MatrixTriangularSolveAdjoint(value bool) MatrixTriangularSolveAttr {
// `rhs` is a tensor of shape `[..., M, K]`.
//
// The output is a tensor of shape `[..., M, K]`. If `adjoint` is
-// `True` then the innermost matrices in output` satisfy matrix equations
+// `True` then the innermost matrices in `output` satisfy matrix equations
// `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
// If `adjoint` is `False` then the strictly then the innermost matrices in
// `output` satisfy matrix equations
@@ -22188,101 +22235,6 @@ func Rsqrt(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
-// RecordInputAttr is an optional argument to RecordInput.
-type RecordInputAttr func(optionalAttr)
-
-// RecordInputFileRandomSeed sets the optional file_random_seed attribute to value.
-//
-// value: Random seeds used to produce randomized records.
-// If not specified, defaults to 301
-func RecordInputFileRandomSeed(value int64) RecordInputAttr {
- return func(m optionalAttr) {
- m["file_random_seed"] = value
- }
-}
-
-// RecordInputFileShuffleShiftRatio sets the optional file_shuffle_shift_ratio attribute to value.
-//
-// value: Shifts the list of files after the list is randomly
-// shuffled.
-// If not specified, defaults to 0
-func RecordInputFileShuffleShiftRatio(value float32) RecordInputAttr {
- return func(m optionalAttr) {
- m["file_shuffle_shift_ratio"] = value
- }
-}
-
-// RecordInputFileBufferSize sets the optional file_buffer_size attribute to value.
-//
-// value: The randomization shuffling buffer.
-// If not specified, defaults to 10000
-func RecordInputFileBufferSize(value int64) RecordInputAttr {
- return func(m optionalAttr) {
- m["file_buffer_size"] = value
- }
-}
-
-// RecordInputFileParallelism sets the optional file_parallelism attribute to value.
-//
-// value: How many sstables are opened and concurrently iterated over.
-// If not specified, defaults to 16
-func RecordInputFileParallelism(value int64) RecordInputAttr {
- return func(m optionalAttr) {
- m["file_parallelism"] = value
- }
-}
-
-// RecordInputBatchSize sets the optional batch_size attribute to value.
-//
-// value: The batch size.
-// If not specified, defaults to 32
-func RecordInputBatchSize(value int64) RecordInputAttr {
- return func(m optionalAttr) {
- m["batch_size"] = value
- }
-}
-
-// Emits randomized records.
-//
-// Arguments:
-// file_pattern: Glob pattern for the data files.
-//
-// Returns A tensor of shape [batch_size].
-func RecordInput(scope *Scope, file_pattern string, optional ...RecordInputAttr) (records tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"file_pattern": file_pattern}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RecordInput",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Rounds the values of a tensor to the nearest integer, element-wise.
-//
-// Rounds half to even. Also known as bankers rounding. If you want to round
-// according to the current system rounding mode use std::cint.
-func Round(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Round",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Generates values in an interval.
//
// A sequence of `num` evenly-spaced values are generated beginning at `start`.
@@ -23637,6 +23589,143 @@ func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
+// Saves the input tensors to disk.
+//
+// The size of `tensor_names` must match the number of tensors in `data`. `data[i]`
+// is written to `filename` with name `tensor_names[i]`.
+//
+// See also `SaveSlices`.
+//
+// Arguments:
+// filename: Must have a single element. The name of the file to which we write
+// the tensor.
+// tensor_names: Shape `[N]`. The names of the tensors to be saved.
+// data: `N` tensors to save.
+//
+// Returns the created operation.
+func Save(scope *Scope, filename tf.Output, tensor_names tf.Output, data []tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Save",
+ Input: []tf.Input{
+ filename, tensor_names, tf.OutputList(data),
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// QrAttr is an optional argument to Qr.
+type QrAttr func(optionalAttr)
+
+// QrFullMatrices sets the optional full_matrices attribute to value.
+//
+// value: If true, compute full-sized `q` and `r`. If false
+// (the default), compute only the leading `P` columns of `q`.
+// If not specified, defaults to false
+func QrFullMatrices(value bool) QrAttr {
+ return func(m optionalAttr) {
+ m["full_matrices"] = value
+ }
+}
+
+// Computes the QR decompositions of one or more matrices.
+//
+// Computes the QR decomposition of each inner matrix in `tensor` such that
+// `tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
+//
+// ```python
+// # a is a tensor.
+// # q is a tensor of orthonormal matrices.
+// # r is a tensor of upper triangular matrices.
+// q, r = qr(a)
+// q_full, r_full = qr(a, full_matrices=True)
+// ```
+//
+// Arguments:
+// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
+// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
+//
+// Returns Orthonormal basis for range of `a`. If `full_matrices` is `False` then
+// shape is `[..., M, P]`; if `full_matrices` is `True` then shape is
+// `[..., M, M]`.Triangular factor. If `full_matrices` is `False` then shape is
+// `[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`.
+func Qr(scope *Scope, input tf.Output, optional ...QrAttr) (q tf.Output, r tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Qr",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// AudioSummaryAttr is an optional argument to AudioSummary.
+type AudioSummaryAttr func(optionalAttr)
+
+// AudioSummaryMaxOutputs sets the optional max_outputs attribute to value.
+//
+// value: Max number of batch elements to generate audio for.
+// If not specified, defaults to 3
+//
+// REQUIRES: value >= 1
+func AudioSummaryMaxOutputs(value int64) AudioSummaryAttr {
+ return func(m optionalAttr) {
+ m["max_outputs"] = value
+ }
+}
+
+// Outputs a `Summary` protocol buffer with audio.
+//
+// DEPRECATED at GraphDef version 15: Use AudioSummaryV2.
+//
+// The summary has up to `max_outputs` summary values containing audio. The
+// audio is built from `tensor` which must be 3-D with shape `[batch_size,
+// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are
+// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`.
+//
+// The `tag` argument is a scalar `Tensor` of type `string`. It is used to
+// build the `tag` of the summary values:
+//
+// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'.
+// * If `max_outputs` is greater than 1, the summary value tags are
+// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc.
+//
+// Arguments:
+// tag: Scalar. Used to build the `tag` attribute of the summary values.
+// tensor: 2-D of shape `[batch_size, frames]`.
+// sample_rate: The sample rate of the signal in hertz.
+//
+// Returns Scalar. Serialized `Summary` protocol buffer.
+func AudioSummary(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate float32, optional ...AudioSummaryAttr) (summary tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"sample_rate": sample_rate}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "AudioSummary",
+ Input: []tf.Input{
+ tag, tensor,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// BiasAddAttr is an optional argument to BiasAdd.
type BiasAddAttr func(optionalAttr)
@@ -23805,6 +23894,101 @@ func ApproximateEqual(scope *Scope, x tf.Output, y tf.Output, optional ...Approx
return op.Output(0)
}
+// RecordInputAttr is an optional argument to RecordInput.
+type RecordInputAttr func(optionalAttr)
+
+// RecordInputFileRandomSeed sets the optional file_random_seed attribute to value.
+//
+// value: Random seeds used to produce randomized records.
+// If not specified, defaults to 301
+func RecordInputFileRandomSeed(value int64) RecordInputAttr {
+ return func(m optionalAttr) {
+ m["file_random_seed"] = value
+ }
+}
+
+// RecordInputFileShuffleShiftRatio sets the optional file_shuffle_shift_ratio attribute to value.
+//
+// value: Shifts the list of files after the list is randomly
+// shuffled.
+// If not specified, defaults to 0
+func RecordInputFileShuffleShiftRatio(value float32) RecordInputAttr {
+ return func(m optionalAttr) {
+ m["file_shuffle_shift_ratio"] = value
+ }
+}
+
+// RecordInputFileBufferSize sets the optional file_buffer_size attribute to value.
+//
+// value: The randomization shuffling buffer.
+// If not specified, defaults to 10000
+func RecordInputFileBufferSize(value int64) RecordInputAttr {
+ return func(m optionalAttr) {
+ m["file_buffer_size"] = value
+ }
+}
+
+// RecordInputFileParallelism sets the optional file_parallelism attribute to value.
+//
+// value: How many sstables are opened and concurrently iterated over.
+// If not specified, defaults to 16
+func RecordInputFileParallelism(value int64) RecordInputAttr {
+ return func(m optionalAttr) {
+ m["file_parallelism"] = value
+ }
+}
+
+// RecordInputBatchSize sets the optional batch_size attribute to value.
+//
+// value: The batch size.
+// If not specified, defaults to 32
+func RecordInputBatchSize(value int64) RecordInputAttr {
+ return func(m optionalAttr) {
+ m["batch_size"] = value
+ }
+}
+
+// Emits randomized records.
+//
+// Arguments:
+// file_pattern: Glob pattern for the data files.
+//
+// Returns A tensor of shape [batch_size].
+func RecordInput(scope *Scope, file_pattern string, optional ...RecordInputAttr) (records tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"file_pattern": file_pattern}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RecordInput",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Rounds the values of a tensor to the nearest integer, element-wise.
+//
+// Rounds half to even. Also known as bankers rounding. If you want to round
+// according to the current system rounding mode use std::cint.
+func Round(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Round",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Returns the max of x and y (i.e. x > y ? x : y) element-wise.
//
// *NOTE*: `Maximum` supports broadcasting. More about broadcasting
@@ -25476,116 +25660,6 @@ func AudioSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate t
return op.Output(0)
}
-// QrAttr is an optional argument to Qr.
-type QrAttr func(optionalAttr)
-
-// QrFullMatrices sets the optional full_matrices attribute to value.
-//
-// value: If true, compute full-sized `q` and `r`. If false
-// (the default), compute only the leading `P` columns of `q`.
-// If not specified, defaults to false
-func QrFullMatrices(value bool) QrAttr {
- return func(m optionalAttr) {
- m["full_matrices"] = value
- }
-}
-
-// Computes the QR decompositions of one or more matrices.
-//
-// Computes the QR decomposition of each inner matrix in `tensor` such that
-// `tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
-//
-// ```python
-// # a is a tensor.
-// # q is a tensor of orthonormal matrices.
-// # r is a tensor of upper triangular matrices.
-// q, r = qr(a)
-// q_full, r_full = qr(a, full_matrices=True)
-// ```
-//
-// Arguments:
-// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
-// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
-//
-// Returns Orthonormal basis for range of `a`. If `full_matrices` is `False` then
-// shape is `[..., M, P]`; if `full_matrices` is `True` then shape is
-// `[..., M, M]`.Triangular factor. If `full_matrices` is `False` then shape is
-// `[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`.
-func Qr(scope *Scope, input tf.Output, optional ...QrAttr) (q tf.Output, r tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Qr",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
-// AudioSummaryAttr is an optional argument to AudioSummary.
-type AudioSummaryAttr func(optionalAttr)
-
-// AudioSummaryMaxOutputs sets the optional max_outputs attribute to value.
-//
-// value: Max number of batch elements to generate audio for.
-// If not specified, defaults to 3
-//
-// REQUIRES: value >= 1
-func AudioSummaryMaxOutputs(value int64) AudioSummaryAttr {
- return func(m optionalAttr) {
- m["max_outputs"] = value
- }
-}
-
-// Outputs a `Summary` protocol buffer with audio.
-//
-// DEPRECATED at GraphDef version 15: Use AudioSummaryV2.
-//
-// The summary has up to `max_outputs` summary values containing audio. The
-// audio is built from `tensor` which must be 3-D with shape `[batch_size,
-// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are
-// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`.
-//
-// The `tag` argument is a scalar `Tensor` of type `string`. It is used to
-// build the `tag` of the summary values:
-//
-// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'.
-// * If `max_outputs` is greater than 1, the summary value tags are
-// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc.
-//
-// Arguments:
-// tag: Scalar. Used to build the `tag` attribute of the summary values.
-// tensor: 2-D of shape `[batch_size, frames]`.
-// sample_rate: The sample rate of the signal in hertz.
-//
-// Returns Scalar. Serialized `Summary` protocol buffer.
-func AudioSummary(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate float32, optional ...AudioSummaryAttr) (summary tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"sample_rate": sample_rate}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "AudioSummary",
- Input: []tf.Input{
- tag, tensor,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Replaces the contents of the table with the specified keys and values.
//
// The tensor `keys` must be of the same type as the keys of the table.
@@ -26357,77 +26431,3 @@ func Inv(scope *Scope, x tf.Output) (y tf.Output) {
op := scope.AddOperation(opspec)
return op.Output(0)
}
-
-// Returns x / y element-wise for integer types.
-//
-// Truncation designates that negative numbers will round fractional quantities
-// toward zero. I.e. -7 / 5 = 1. This matches C semantics but it is different
-// than Python semantics. See `FloorDiv` for a division function that matches
-// Python Semantics.
-//
-// *NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func TruncateDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "TruncateDiv",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Restores tensors from a V2 checkpoint.
-//
-// For backward compatibility with the V1 format, this Op currently allows
-// restoring from a V1 checkpoint as well:
-// - This Op first attempts to find the V2 index file pointed to by "prefix", and
-// if found proceed to read it as a V2 checkpoint;
-// - Otherwise the V1 read path is invoked.
-// Relying on this behavior is not recommended, as the ability to fall back to read
-// V1 might be deprecated and eventually removed.
-//
-// By default, restores the named tensors in full. If the caller wishes to restore
-// specific slices of stored tensors, "shape_and_slices" should be non-empty
-// strings and correspondingly well-formed.
-//
-// Callers must ensure all the named tensors are indeed stored in the checkpoint.
-//
-// Arguments:
-// prefix: Must have a single element. The prefix of a V2 checkpoint.
-// tensor_names: shape {N}. The names of the tensors to be restored.
-// shape_and_slices: shape {N}. The slice specs of the tensors to be restored.
-// Empty strings indicate that they are non-partitioned tensors.
-// dtypes: shape {N}. The list of expected dtype for the tensors. Must match
-// those stored in the checkpoint.
-//
-// Returns shape {N}. The restored tensors, whose shapes are read from the
-// checkpoint directly.
-func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, dtypes []tf.DataType) (tensors []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtypes": dtypes}
- opspec := tf.OpSpec{
- Type: "RestoreV2",
- Input: []tf.Input{
- prefix, tensor_names, shape_and_slices,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if tensors, idx, err = makeOutputList(op, idx, "tensors"); err != nil {
- scope.UpdateErr("RestoreV2", err)
- return
- }
- return tensors
-}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index bc6a88b478..524f128154 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -44,6 +44,7 @@ py_library(
"//tensorflow/contrib/learn/python/learn/datasets:__pkg__", # TODO(b/34059704): remove when fixed
"//tensorflow/python/debug:__pkg__", # TODO(b/34059704): remove when fixed
"//tensorflow/python/tools:__pkg__", # TODO(b/34059704): remove when fixed
+ "//tensorflow/tools/api/generator:__pkg__",
"//tensorflow/tools/quantization:__pkg__", # TODO(b/34059704): remove when fixed
],
deps = [
@@ -3242,7 +3243,9 @@ py_test(
srcs_version = "PY2AND3",
tags = [
"no_gpu",
+ "no_oss",
"no_pip_gpu",
+ "notap",
],
deps = [
":array_ops",
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index b97c5a00b6..05906a405a 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -800,6 +800,7 @@ cuda_py_test(
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow:tensorflow_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
diff --git a/tensorflow/python/debug/cli/analyzer_cli.py b/tensorflow/python/debug/cli/analyzer_cli.py
index cf84478b4d..22d1b4b543 100644
--- a/tensorflow/python/debug/cli/analyzer_cli.py
+++ b/tensorflow/python/debug/cli/analyzer_cli.py
@@ -1397,7 +1397,7 @@ class DebugAnalyzer(object):
for i in xrange(len(all_inputs)):
inp = all_inputs[i]
- op_type = self._debug_dump.node_op_type(inp)
+ op_type = self._debug_dump.node_op_type(debug_graphs.get_node_name(inp))
if op_type in self._GRAPH_STRUCT_OP_TYPE_BLACKLIST:
continue
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index 916b698964..a0c6e45d34 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.debug.lib import debug_utils
from tensorflow.python.debug.lib import source_utils
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
@@ -498,11 +499,74 @@ def check_menu_item(tst, out, line_index, expected_begin, expected_end,
tst.assertTrue(found_menu_item)
+def create_analyzer_cli(dump):
+ """Create an analyzer CLI.
+
+ Args:
+ dump: A `DebugDumpDir` object to base the analyzer CLI on.
+
+ Returns:
+ 1) A `DebugAnalyzer` object created based on `dump`.
+ 2) A `CommandHandlerRegistry` that is based on the `DebugAnalyzer` object
+ and has the common tfdbg commands, e.g., lt, ni, li, lo, registered.
+ """
+ # Construct the analyzer.
+ analyzer = analyzer_cli.DebugAnalyzer(dump)
+
+ # Construct the handler registry.
+ registry = debugger_cli_common.CommandHandlerRegistry()
+
+ # Register command handlers.
+ registry.register_command_handler(
+ "list_tensors",
+ analyzer.list_tensors,
+ analyzer.get_help("list_tensors"),
+ prefix_aliases=["lt"])
+ registry.register_command_handler(
+ "node_info",
+ analyzer.node_info,
+ analyzer.get_help("node_info"),
+ prefix_aliases=["ni"])
+ registry.register_command_handler(
+ "list_inputs",
+ analyzer.list_inputs,
+ analyzer.get_help("list_inputs"),
+ prefix_aliases=["li"])
+ registry.register_command_handler(
+ "list_outputs",
+ analyzer.list_outputs,
+ analyzer.get_help("list_outputs"),
+ prefix_aliases=["lo"])
+ registry.register_command_handler(
+ "print_tensor",
+ analyzer.print_tensor,
+ analyzer.get_help("print_tensor"),
+ prefix_aliases=["pt"])
+ registry.register_command_handler(
+ "print_source",
+ analyzer.print_source,
+ analyzer.get_help("print_source"),
+ prefix_aliases=["ps"])
+ registry.register_command_handler(
+ "list_source",
+ analyzer.list_source,
+ analyzer.get_help("list_source"),
+ prefix_aliases=["ls"])
+ registry.register_command_handler(
+ "eval",
+ analyzer.evaluate_expression,
+ analyzer.get_help("eval"),
+ prefix_aliases=["ev"])
+
+ return analyzer, registry
+
+
class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
@classmethod
def setUpClass(cls):
cls._dump_root = tempfile.mkdtemp()
+ cls._dump_root_for_unique = tempfile.mkdtemp()
cls._is_gpu_available = test.is_gpu_available()
if cls._is_gpu_available:
@@ -536,8 +600,11 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
x = math_ops.add(w, w, name="simple_mul_add/add")
cls._x_line_number = line_number_above()
+ a = variables.Variable([1, 3, 3, 7], name="a")
+
u.initializer.run()
v.initializer.run()
+ a.initializer.run()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
@@ -548,53 +615,16 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
# Invoke Session.run().
run_metadata = config_pb2.RunMetadata()
- sess.run(x, options=run_options, run_metadata=run_metadata)
-
- cls._debug_dump = debug_data.DebugDumpDir(
- cls._dump_root, partition_graphs=run_metadata.partition_graphs)
-
- # Construct the analyzer.
- cls._analyzer = analyzer_cli.DebugAnalyzer(cls._debug_dump)
-
- # Construct the handler registry.
- cls._registry = debugger_cli_common.CommandHandlerRegistry()
-
- # Register command handlers.
- cls._registry.register_command_handler(
- "list_tensors",
- cls._analyzer.list_tensors,
- cls._analyzer.get_help("list_tensors"),
- prefix_aliases=["lt"])
- cls._registry.register_command_handler(
- "node_info",
- cls._analyzer.node_info,
- cls._analyzer.get_help("node_info"),
- prefix_aliases=["ni"])
- cls._registry.register_command_handler(
- "print_tensor",
- cls._analyzer.print_tensor,
- cls._analyzer.get_help("print_tensor"),
- prefix_aliases=["pt"])
- cls._registry.register_command_handler(
- "print_source",
- cls._analyzer.print_source,
- cls._analyzer.get_help("print_source"),
- prefix_aliases=["ps"])
- cls._registry.register_command_handler(
- "list_source",
- cls._analyzer.list_source,
- cls._analyzer.get_help("list_source"),
- prefix_aliases=["ls"])
- cls._registry.register_command_handler(
- "eval",
- cls._analyzer.evaluate_expression,
- cls._analyzer.get_help("eval"),
- prefix_aliases=["ev"])
+ sess.run([x], options=run_options, run_metadata=run_metadata)
+ cls._debug_dump = debug_data.DebugDumpDir(
+ cls._dump_root, partition_graphs=run_metadata.partition_graphs)
+ cls._analyzer, cls._registry = create_analyzer_cli(cls._debug_dump)
@classmethod
def tearDownClass(cls):
# Tear down temporary dump directory.
shutil.rmtree(cls._dump_root)
+ shutil.rmtree(cls._dump_root_for_unique)
def testMeasureTensorListColumnWidthsGivesRightAnswerForEmptyData(self):
timestamp_col_width, dump_size_col_width, op_type_col_width = (
@@ -1461,6 +1491,37 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
self.assertEqual("ps uncompiled.py -b 6",
out.font_attr_segs[6][0][2][1].content)
+ def testListInputInvolvingNodesWithMultipleOutputs(self):
+ """List an input tree containing tensors from non-:0 output slot."""
+
+ with session.Session(config=no_rewrite_session_config()) as sess:
+ x = variables.Variable([1, 3, 3, 7], name="x")
+ _, idx = array_ops.unique(x, name="x_unique")
+ idx_times_two = math_ops.multiply(idx, 2, name="idx_times_two")
+ sess.run(x.initializer)
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ debug_utils.watch_graph(
+ run_options,
+ sess.graph,
+ debug_ops=["DebugIdentity"],
+ debug_urls="file://%s" % self._dump_root_for_unique)
+ run_metadata = config_pb2.RunMetadata()
+ self.assertAllEqual(
+ [0, 2, 2, 4],
+ sess.run(idx_times_two,
+ options=run_options,
+ run_metadata=run_metadata))
+ debug_dump = debug_data.DebugDumpDir(
+ self._dump_root_for_unique,
+ partition_graphs=run_metadata.partition_graphs)
+ _, registry = create_analyzer_cli(debug_dump)
+
+ out = registry.dispatch_command("li", ["idx_times_two"])
+ self.assertEqual(
+ ["Inputs to node \"idx_times_two\" (Depth limit = 1):",
+ "|- (1) x_unique:1"], out.lines[:2])
+
class AnalyzerCLIPrintLargeTensorTest(test_util.TensorFlowTestCase):
@@ -1486,18 +1547,8 @@ class AnalyzerCLIPrintLargeTensorTest(test_util.TensorFlowTestCase):
cls._debug_dump = debug_data.DebugDumpDir(
cls._dump_root, partition_graphs=run_metadata.partition_graphs)
- # Construct the analyzer.
- cls._analyzer = analyzer_cli.DebugAnalyzer(cls._debug_dump)
-
- # Construct the handler registry.
- cls._registry = debugger_cli_common.CommandHandlerRegistry()
-
- # Register command handler.
- cls._registry.register_command_handler(
- "print_tensor",
- cls._analyzer.print_tensor,
- cls._analyzer.get_help("print_tensor"),
- prefix_aliases=["pt"])
+ # Construct the analyzer and command registry.
+ cls._analyzer, cls._registry = create_analyzer_cli(cls._debug_dump)
@classmethod
def tearDownClass(cls):
@@ -1575,28 +1626,8 @@ class AnalyzerCLIControlDepTest(test_util.TensorFlowTestCase):
debug_dump = debug_data.DebugDumpDir(
cls._dump_root, partition_graphs=run_metadata.partition_graphs)
- # Construct the analyzer.
- analyzer = analyzer_cli.DebugAnalyzer(debug_dump)
-
- # Construct the handler registry.
- cls._registry = debugger_cli_common.CommandHandlerRegistry()
-
- # Register command handlers.
- cls._registry.register_command_handler(
- "node_info",
- analyzer.node_info,
- analyzer.get_help("node_info"),
- prefix_aliases=["ni"])
- cls._registry.register_command_handler(
- "list_inputs",
- analyzer.list_inputs,
- analyzer.get_help("list_inputs"),
- prefix_aliases=["li"])
- cls._registry.register_command_handler(
- "list_outputs",
- analyzer.list_outputs,
- analyzer.get_help("list_outputs"),
- prefix_aliases=["lo"])
+ # Construct the analyzer and command handler registry.
+ _, cls._registry = create_analyzer_cli(debug_dump)
@classmethod
def tearDownClass(cls):
@@ -1911,18 +1942,7 @@ class AnalyzerCLIWhileLoopTest(test_util.TensorFlowTestCase):
cls._debug_dump = debug_data.DebugDumpDir(
cls._dump_root, partition_graphs=run_metadata.partition_graphs)
- cls._analyzer = analyzer_cli.DebugAnalyzer(cls._debug_dump)
- cls._registry = debugger_cli_common.CommandHandlerRegistry()
- cls._registry.register_command_handler(
- "list_tensors",
- cls._analyzer.list_tensors,
- cls._analyzer.get_help("list_tensors"),
- prefix_aliases=["lt"])
- cls._registry.register_command_handler(
- "print_tensor",
- cls._analyzer.print_tensor,
- cls._analyzer.get_help("print_tensor"),
- prefix_aliases=["pt"])
+ cls._analyzer, cls._registry = create_analyzer_cli(cls._debug_dump)
@classmethod
def tearDownClass(cls):
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 62a0c19daf..92af5f3edf 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -97,8 +97,7 @@ def _prepare_backprop(target, tensor_to_op, op_to_entry, id_sources):
continue
op_trace = op_to_entry[op]
o_to_e[op] = op_trace
- for i in op_trace.inputs:
- it = ops.tensor_id(i)
+ for it in op_trace.input_ids:
if it in tensor_usage_counts:
tensor_usage_counts[it] += 1
else:
@@ -220,7 +219,7 @@ def imperative_grad(
if not tape._tape_stack.stack: # pylint: disable=protected-access
raise RuntimeError("Computing a gradient with no tape present")
bp_tape = tape.pop_tape()
- tensor_to_op, op_to_entry, output_to_shape_dtype = bp_tape.export()
+ tensor_to_op, op_to_entry = bp_tape.export()
# This overwrites the op_to_entry variable, which will release all memory used
# to keep traces that are irrelevant to the gradient computation we're doing
# here.
@@ -238,8 +237,7 @@ def imperative_grad(
for i in range(len(out_gradients)):
if out_gradients[i] is None:
# TODO(apassos) this should be in the right device
- out_gradients[i] = array_ops.zeros(
- *output_to_shape_dtype[op_trace.output_ids[i]])
+ out_gradients[i] = array_ops.zeros(*op_trace.output_shape_and_dtype[i])
else:
out_gradients[i] = _aggregate_grads(out_gradients[i])
@@ -250,14 +248,15 @@ def imperative_grad(
ops.IndexedSlices,
type(None)))
else in_gradients)
- for i, t in enumerate(op_trace.inputs):
+ for i, t in enumerate(op_trace.input_ids):
if in_gradients[i] is not None:
- gradients[ops.tensor_id(t)].append(in_gradients[i])
- if tensor_usage_counts.get(ops.tensor_id(t), 0) > 0:
- tensor_usage_counts[ops.tensor_id(t)] -= 1
- if ops.tensor_id(t) in tensor_to_op and tensor_usage_counts[
- ops.tensor_id(t)] == 0 and ops.tensor_id(t) not in id_sources:
- in_op = tensor_to_op[ops.tensor_id(t)]
+ gradients[t].append(in_gradients[i])
+ if tensor_usage_counts.get(t, 0) > 0:
+ tensor_usage_counts[t] -= 1
+ if (t in tensor_to_op
+ and tensor_usage_counts[t] == 0
+ and t not in id_sources):
+ in_op = tensor_to_op[t]
if in_op is None:
continue
if op_missing_tensor.get(in_op, 0) > 0:
@@ -350,6 +349,109 @@ _gradient_functions_lock = threading.Lock()
_tracing = False
+# TODO(apassos) replace this with a mechanism which can happen at the op
+# gradient function registration site, to be less error-prone
+# TODO(apassos) add ops other than those in nn_grad and math_grad
+_ops_which_dont_need_outputs = set([
+ "MatMul",
+ "Conv2DBackpropInput",
+ "Conv2DBackpropFilter",
+ "Conv3D",
+ "Conv3DBackpropInputV2",
+ "AvgPool3D",
+ "AvgPool3DGrad",
+ "MaxPool3D",
+ "MaxPool3DGrad",
+ "MaxPool3DGradGrad",
+ "BiasAdd",
+ "BiasAddV1",
+ "BiasAddGrad",
+ "Relu6",
+ "Softplus",
+ "SoftplusGrad",
+ "Softsign",
+ "ReluGrad",
+ "Conv2D",
+ "DepthwiseConv2dNative",
+ "Dilation2D",
+ "AvgPool",
+ "AvgPoolGrad",
+ "BatchNormWithGlobalNormalization",
+ "L2Loss",
+ "Sum",
+ "Prod",
+ "SegmentSum",
+ "SegmentMean",
+ "SparseSegmentSum",
+ "SparseSegmentMean",
+ "SparseSegmentSqrtN",
+ "SegmentMin",
+ "SegmentMax",
+ "UnsortedSegmentSum",
+ "UnsortedSegmentMax",
+ "Abs",
+ "Neg",
+ "ReciprocalGrad",
+ "Square",
+ "Expm1",
+ "Log",
+ "Log1p",
+ "TanhGrad",
+ "SigmoidGrad",
+ "Sign",
+ "Sin",
+ "Cos",
+ "Tan",
+ "Add",
+ "Sub",
+ "Mul",
+ "Div",
+ "RealDiv",
+ "Pow",
+ "Maximum",
+ "Minimum",
+ "SquaredDifference",
+ "Select",
+ "SparseMatMul",
+ "BatchMatMul",
+ "Complex",
+ "Real",
+ "Imag",
+ "Angle",
+ "Conj",
+ "Cast",
+ "Cross",
+ "Cumsum",
+ "Cumprod",
+ "ReadVariableOp",
+ "VarHandleOp",
+ "Shape",
+])
+
+_ops_which_dont_need_inputs = set([
+ "Softmax",
+ "LogSoftmax",
+ "BiasAdd",
+ "Relu",
+ "Elu",
+ "Selu",
+ "SparseSoftmaxCrossEntropyWithLogits",
+ "Neg",
+ "Inv",
+ "Reciprocal",
+ "Sqrt",
+ "Exp",
+ "Tanh",
+ "Sigmoid",
+ "Real",
+ "Imag",
+ "Conj",
+ "ReadVariableOp",
+ "VarHandleOp",
+ "Shape",
+])
+
+
def _record_gradient(op_name, inputs, attrs, results, name):
"""Records gradients for a TensorFlow operation.
@@ -368,13 +470,32 @@ def _record_gradient(op_name, inputs, attrs, results, name):
Raises:
An exception on error.
"""
+ if not tape.could_possibly_record():
+ return
+
+ if op_name in _ops_which_dont_need_outputs:
+ op_outputs = None
+ else:
+ # TODO(apassos) this line creates a weak circular reference where the
+ # backprop function keeps an output alive which in turn keeps the tape entry
+ # alive which keeps the backprop function alive. Figure out how to break
+ # this up without breaking second derivatives of ops like Exp whose
+ # gradients depend only on the outputs.
+ op_outputs = results
+
+ if op_name in _ops_which_dont_need_inputs:
+ op_inputs = None
+ else:
+ op_inputs = inputs
+
+ num_inputs = len(inputs)
def grad_fn(*orig_outputs):
"""Generated gradient function."""
- result = _magic_gradient_function(op_name, attrs, len(inputs),
- inputs, results, orig_outputs)
+ result = _magic_gradient_function(op_name, attrs, num_inputs,
+ op_inputs, op_outputs, orig_outputs)
if _tracing:
- print("Gradient for", (name if name else op_name), "inputs", inputs,
+ print("Gradient for", (name if name else op_name), "inputs", op_inputs,
"output_grads", orig_outputs, "gradients", result)
return result
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 73bfbc8031..38555ae2fa 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.platform import app
from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util import tf_inspect
GRAPH_MODE = 0
EAGER_MODE = 1
@@ -406,10 +407,19 @@ def run(main=None, argv=None):
def enable_eager_execution():
"""Enables, for the rest of the lifetime of this program, eager execution.
- If not called immediately on startup risks creating breakage and bugs.
+ If not called immediately on startup risks creating breakage and bugs. Calling
+ this method more than once in the same process will lead to an exception.
+
+ Raises:
+ ValueError: If this method has already been invoked in the current process.
"""
global _default_mode
- assert _default_mode == GRAPH_MODE
+ if _default_mode == EAGER_MODE:
+ func_name = (
+ "tfe." + tf_inspect.getframeinfo(tf_inspect.currentframe()).function)
+ raise ValueError(
+ "Do not call %s more than once in the same process. Note eager-mode "
+ "methods such as tfe.run() also call %s." % (func_name, func_name))
_default_mode = EAGER_MODE
diff --git a/tensorflow/python/eager/core.py b/tensorflow/python/eager/core.py
index 64c615fb63..b6e7d53ee8 100644
--- a/tensorflow/python/eager/core.py
+++ b/tensorflow/python/eager/core.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
-from tensorflow.python.eager import context
from tensorflow.python.eager import memory_trace
from tensorflow.python.framework import errors
@@ -57,8 +56,7 @@ def enable_tracing():
WARNING: tracing is not thread-safe.
"""
global _active_trace
- _active_trace = memory_trace.MemoryTrace(
- len(context.get_default_context().devices()))
+ _active_trace = memory_trace.MemoryTrace()
def flush_trace():
diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py
index 52f3f07708..dbd5a50924 100644
--- a/tensorflow/python/eager/custom_gradient.py
+++ b/tensorflow/python/eager/custom_gradient.py
@@ -43,7 +43,9 @@ def custom_gradient(f):
"""Decorated function with custom gradient."""
input_tensors = [x for x in args
if isinstance(x, tf_ops.Tensor)]
- result, grad_fn = f(*args, **kwargs)
+
+ with tape.stop_recording():
+ result, grad_fn = f(*args, **kwargs)
# TODO(apassos): naive uses of custom_gradient will not get the correct
# second derivative this way if they capture any output tensors. Change the
diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py
index 07d37cc500..50a23f585b 100644
--- a/tensorflow/python/eager/execute.py
+++ b/tensorflow/python/eager/execute.py
@@ -73,12 +73,11 @@ def execute(op_name, num_outputs, inputs, attrs=None, name=None):
tensors = [ops._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access
# TODO(alive, cais): Use the execution callback mechanism.
if core.active_trace() is not None:
- trace_name = name if name else op_name
for t in tensors:
# pylint: disable=protected-access
- core.active_trace().record_tensor(trace_name,
+ core.active_trace().record_tensor(op_name,
ops.tensor_id(t),
- t._device_name(),
+ t.device,
t.shape.num_elements())
# pylint: enable=protected-access
diff --git a/tensorflow/python/eager/memory_trace.py b/tensorflow/python/eager/memory_trace.py
index 0baf922408..094bcab9e2 100644
--- a/tensorflow/python/eager/memory_trace.py
+++ b/tensorflow/python/eager/memory_trace.py
@@ -29,29 +29,30 @@ TensorData = collections.namedtuple(
class MemoryTrace(object):
"""Records a trace of memory usage over operation execution."""
- def __init__(self, n_devices):
+ def __init__(self):
self.trace = []
self.tensor_to_data = {}
- self.current_device_mem_usage = [0] * n_devices
+ self.current_device_mem_usage = collections.defaultdict(int)
def record_tensor(self, op_name, tensor_id, device, size):
self.current_device_mem_usage[device] += size
self.tensor_to_data[tensor_id] = TensorData(op_name, size, device)
self.trace.append(TraceEntry(op_name,
tensor_id,
- self.current_device_mem_usage[:],
+ dict(self.current_device_mem_usage.items()),
device,
size))
def delete_tensor(self, tensor_id):
if tensor_id not in self.tensor_to_data:
return
- data = self.tensor_to_data.pop(tensor_id)
+ data = self.tensor_to_data.pop(tensor_id, None)
+ if data is None: return
self.current_device_mem_usage[data.device] -= data.tensor_size
self.trace.append(TraceEntry(data.op_name,
tensor_id,
- self.current_device_mem_usage[:],
+ dict(self.current_device_mem_usage.items()),
data.device,
-data.tensor_size))
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index e33c52a1b2..7ba7d0e7ec 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import contextlib
import threading
from tensorflow.python.util import tf_contextlib
@@ -30,7 +31,8 @@ def tid(tensor):
class TapeEntry(
collections.namedtuple("TapeEntry", [
- "output_ids", "inputs", "side_outputs", "backward_function"
+ "output_ids", "input_ids", "side_outputs", "backward_function",
+ "output_shape_and_dtype",
])):
"""Entry in the gradient tape.
@@ -39,11 +41,13 @@ class TapeEntry(
Args:
output_ids: tensor_id(t) for each output tensor T
- inputs: input tensors
- side_outputs: optional tensors which need to be provided to the backward
- function.
+ input_ids: tensor_id(t) for each input tensor T
+ side_outputs: optional tensors (not IDs) which need to be provided to the
+ backward function.
backward_function: function to be called with the downstream gradients and
side outputs as arguments which computes the backward pass.
+ output_shape_and_dtype: a list of (shape_tuple, dtype) for every output
+ tensor_id
"""
@@ -57,8 +61,9 @@ class Tape(object):
def __init__(self):
# _tensor_tape maps from tensor IDs to their operation IDs
self._tensor_tape = {}
- # maps output tensor IDs to their shapes and dtypes
- self._shape_dtype = {}
+ # maps from tensor ID to usage count. Triggers garbage collection when this
+ # goes to zero.
+ self._tensor_usage = {}
# maps from operation ID to TapeEntry
self._op_tape = {}
# next operation ID
@@ -81,8 +86,10 @@ class Tape(object):
def watch(self, tensor):
"""Adds a tensor to the tape."""
- if tid(tensor) not in self._tensor_tape:
- self._tensor_tape[tid(tensor)] = None
+ i = tid(tensor)
+ if i not in self._tensor_tape:
+ self._tensor_tape[i] = None
+ self._tensor_usage[i] = 1
self._watched.append(tensor)
def watch_variable(self, v):
@@ -95,25 +102,40 @@ class Tape(object):
if not self.should_record(input_tensors):
return output_tensors
for t in output_tensors:
- self._tensor_tape[tid(t)] = self._next_op_id
- self._shape_dtype[tid(t)] = (_tensor_shape(t), t.dtype)
-
+ i = tid(t)
+ self._tensor_tape[i] = self._next_op_id
+ self._tensor_usage[i] = 1
+ for t in input_tensors:
+ i = tid(t)
+ self._tensor_usage[i] = self._tensor_usage.get(i, 0) + 1
self._op_tape[self._next_op_id] = TapeEntry(
[tid(t) for t in output_tensors],
- input_tensors,
+ [tid(t) for t in input_tensors],
side_outputs,
- backward_function)
+ backward_function,
+ [(_tensor_shape(t), t.dtype) for t in output_tensors])
self._next_op_id += 1
+ def _delete_tensor_id(self, i):
+ if i in self._tensor_usage:
+ self._tensor_usage[i] -= 1
+ if self._tensor_usage[i] == 0:
+ del self._tensor_usage[i]
+ op_id = self._tensor_tape.pop(i, None)
+ if op_id is None:
+ return
+ op = self._op_tape[op_id]
+ if not any(tensor_id in self._tensor_usage
+ for tensor_id in op.output_ids):
+ del self._op_tape[op_id]
+ for tensor_id in op.input_ids:
+ # TODO(apassos) this recursion might come to bite us. Consider
+ # adding an explicit stack if this ever gets out of hand
+ self._delete_tensor_id(tensor_id)
+
def delete_trace(self, tensor):
"""Deletes any trace we have for this tensor."""
- if tid(tensor) in self._tensor_tape:
- op = self._tensor_tape[tid(tensor)]
- del self._tensor_tape[tid(tensor)]
- if op in self._op_tape:
- if not any(
- x in self._tensor_tape for x in self._op_tape[op].output_ids):
- del self._op_tape[op]
+ self._delete_tensor_id(tid(tensor))
def export(self):
"""Exports the internal state of this tape.
@@ -122,10 +144,8 @@ class Tape(object):
tensor_tape: a map from tensor_id(tensor) to <identifier for op>
responsible for generating that tensor.
op_tape: a map from <identifier for op> to TapeEntry for that op.
- output_to_shape_dtype: a map from tensor_id(tensor) to its shape and
- dtype, for tensors which are outputs
"""
- return self._tensor_tape, self._op_tape, self._shape_dtype
+ return self._tensor_tape, self._op_tape
class _TapeStack(threading.local):
@@ -188,6 +208,16 @@ def pop_tape():
return None
+@contextlib.contextmanager
+def stop_recording():
+ old = _tape_stack.stack
+ _tape_stack._stack = [] # pylint: disable=protected-access
+ try:
+ yield
+ finally:
+ _tape_stack._stack = old # pylint: disable=protected-access
+
+
def should_record(tensors):
"""Returns true if any tape in the stach watches any of these tensors."""
if not _tape_stack.stack:
@@ -219,3 +249,8 @@ def top_tape_watched_tensors():
def top_tape_watched_variables():
t = _tape_stack.stack[-1]
return t._watched_variables # pylint: disable=protected-access
+
+
+def could_possibly_record():
+ """Returns True if any tape is active."""
+ return len(_tape_stack.stack) > 0 # pylint: disable=g-explicit-length-test
diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py
index 5541087f5d..2df833175b 100644
--- a/tensorflow/python/eager/tape_test.py
+++ b/tensorflow/python/eager/tape_test.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
+from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -49,6 +50,16 @@ def two_outputs(a, b):
return [mm, r], grad
+@custom_gradient.custom_gradient
+def gradient_is_constant(x):
+ result = x * x
+
+ def grad(dr):
+ return [dr]
+
+ return result, grad
+
+
class TapeTest(test.TestCase):
def testMultiOutput(self):
@@ -155,6 +166,25 @@ class TapeTest(test.TestCase):
g, = backprop.gradients_function(fn, [0])(t)
self.assertEqual(g.numpy(), 1.0)
+ def testTapeGC(self):
+ # TODO(apassos) figure out how to test this without using tape internal
+ # APIs.
+ tape.push_new_tape()
+
+ def f():
+ x = constant_op.constant(1.0)
+ tape.watch(x)
+ x = gradient_is_constant(x)
+ x = gradient_is_constant(x)
+ x = gradient_is_constant(x)
+
+ f()
+ t = tape.pop_tape()
+ tensor_tape, op_tape = t.export()
+ self.assertEqual(len(tensor_tape), 1) # The watched tensor will remain on
+ # the tape
+ self.assertEqual(len(op_tape), 0) # No operations should remain on the tape
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 99fb1e0c90..60a015119f 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -259,6 +259,7 @@ class _DefinedFunction(object):
python_grad_func=None,
out_names=None,
shape_func=None,
+ capture_by_value=False,
**kwargs):
"""Creates _DefinedFunction.
@@ -277,6 +278,8 @@ class _DefinedFunction(object):
names.
shape_func: An optional function mapping an op to a list of static
output shapes.
+ capture_by_value: Boolean (defaults to False). If True, captured values
+ will be copied into the function body.
**kwargs: The keyword arguments. **kwargs is passed to every call
site of this function.
@@ -291,6 +294,7 @@ class _DefinedFunction(object):
self._python_grad_func = python_grad_func
self._out_names = out_names
self._shape_func = shape_func
+ self._capture_by_value = capture_by_value
self._extra_kwargs = kwargs
self._definition = None # Constructed lazily.
self._c_func = None # Constructed with definition.
@@ -344,12 +348,16 @@ class _DefinedFunction(object):
def _create_definition_if_needed(self):
"""Creates the function definition if it's not created yet."""
+ with context.graph_mode():
+ self._create_definition_if_needed_impl()
+ def _create_definition_if_needed_impl(self):
+ """This is not what you want, see _create_definition_if_needed."""
if self._definition is not None:
return
# Create the func_def object.
- temp_graph = _FuncGraph()
+ temp_graph = _FuncGraph(capture_by_value=self._capture_by_value)
with temp_graph.as_default():
# List of placeholders for the function_def.
inputs = []
@@ -613,8 +621,9 @@ class _FuncGraph(ops.Graph):
function argument and the caller passes in the captured tensor.
"""
- def __init__(self, *args, **kwargs):
+ def __init__(self, capture_by_value, *args, **kwargs):
super(_FuncGraph, self).__init__(*args, **kwargs)
+ self._capture_by_value = capture_by_value
self._building_function = True
self._outer_graph = ops.get_default_graph()
self._vscope = vs.get_variable_scope()
@@ -672,6 +681,8 @@ class _FuncGraph(ops.Graph):
if x in self._captured:
# Captured already.
inputs[i] = self._captured[x]
+ elif self._capture_by_value:
+ inputs[i] = self._add_tensor_and_parents(x)
else:
# Substitute with a placeholder.
self.extra_inputs.append(x)
@@ -685,6 +696,35 @@ class _FuncGraph(ops.Graph):
return super(_FuncGraph, self).create_op(op_type, inputs, data_types,
**kwargs)
+ def _add_tensor_and_parents(self, tensor):
+ op = self._add_op_and_parents(tensor.op)
+ return op.outputs[tensor.value_index]
+
+ def _add_op_and_parents(self, op):
+ # pylint: disable=protected-access
+ op_def = graph_to_function_def._get_op_def(op)
+ # pylint: enable=protected-access
+ if op_def.is_stateful:
+ raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
+ "by value." % (op.name, op.type))
+ elif op.type in ("Placeholder", "PlaceholderV2"):
+ raise ValueError("Cannot capture a placeholder (name:%s, type:%s) "
+ "by value." % (op.name, op.type))
+
+ captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
+
+ captured_op = self.create_op(
+ op.type,
+ captured_inputs, [o.dtype for o in op.outputs],
+ name=op.name,
+ attrs=op.node_def.attr,
+ op_def=op_def)
+
+ for t, captured_t in zip(op.outputs, captured_op.outputs):
+ self._captured[t] = captured_t
+
+ return captured_op
+
def _call(sig, *inputs, **kwargs):
"""Adds a node calling a function.
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 40205ddf05..73ccac31e2 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -128,7 +128,7 @@ class FunctionTestMethods(object):
@function.Defun(dtypes.float32, dtypes.float32)
def APlus2B(a, b):
print(a + b * 2) # Create some ops to have nodes in the body
- # Using 'print' to make lint happy
+ # Using 'print' to make lint happy
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError,
@@ -1213,6 +1213,35 @@ class FunctionOverloadTest(test.TestCase):
"Successor of x.")
+class FunctionCaptureByValueTest(test.TestCase):
+
+ def testCaptureByValue(self):
+ g = ops.Graph()
+ with g.as_default():
+ w = constant_op.constant([[1.0]])
+ b = constant_op.constant([2.0])
+
+ # Foo() captures w and b.
+ @function.Defun(dtypes.float32, capture_by_value=True)
+ def Foo(x):
+
+ # Plus() captures b.
+ @function.Defun(dtypes.float32, capture_by_value=True)
+ def Plus(y):
+ return y + b
+
+ self.assertEqual(0, len(Plus.captured_inputs))
+
+ return Plus(math_ops.matmul(w, x))
+
+ y = Foo(constant_op.constant([[10.]]))
+
+ self.assertEqual(0, len(Foo.captured_inputs))
+
+ with self.test_session(graph=g):
+ self.assertAllEqual(y.eval(), [[12.0]])
+
+
class UnrollLSTMTest(test.TestCase):
BATCH_SIZE = 16
LSTM_DIMS = 32
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 3388f882ac..ad9368b599 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -472,7 +472,7 @@ class TensorFlowTestCase(googletest.TestCase):
def tearDown(self):
for thread in self._threads:
- self.assertFalse(thread.is_alive(), "A checkedThread did not terminate")
+ thread.check_termination()
self._ClearCachedSession()
@@ -726,6 +726,8 @@ class TensorFlowTestCase(googletest.TestCase):
self._thread = threading.Thread(target=self._protected_run)
self._exception = None
+ self._is_thread_joined = False
+
def _protected_run(self):
"""Target for the wrapper thread. Sets self._exception on failure."""
try:
@@ -748,6 +750,7 @@ class TensorFlowTestCase(googletest.TestCase):
self._testcase.failureException: If the thread terminates with due to
an exception.
"""
+ self._is_thread_joined = True
self._thread.join()
if self._exception is not None:
self._testcase.fail("Error in checkedThread: %s" % str(self._exception))
@@ -763,6 +766,28 @@ class TensorFlowTestCase(googletest.TestCase):
"""
return self._thread.is_alive()
+ def check_termination(self):
+ """Returns whether the checked thread was properly used and did terminate.
+
+ Every checked thread should be "join"ed after starting, and before the
+ test tears down. If it is not joined, it is possible the thread will hang
+ and cause flaky failures in tests.
+
+ Raises:
+ self._testcase.failureException: If check_termination was called before
+ thread was joined.
+
+ RuntimeError: If the thread is not terminated. This means thread was not
+ joined with the main thread.
+ """
+ if self._is_thread_joined:
+ if self.is_alive():
+ raise RuntimeError(
+ "Thread was not joined with main thread, and is still running "
+ "when the test finished.")
+ else:
+ self._testcase.fail("A checked thread was not joined.")
+
def checkedThread(self, target, args=None, kwargs=None):
"""Returns a Thread wrapper that asserts 'target' completes successfully.
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 15e02ab0f4..1c6b2a87c3 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -2301,7 +2301,7 @@ cuda_py_test(
cuda_py_test(
name = "fft_ops_test",
- size = "medium",
+ size = "large",
srcs = ["fft_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
diff --git a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
index e641d5511f..317b8dc05b 100644
--- a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
@@ -19,13 +19,24 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradient_checker
-from tensorflow.python.platform import test
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test as test_lib
-class MatrixBandPartTest(test.TestCase):
+def _AddTest(test, op_name, testcase_name, fn):
+ test_name = "_".join(["test", op_name, testcase_name])
+ if hasattr(test, test_name):
+ raise RuntimeError("Test %s defined more than once" % test_name)
+ setattr(test, test_name, fn)
+
+
+class MatrixBandPartTest(test_lib.TestCase):
pass # Filled in below
@@ -34,23 +45,23 @@ def _GetMatrixBandPartTest(dtype_, batch_shape_, shape_):
def Test(self):
mat = np.ones(shape_).astype(dtype_)
batch_mat = np.tile(mat, batch_shape_ + (1, 1))
- with self.test_session(use_gpu=True):
- for lower in -1, 0, 1, shape_[-2] - 1:
- for upper in -1, 0, 1, shape_[-1] - 1:
- band_np = mat
- if lower >= 0:
- band_np = np.triu(band_np, -lower)
- if upper >= 0:
- band_np = np.tril(band_np, upper)
- if batch_shape_ is not ():
- band_np = np.tile(band_np, batch_shape + (1, 1))
+ for lower in -1, 0, 1, shape_[-2] - 1:
+ for upper in -1, 0, 1, shape_[-1] - 1:
+ band_np = mat
+ if lower >= 0:
+ band_np = np.triu(band_np, -lower)
+ if upper >= 0:
+ band_np = np.tril(band_np, upper)
+ if batch_shape_ is not ():
+ band_np = np.tile(band_np, batch_shape_ + (1, 1))
+ with self.test_session(use_gpu=False):
band = array_ops.matrix_band_part(batch_mat, lower, upper)
self.assertAllEqual(band_np, band.eval())
return Test
-class MatrixBandPartGradTest(test.TestCase):
+class MatrixBandPartGradTest(test_lib.TestCase):
pass # Filled in below
@@ -59,7 +70,7 @@ def _GetMatrixBandPartGradTest(dtype_, batch_shape_, shape_):
def Test(self):
shape = batch_shape_ + shape_
x = constant_op.constant(np.random.rand(*shape), dtype=dtype_)
- with self.test_session(use_gpu=True):
+ with self.test_session(use_gpu=False):
for lower in -1, 0, 1, shape_[-2] - 1:
for upper in -1, 0, 1, shape_[-1] - 1:
y = array_ops.matrix_band_part(x, lower, upper)
@@ -70,18 +81,77 @@ def _GetMatrixBandPartGradTest(dtype_, batch_shape_, shape_):
return Test
-if __name__ == '__main__':
- for dtype in (
- np.int32, np.int64, np.float32, np.float64, np.complex64, np.complex128):
+class MatrixBandPartBenchmark(test_lib.Benchmark):
+
+ shapes = [
+ (10, 16, 16),
+ (10, 101, 101),
+ (10, 256, 256),
+ (10, 1000, 1000),
+ (10, 1024, 1024),
+ (10, 2048, 2048),
+ (10, 10, 4, 4),
+ (10, 10, 10, 10),
+ (10, 10, 16, 16),
+ (10, 10, 101, 101),
+ (10, 10, 256, 256),
+ (10, 10, 1000, 1000),
+ (10, 10, 1024, 1024),
+ (10, 10, 2048, 2048),
+ ]
+
+ def benchmarkMatrixBandPartOp(self):
+ for shape_ in self.shapes:
+ for limits in (-1, -1), (-1, 0), (0, -1), (2, 2):
+ with ops.Graph().as_default(), \
+ session.Session() as sess, \
+ ops.device("/cpu:0"):
+ matrix = variables.Variable(array_ops.ones(shape_))
+ band = array_ops.matrix_band_part(matrix, limits[0], limits[1])
+ variables.global_variables_initializer().run()
+ self.run_op_benchmark(
+ sess,
+ control_flow_ops.group(band),
+ min_iters=10,
+ name="matrix_band_part_cpu_{shape}_{limits}".format(
+ shape=shape_, limits=limits))
+
+ if test_lib.is_gpu_available(True):
+ with ops.Graph().as_default(), \
+ session.Session() as sess, \
+ ops.device("/gpu:0"):
+ matrix = variables.Variable(array_ops.ones(shape_))
+ band = array_ops.matrix_band_part(matrix, limits[0], limits[1])
+ variables.global_variables_initializer().run()
+ self.run_op_benchmark(
+ sess,
+ control_flow_ops.group(band),
+ min_iters=10,
+ name="matrix_band_part_gpu_{shape}_{limits}".format(
+ shape=shape_, limits=limits))
+
+
+if __name__ == "__main__":
+ dtypes = (np.bool, np.int32, np.int64, np.float32, np.float64, np.complex64,
+ np.complex128)
+ for dtype in dtypes:
for batch_shape in ((), (2,), (1, 3, 2)):
for rows in 1, 2, 7:
for cols in 1, 2, 7:
shape = (rows, cols)
- name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape)))
- setattr(MatrixBandPartTest, 'testMatrixBandPart_' + name,
- _GetMatrixBandPartTest(dtype, batch_shape, shape))
- if dtype == np.float32 or dtype == np.float64:
- setattr(MatrixBandPartGradTest, 'testMatrixBandPartGrad_' + name,
- _GetMatrixBandPartGradTest(dtype, batch_shape, shape))
-
- test.main()
+ name = "%s_%s" % (dtype.__name__,
+ "_".join(map(str, batch_shape + shape)))
+ _AddTest(MatrixBandPartTest, "MatrixBandPart", name,
+ _GetMatrixBandPartTest(dtype, batch_shape, shape))
+
+ for dtype in (np.float32, np.float64):
+ for batch_shape in ((), (2,)):
+ for rows in 1, 2, 7:
+ for cols in 1, 2, 7:
+ shape = (rows, cols)
+ name = "%s_%s" % (dtype.__name__,
+ "_".join(map(str, batch_shape + shape)))
+ _AddTest(MatrixBandPartGradTest, "MatrixBandPartGrad", name,
+ _GetMatrixBandPartGradTest(dtype, batch_shape, shape))
+
+ test_lib.main()
diff --git a/tensorflow/python/kernel_tests/svd_op_test.py b/tensorflow/python/kernel_tests/svd_op_test.py
index 08650aaf36..32a623e74a 100644
--- a/tensorflow/python/kernel_tests/svd_op_test.py
+++ b/tensorflow/python/kernel_tests/svd_op_test.py
@@ -48,7 +48,7 @@ def _GetSvdOpTest(dtype_, shape_, use_static_shape_, use_gpu_):
# The gpu version returns results that are much less precise
precision_factor = 100 if use_gpu_ else 1
- tol = precision_factor * (1e-4 if is_single else 1e-12)
+ tol = precision_factor * (3e-4 if is_single else 1e-12)
def CompareSingularValues(self, x, y):
self.assertAllClose(x, y, atol=(x[0] + y[0]) * tol)
@@ -68,7 +68,7 @@ def _GetSvdOpTest(dtype_, shape_, use_static_shape_, use_gpu_):
sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True)
phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
x *= phases
- self.assertAllClose(x, y, atol=tol)
+ self.assertAllClose(x, y, atol=2 * tol)
def CheckApproximation(self, a, u, s, v, full_matrices):
# Tests that a ~= u*diag(s)*transpose(v).
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 53b8996c0c..334488b2a9 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -806,8 +806,10 @@ def fused_batch_norm(
mean = constant_op.constant([])
if variance is None:
variance = constant_op.constant([])
- # Add 1e-12 to epsilon when epsilon <= 1e-5 to prevent CUDNN exception.
- epsilon = epsilon if epsilon > 1e-5 else epsilon + 1e-12
+ # Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to
+ # prevent exception (see cudnn.h).
+ min_epsilon = 1.001e-5
+ epsilon = epsilon if epsilon > min_epsilon else min_epsilon
# pylint: disable=protected-access
y, batch_mean, batch_var, _, _ = gen_nn_ops._fused_batch_norm(
x,
diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py
index be29a0f6b1..a7c648ce2f 100644
--- a/tensorflow/python/saved_model/signature_def_utils.py
+++ b/tensorflow/python/saved_model/signature_def_utils.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""SignatureDef utility functions.
-Utility functions for constructing SignatureDef protos.
+Utility functions for building and inspecting SignatureDef protos.
"""
from __future__ import absolute_import
from __future__ import division
@@ -26,13 +26,7 @@ from tensorflow.python.saved_model.signature_def_utils_impl import classificatio
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def
# pylint: enable=unused-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = [
- "build_signature_def",
- "classification_signature_def",
- "predict_signature_def",
- "regression_signature_def",
-]
-remove_undocumented(__name__, _allowed_symbols)
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py
index 0559fb415e..7a3fb16825 100644
--- a/tensorflow/python/saved_model/signature_def_utils_impl.py
+++ b/tensorflow/python/saved_model/signature_def_utils_impl.py
@@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import utils
@@ -146,3 +148,81 @@ def predict_signature_def(inputs, outputs):
signature_constants.PREDICT_METHOD_NAME)
return signature_def
+
+
+def _get_shapes_from_tensor_info_dict(tensor_info_dict):
+ """Returns a map of keys to TensorShape objects.
+
+ Args:
+ tensor_info_dict: map with TensorInfo proto as values.
+
+ Returns:
+ Map with corresponding TensorShape objects as values.
+ """
+ return {
+ key: tensor_shape.TensorShape(tensor_info.tensor_shape)
+ for key, tensor_info in tensor_info_dict.items()
+ }
+
+
+def _get_types_from_tensor_info_dict(tensor_info_dict):
+ """Returns a map of keys to DType objects.
+
+ Args:
+ tensor_info_dict: map with TensorInfo proto as values.
+
+ Returns:
+ Map with corresponding DType objects as values.
+ """
+ return {
+ key: dtypes.DType(tensor_info.dtype)
+ for key, tensor_info in tensor_info_dict.items()
+ }
+
+
+def get_signature_def_input_shapes(signature):
+ """Returns map of parameter names to their shapes.
+
+ Args:
+ signature: SignatureDef proto.
+
+ Returns:
+ Map from string to TensorShape objects.
+ """
+ return _get_shapes_from_tensor_info_dict(signature.inputs)
+
+
+def get_signature_def_input_types(signature):
+ """Returns map of output names to their types.
+
+ Args:
+ signature: SignatureDef proto.
+
+ Returns:
+ Map from string to DType objects.
+ """
+ return _get_types_from_tensor_info_dict(signature.inputs)
+
+
+def get_signature_def_output_shapes(signature):
+ """Returns map of output names to their shapes.
+
+ Args:
+ signature: SignatureDef proto.
+
+ Returns:
+ Map from string to TensorShape objects.
+ """
+ return _get_shapes_from_tensor_info_dict(signature.outputs)
+
+
+def get_signature_def_output_types(signature):
+ """Returns map of output names to their types.
+
+ Args:
+ signature: SignatureDef proto.
+
+ Returns:
+ Map from string to DType objects.
+ """
+ return _get_types_from_tensor_info_dict(signature.outputs)
diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py
index 5859496cf3..6627602849 100644
--- a/tensorflow/python/saved_model/signature_def_utils_test.py
+++ b/tensorflow/python/saved_model/signature_def_utils_test.py
@@ -24,10 +24,23 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.saved_model import signature_def_utils_impl
from tensorflow.python.saved_model import utils
+def _make_signature(inputs, outputs, name=None):
+ input_info = {
+ input_name: utils.build_tensor_info(tensor)
+ for input_name, tensor in inputs.items()
+ }
+ output_info = {
+ output_name: utils.build_tensor_info(tensor)
+ for output_name, tensor in outputs.items()
+ }
+ return signature_def_utils_impl.build_signature_def(input_info, output_info,
+ name)
+
+
class SignatureDefUtilsTest(test.TestCase):
def testBuildSignatureDef(self):
@@ -41,8 +54,8 @@ class SignatureDefUtilsTest(test.TestCase):
outputs = dict()
outputs["foo-output"] = y_tensor_info
- signature_def = signature_def_utils.build_signature_def(inputs, outputs,
- "foo-method-name")
+ signature_def = signature_def_utils_impl.build_signature_def(
+ inputs, outputs, "foo-method-name")
self.assertEqual("foo-method-name", signature_def.method_name)
# Check inputs in signature def.
@@ -63,8 +76,8 @@ class SignatureDefUtilsTest(test.TestCase):
def testRegressionSignatureDef(self):
input1 = constant_op.constant("a", name="input-1")
output1 = constant_op.constant("b", name="output-1")
- signature_def = signature_def_utils.regression_signature_def(input1,
- output1)
+ signature_def = signature_def_utils_impl.regression_signature_def(
+ input1, output1)
self.assertEqual(signature_constants.REGRESS_METHOD_NAME,
signature_def.method_name)
@@ -89,9 +102,8 @@ class SignatureDefUtilsTest(test.TestCase):
input1 = constant_op.constant("a", name="input-1")
output1 = constant_op.constant("b", name="output-1")
output2 = constant_op.constant("c", name="output-2")
- signature_def = signature_def_utils.classification_signature_def(input1,
- output1,
- output2)
+ signature_def = signature_def_utils_impl.classification_signature_def(
+ input1, output1, output2)
self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME,
signature_def.method_name)
@@ -122,7 +134,7 @@ class SignatureDefUtilsTest(test.TestCase):
input2 = constant_op.constant("b", name="input-2")
output1 = constant_op.constant("c", name="output-1")
output2 = constant_op.constant("d", name="output-2")
- signature_def = signature_def_utils.predict_signature_def({
+ signature_def = signature_def_utils_impl.predict_signature_def({
"input-1": input1,
"input-2": input2
}, {"output-1": output1,
@@ -153,6 +165,44 @@ class SignatureDefUtilsTest(test.TestCase):
self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype)
self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim))
+ def testGetShapeAndTypes(self):
+ inputs = {
+ "input-1": constant_op.constant(["a", "b"]),
+ "input-2": array_ops.placeholder(dtypes.float32, [10, 11]),
+ }
+ outputs = {
+ "output-1": array_ops.placeholder(dtypes.float32, [10, 32]),
+ "output-2": constant_op.constant([["b"]]),
+ }
+ signature_def = _make_signature(inputs, outputs)
+ self.assertEqual(
+ signature_def_utils_impl.get_signature_def_input_shapes(signature_def),
+ {"input-1": [2], "input-2": [10, 11]})
+ self.assertEqual(
+ signature_def_utils_impl.get_signature_def_output_shapes(signature_def),
+ {"output-1": [10, 32], "output-2": [1, 1]})
+ self.assertEqual(
+ signature_def_utils_impl.get_signature_def_input_types(signature_def),
+ {"input-1": dtypes.string, "input-2": dtypes.float32})
+ self.assertEqual(
+ signature_def_utils_impl.get_signature_def_output_types(signature_def),
+ {"output-1": dtypes.float32, "output-2": dtypes.string})
+
+ def testGetNonFullySpecifiedShapes(self):
+ outputs = {
+ "output-1": array_ops.placeholder(dtypes.float32, [None, 10, None]),
+ "output-2": array_ops.sparse_placeholder(dtypes.float32),
+ }
+ signature_def = _make_signature({}, outputs)
+ shapes = signature_def_utils_impl.get_signature_def_output_shapes(
+ signature_def)
+ self.assertEqual(len(shapes), 2)
+ # Must compare shapes with as_list() since 2 equivalent non-fully defined
+ # shapes are not equal to each other.
+ self.assertEqual(shapes["output-1"].as_list(), [None, 10, None])
+ # Must compare `dims` since its an unknown shape.
+ self.assertEqual(shapes["output-2"].dims, None)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index 977b0df08b..9ed125704b 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -76,6 +76,10 @@ def getcallargs(func, *positional, **named):
return call_args
+def getframeinfo(*args, **kwargs):
+ return _inspect.getframeinfo(*args, **kwargs)
+
+
def getdoc(object): # pylint: disable=redefined-builtin
"""TFDecorator-aware replacement for inspect.getdoc.
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 5c156e7ee2..54649dab01 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -980,6 +980,7 @@ check_deps = rule(
def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[]):
cuda_deps = [
clean_dep("//tensorflow/core:stream_executor_headers_lib"),
+ "@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudart_static",
]
deps = deps + tf_custom_op_library_additional_deps()
diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD
new file mode 100644
index 0000000000..3896a21b99
--- /dev/null
+++ b/tensorflow/tools/api/generator/BUILD
@@ -0,0 +1,53 @@
+# Description:
+# Scripts used to generate TensorFlow Python API.
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_binary(
+ name = "create_python_api",
+ srcs = ["create_python_api.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "create_python_api_test",
+ srcs = ["create_python_api_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":create_python_api",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+genrule(
+ name = "python_api_gen",
+ # List of API files. This list should include file name for
+ # every module exported using tf_export. For e.g. if an op is decorated with
+ # @tf_export('module1.module2', 'module3'). Then, outs should include
+ # api/module1/module2/__init__.py and api/module3/__init__.py.
+ outs = ["api/__init__.py"],
+ cmd = "$(location create_python_api) $(OUTS)",
+ tools = ["create_python_api"],
+)
+
+py_library(
+ name = "python_api",
+ srcs = [":python_api_gen"],
+ srcs_version = "PY2AND3",
+)
diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py
new file mode 100644
index 0000000000..5f1286aaf6
--- /dev/null
+++ b/tensorflow/tools/api/generator/create_python_api.py
@@ -0,0 +1,179 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Generates and prints out imports and constants for new TensorFlow python api.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import collections
+import os
+import sys
+
+# This import is needed so that we can traverse over TensorFlow modules.
+import tensorflow as tf # pylint: disable=unused-import
+from tensorflow.python.util import tf_decorator
+
+
+_API_CONSTANTS_ATTR = '_tf_api_constants'
+_API_NAMES_ATTR = '_tf_api_names'
+_API_DIR = '/api/'
+_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
+
+This file is MACHINE GENERATED! Do not edit.
+Generated by: tensorflow/tools/api/generator/create_python_api.py script.
+\"\"\"
+"""
+
+
+def format_import(source_module_name, source_name, dest_name):
+ """Formats import statement.
+
+ Args:
+ source_module_name: (string) Source module to import from.
+ source_name: (string) Source symbol name to import.
+ dest_name: (string) Destination alias name.
+
+ Returns:
+ An import statement string.
+ """
+ if source_name == dest_name:
+ return 'from %s import %s' % (source_module_name, source_name)
+ else:
+ return 'from %s import %s as %s' % (
+ source_module_name, source_name, dest_name)
+
+
+def get_api_imports():
+ """Get a map from destination module to formatted imports.
+
+ Returns:
+ A dictionary where
+ key: (string) destination module (for e.g. tf or tf.consts).
+ value: List of strings representing module imports
+ (for e.g. 'from foo import bar') and constant
+ assignments (for e.g. 'FOO = 123').
+ """
+ module_imports = collections.defaultdict(list)
+ # Traverse over everything imported above. Specifically,
+ # we want to traverse over TensorFlow Python modules.
+ for module in sys.modules.values():
+ # Only look at tensorflow modules.
+ if not module or 'tensorflow.' not in module.__name__:
+ continue
+
+ for module_contents_name in dir(module):
+ attr = getattr(module, module_contents_name)
+
+ # If attr is _tf_api_constants attribute, then add the constants.
+ if module_contents_name == _API_CONSTANTS_ATTR:
+ for exports, value in attr:
+ for export in exports:
+ names = ['tf'] + export.split('.')
+ dest_module = '.'.join(names[:-1])
+ import_str = format_import(module.__name__, value, names[-1])
+ module_imports[dest_module].append(import_str)
+ continue
+
+ _, attr = tf_decorator.unwrap(attr)
+ # If attr is a symbol with _tf_api_names attribute, then
+ # add import for it.
+ if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
+ # The same op might be accessible from multiple modules.
+ # We only want to consider location where function was defined.
+ if attr.__module__ != module.__name__:
+ continue
+
+ for export in attr._tf_api_names: # pylint: disable=protected-access
+ names = ['tf'] + export.split('.')
+ dest_module = '.'.join(names[:-1])
+ import_str = format_import(
+ module.__name__, module_contents_name, names[-1])
+ module_imports[dest_module].append(import_str)
+
+ # Import all required modules in their parent modules.
+ # For e.g. if we import 'tf.foo.bar.Value'. Then, we also
+ # import 'bar' in 'tf.foo'.
+ for dest_module in module_imports.keys():
+ dest_module_split = dest_module.split('.')
+ for dest_submodule_index in range(1, len(dest_module_split)):
+ dest_submodule = '.'.join(dest_module_split[:dest_submodule_index])
+ submodule_import = format_import(
+ '', dest_module_split[dest_submodule_index],
+ dest_module_split[dest_submodule_index])
+ if submodule_import not in module_imports[dest_submodule]:
+ module_imports[dest_submodule].append(submodule_import)
+
+ return module_imports
+
+
+def create_api_files(output_files):
+ """Creates __init__.py files for the Python API.
+
+ Args:
+ output_files: List of __init__.py file paths to create.
+ Each file must be under api/ directory.
+
+ Raises:
+ ValueError: if an output file is not under api/ directory,
+ or output_files list is missing a required file.
+ """
+ module_name_to_file_path = {}
+ for output_file in output_files:
+ if _API_DIR not in output_file:
+ raise ValueError(
+ 'Output files must be in api/ directory, found %s.' % output_file)
+ # Get the module name that corresponds to output_file.
+ # First get module directory under _API_DIR.
+ module_dir = os.path.dirname(
+ output_file[output_file.rfind(_API_DIR)+len(_API_DIR):])
+ # Convert / to . and prefix with tf.
+ module_name = '.'.join(['tf', module_dir.replace('/', '.')]).strip('.')
+ module_name_to_file_path[module_name] = output_file
+
+ # Create file for each expected output in genrule.
+ for module, file_path in module_name_to_file_path.items():
+ if not os.path.isdir(os.path.dirname(file_path)):
+ os.makedirs(os.path.dirname(file_path))
+ open(file_path, 'a').close()
+
+ # Add imports to output files.
+ module_imports = get_api_imports()
+ missing_output_files = []
+ for module, exports in module_imports.items():
+ # Make sure genrule output file list is in sync with API exports.
+ if module not in module_name_to_file_path:
+ missing_output_files.append(module)
+ continue
+ with open(module_name_to_file_path[module], 'w') as fp:
+ fp.write(_GENERATED_FILE_HEADER + '\n'.join(exports))
+
+ if missing_output_files:
+ raise ValueError(
+ 'Missing outputs for python_api_gen genrule:\n%s' %
+ ',\n'.join(missing_output_files))
+
+
+def main(output_files):
+ create_api_files(output_files)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ 'outputs', metavar='O', type=str, nargs='+',
+ help='Python files that we expect this script to output.')
+ args = parser.parse_args()
+ main(args.outputs)
diff --git a/tensorflow/tools/api/generator/create_python_api_test.py b/tensorflow/tools/api/generator/create_python_api_test.py
new file mode 100644
index 0000000000..2760779e6e
--- /dev/null
+++ b/tensorflow/tools/api/generator/create_python_api_test.py
@@ -0,0 +1,86 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for create_python_api."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import imp
+import sys
+
+from tensorflow.python.platform import test
+from tensorflow.python.util.tf_export import tf_export
+from tensorflow.tools.api.generator import create_python_api
+
+
+@tf_export('test_op', 'test_op1')
+def test_op():
+ pass
+
+
+@tf_export('TestClass', 'NewTestClass')
+class TestClass(object):
+ pass
+
+
+_TEST_CONSTANT = 5
+_MODULE_NAME = 'test.tensorflow.test_module'
+
+
+class CreatePythonApiTest(test.TestCase):
+
+ def setUp(self):
+ # Add fake op to a module that has 'tensorflow' in the name.
+ sys.modules[_MODULE_NAME] = imp.new_module(_MODULE_NAME)
+ setattr(sys.modules[_MODULE_NAME], 'test_op', test_op)
+ setattr(sys.modules[_MODULE_NAME], 'TestClass', TestClass)
+ test_op.__module__ = _MODULE_NAME
+ TestClass.__module__ = _MODULE_NAME
+ tf_export('consts._TEST_CONSTANT').export_constant(
+ _MODULE_NAME, '_TEST_CONSTANT')
+
+ def tearDown(self):
+ del sys.modules[_MODULE_NAME]
+
+ def testFunctionImportIsAdded(self):
+ imports = create_python_api.get_api_imports()
+ expected_import = (
+ 'from test.tensorflow.test_module import test_op as test_op1')
+ self.assertTrue(
+ expected_import in str(imports),
+ msg='%s not in %s' % (expected_import, str(imports)))
+
+ expected_import = 'from test.tensorflow.test_module import test_op'
+ self.assertTrue(
+ expected_import in str(imports),
+ msg='%s not in %s' % (expected_import, str(imports)))
+
+ def testClassImportIsAdded(self):
+ imports = create_python_api.get_api_imports()
+ expected_import = 'from test.tensorflow.test_module import TestClass'
+ self.assertTrue(
+ 'TestClass' in str(imports),
+ msg='%s not in %s' % (expected_import, str(imports)))
+
+ def testConstantIsAdded(self):
+ imports = create_python_api.get_api_imports()
+ expected = 'from test.tensorflow.test_module import _TEST_CONSTANT'
+ self.assertTrue(expected in str(imports),
+ msg='%s not in %s' % (expected, str(imports)))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index 1b97c0d108..4cfaf68ef3 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -72,7 +72,7 @@ RUN mkdir /bazel && \
RUN git clone https://github.com/tensorflow/tensorflow.git && \
cd tensorflow && \
- git checkout r1.3
+ git checkout r1.4
WORKDIR /tensorflow
# TODO(craigcitro): Don't install the pip package, since it makes it
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index 80b45ae704..8d7e759bb2 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -73,7 +73,7 @@ RUN mkdir /bazel && \
RUN git clone https://github.com/tensorflow/tensorflow.git && \
cd tensorflow && \
- git checkout r1.3
+ git checkout r1.4
WORKDIR /tensorflow
# Configure the build for our CUDA configuration.
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index bebf7dbc00..00dffc4d27 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -29,7 +29,7 @@ from setuptools.dist import Distribution
# This version string is semver compatible, but incompatible with pip.
# For pip, we will remove all '-' characters from this string, and use the
# result for pip.
-_VERSION = '1.3.0'
+_VERSION = '1.4.0-dev'
REQUIRED_PACKAGES = [
'enum34 >= 1.1.6',
@@ -233,8 +233,8 @@ setup(
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
- 'Topic :: Software Development :: Libraries',
- 'Topic :: Software Development :: Libraries :: Python Modules',
+ 'Topic :: Software Development :: Libraries',
+ 'Topic :: Software Development :: Libraries :: Python Modules',
],
license='Apache 2.0',
keywords='tensorflow tensor machine learning',)