aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rwxr-xr-xconfigure814
-rw-r--r--configure.py950
-rw-r--r--tensorflow/c/c_api.cc134
-rw-r--r--tensorflow/c/c_api_test.cc50
-rw-r--r--tensorflow/cc/framework/gradients.cc4
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc5
-rw-r--r--tensorflow/compiler/tests/BUILD14
-rw-r--r--tensorflow/compiler/tests/segment_reduction_ops_test.py139
-rw-r--r--tensorflow/compiler/tests/tensor_array_ops_test.py4
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc22
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/no_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc155
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc14
-rw-r--r--tensorflow/compiler/xla/BUILD5
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc6
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.h13
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc18
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc31
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1117
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h5
-rw-r--r--tensorflow/compiler/xla/service/reduce_precision_insertion.cc39
-rw-r--r--tensorflow/compiler/xla/service/reduce_precision_insertion.h22
-rw-r--r--tensorflow/compiler/xla/statusor.cc22
-rw-r--r--tensorflow/compiler/xla/statusor.h323
-rw-r--r--tensorflow/compiler/xla/statusor_internals.h245
-rw-r--r--tensorflow/compiler/xla/statusor_test.cc22
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc163
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc405
-rw-r--r--tensorflow/compiler/xla/tests/reduce_precision_test.cc90
-rw-r--r--tensorflow/compiler/xla/xla.proto23
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt1
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py29
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py17
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py66
-rw-r--r--tensorflow/core/framework/attr_value_util.cc6
-rw-r--r--tensorflow/core/framework/attr_value_util.h6
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc83
-rw-r--r--tensorflow/core/kernels/BUILD43
-rw-r--r--tensorflow/core/kernels/compare_and_bitpack_op.cc185
-rw-r--r--tensorflow/core/kernels/compare_and_bitpack_op.h42
-rw-r--r--tensorflow/core/kernels/compare_and_bitpack_op_gpu.cu.cc141
-rw-r--r--tensorflow/core/kernels/population_count_op.cc163
-rw-r--r--tensorflow/core/kernels/population_count_op.h38
-rw-r--r--tensorflow/core/kernels/population_count_op_gpu.cu.cc92
-rw-r--r--tensorflow/core/ops/bitwise_ops.cc16
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt56
-rw-r--r--tensorflow/core/ops/math_ops.cc58
-rw-r--r--tensorflow/core/ops/ops.pbtxt96
-rw-r--r--tensorflow/core/platform/default/build_config_root.bzl2
-rw-r--r--tensorflow/core/platform/macros.h3
-rw-r--r--tensorflow/core/platform/posix/error.cc6
-rw-r--r--tensorflow/go/op/wrappers.go237
-rw-r--r--tensorflow/python/kernel_tests/BUILD12
-rw-r--r--tensorflow/python/kernel_tests/compare_and_bitpack_op_test.py83
-rw-r--r--tensorflow/python/kernel_tests/gather_op_test.py102
-rw-r--r--tensorflow/python/ops/array_ops.py2
-rw-r--r--tensorflow/python/ops/bitwise_ops.py1
-rw-r--r--tensorflow/python/ops/bitwise_ops_test.py23
-rw-r--r--tensorflow/python/ops/control_flow_ops.py10
-rw-r--r--tensorflow/python/ops/init_ops.py1
-rw-r--r--tensorflow/python/ops/nn_test.py18
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py18
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt2
-rw-r--r--tensorflow/tools/ci_build/builds/builds_common.sh1
-rwxr-xr-xtensorflow/tools/ci_build/builds/configured2
-rwxr-xr-xtensorflow/tools/ci_build/builds/pip.sh27
-rwxr-xr-xtensorflow/tools/ci_build/builds/run_pip_tests.sh2
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_cc_core.sh2
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py2_core.sh2
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh2
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py3_core.sh2
-rwxr-xr-xtensorflow/tools/ci_build/linux/gpu/run_cc_core.sh2
-rwxr-xr-xtensorflow/tools/ci_build/linux/gpu/run_py3_core.sh2
-rwxr-xr-xtensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh2
-rwxr-xr-xtensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh2
-rw-r--r--third_party/llvm/llvm.BUILD4
81 files changed, 4571 insertions, 2010 deletions
diff --git a/configure b/configure
index 1eeaffaf74..c6df6992d9 100755
--- a/configure
+++ b/configure
@@ -3,816 +3,12 @@
set -e
set -o pipefail
-MIN_BAZEL_VERSION=0.4.5
-
-# Find out the absolute path to where ./configure resides
-pushd `dirname $0` > /dev/null
-SOURCE_BASE_DIR=`pwd -P`
-popd > /dev/null
-
-PLATFORM="$(uname -s | tr 'A-Z' 'a-z')"
-
-function is_linux() {
- [[ "${PLATFORM}" == "linux" ]]
-}
-
-function is_macos() {
- [[ "${PLATFORM}" == "darwin" ]]
-}
-
-function is_windows() {
- # On windows, the shell script is actually running in msys
- [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]]
-}
-
-function is_ppc64le() {
- [[ "$(uname -m)" == "ppc64le" ]]
-}
-
-function sed_in_place() {
- sed -e $1 $2 > "$2.bak"
- mv "$2.bak" $2
-}
-
-function write_to_bazelrc() {
- echo "$1" >> .tf_configure.bazelrc
-}
-
-function write_action_env_to_bazelrc() {
- write_to_bazelrc "build --action_env $1=\"$2\""
-}
-
-function python_path {
- "$PYTHON_BIN_PATH" - <<END
-from __future__ import print_function
-import site
-import os
-
-try:
- input = raw_input
-except NameError:
- pass
-
-python_paths = []
-if os.getenv('PYTHONPATH') is not None:
- python_paths = os.getenv('PYTHONPATH').split(':')
-try:
- library_paths = site.getsitepackages()
-except AttributeError:
- from distutils.sysconfig import get_python_lib
- library_paths = [get_python_lib()]
-all_paths = set(python_paths + library_paths)
-
-paths = []
-for path in all_paths:
- if os.path.isdir(path):
- paths.append(path)
-
-print(",".join(paths))
-END
-}
-
-function setup_python {
- ## Set up python-related environment settings:
- while true; do
- fromuser=""
- if [ -z "$PYTHON_BIN_PATH" ]; then
- default_python_bin_path=$(which python || which python3 || true)
- read -p "Please specify the location of python. [Default is $default_python_bin_path]: " PYTHON_BIN_PATH
- fromuser="1"
- if [ -z "$PYTHON_BIN_PATH" ]; then
- PYTHON_BIN_PATH=$default_python_bin_path
- fi
- fi
- if [ -e "$PYTHON_BIN_PATH" ]; then
- break
- fi
- echo "Invalid python path. ${PYTHON_BIN_PATH} cannot be found" 1>&2
- if [ -z "$fromuser" ]; then
- exit 1
- fi
- PYTHON_BIN_PATH=""
- # Retry
- done
-
- if [ -z "$PYTHON_LIB_PATH" ]; then
- # Split python_path into an array of paths, this allows path containing spaces
- IFS=',' read -r -a python_lib_path <<< "$(python_path)"
-
- if [ 1 = "$USE_DEFAULT_PYTHON_LIB_PATH" ]; then
- PYTHON_LIB_PATH=${python_lib_path[0]}
- echo "Using python library path: $PYTHON_LIB_PATH"
-
- else
- echo "Found possible Python library paths:"
- for x in "${python_lib_path[@]}"; do
- echo " $x"
- done
- set -- "${python_lib_path[@]}"
- echo "Please input the desired Python library path to use. Default is [$1]"
- read b || true
- if [ "$b" == "" ]; then
- PYTHON_LIB_PATH=${python_lib_path[0]}
- echo "Using python library path: $PYTHON_LIB_PATH"
- else
- PYTHON_LIB_PATH="$b"
- fi
- fi
- fi
-
- if [ ! -x "$PYTHON_BIN_PATH" ] || [ -d "$PYTHON_BIN_PATH" ]; then
- echo "PYTHON_BIN_PATH is not executable. Is it the python binary?"
- exit 1
- fi
-
- local python_major_version
- python_major_version=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; import sys; print(sys.version_info[0]);' | head -c1)
- if [ -z "$python_major_version" ]; then
- echo -e "\n\nERROR: Problem getting python version. Is $PYTHON_BIN_PATH the correct python binary?"
- exit 1
- fi
-
- # Convert python path to Windows style before writing into bazel.rc
- if is_windows; then
- PYTHON_BIN_PATH="$(cygpath -m "$PYTHON_BIN_PATH")"
- PYTHON_LIB_PATH="$(cygpath -m "$PYTHON_LIB_PATH")"
- fi
-
- # Set-up env variables used by python_configure.bzl
- write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH"
- write_action_env_to_bazelrc "PYTHON_LIB_PATH" "$PYTHON_LIB_PATH"
- write_to_bazelrc "build --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\""
- write_to_bazelrc "build --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\""
- write_to_bazelrc "build --force_python=py$python_major_version"
- write_to_bazelrc "build --host_force_python=py$python_major_version"
- write_to_bazelrc "build --python${python_major_version}_path=\"$PYTHON_BIN_PATH\""
- write_to_bazelrc "test --force_python=py$python_major_version"
- write_to_bazelrc "test --host_force_python=py$python_major_version"
- write_to_bazelrc "test --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\""
- write_to_bazelrc "test --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\""
- write_to_bazelrc "run --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\""
- write_to_bazelrc "run --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\""
-
- # Write tools/python_bin_path.sh
- echo "export PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" > tools/python_bin_path.sh
-}
-
-function version {
- echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }';
-}
-
-
-bazel version > bazel.version
-set +e
-curr_bazel_version=$(grep -m 1 'Build label:' bazel.version | cut -d ' ' -f3)
-set -e
-rm -f bazel.version
-
-
-echo "You have bazel $curr_bazel_version installed."
-if [ -z "$curr_bazel_version" ]; then
- echo "WARNING: current bazel installation is not a release version."
- echo "Make sure you are running at least bazel $MIN_BAZEL_VERSION."
-elif [ "$(version "$MIN_BAZEL_VERSION")" -gt "$(version "$curr_bazel_version")" ]; then
- echo "Please upgrade your bazel installation to version $MIN_BAZEL_VERSION or higher to build TensorFlow!"
- echo "Exiting..."
- exit 1
-fi
-
-# This file contains customized config settings.
-rm -f .tf_configure.bazelrc
-touch .tf_configure.bazelrc
-if [[ ! -e .bazelrc ]]; then
- if [[ -e "${HOME}/.bazelrc" ]]; then
- echo "import ${HOME}/.bazelrc" >.bazelrc
- else
- touch .bazelrc
- fi
-fi
-sed_in_place "/tf_configure/d" .bazelrc
-echo "import %workspace%/.tf_configure.bazelrc" >> .bazelrc
-
-# Delete any leftover BUILD files from the Makefile build, which would interfere
-# with Bazel parsing.
-MAKEFILE_DOWNLOAD_DIR=tensorflow/contrib/makefile/downloads
-if [ -d "${MAKEFILE_DOWNLOAD_DIR}" ]; then
- find ${MAKEFILE_DOWNLOAD_DIR} -type f -name '*BUILD' -delete
-fi
-
-setup_python
-
-## Set up MKL related environment settings
-write_to_bazelrc 'build:mkl --define with_mkl_support=true'
-write_to_bazelrc 'build:mkl --define using_mkl=true'
-write_to_bazelrc 'build:mkl -c opt'
-write_to_bazelrc 'build:mkl --copt="-DEIGEN_USE_VML"'
-echo ""
-echo "Add \"--config=mkl\" to your bazel command to build with MKL support."
-echo "Please note that MKL on MacOS or windows is still not supported."
-echo "If you would like to use a local MKL instead of downloading, please "
-echo " set the environment variable \"TF_MKL_ROOT\" every time before build."
-echo ""
-## End MKL setup
-
-## Set up architecture-dependent optimization flags.
-if [ -z "$CC_OPT_FLAGS" ]; then
- if is_ppc64le; then
- # gcc on ppc64le does not support -march, use mcpu instead
- default_cc_opt_flags="-mcpu=native"
- else
- default_cc_opt_flags="-march=native"
- fi
- read -p "Please specify optimization flags to use during compilation when bazel option "\
-"\"--config=opt\" is specified [Default is $default_cc_opt_flags]: " CC_OPT_FLAGS
- if [ -z "$CC_OPT_FLAGS" ]; then
- CC_OPT_FLAGS=$default_cc_opt_flags
- fi
-fi
-
-if is_windows; then
- TF_NEED_GCP=0
- TF_NEED_HDFS=0
- TF_NEED_JEMALLOC=0
- TF_NEED_OPENCL=0
- TF_CUDA_CLANG=0
-fi
-
-if is_linux; then
- while [ "$TF_NEED_JEMALLOC" == "" ]; do
- read -p "Do you wish to use jemalloc as the malloc implementation? [Y/n] "\
- INPUT
- case $INPUT in
- [Yy]* ) echo "jemalloc enabled"; TF_NEED_JEMALLOC=1;;
- [Nn]* ) echo "jemalloc disabled"; TF_NEED_JEMALLOC=0;;
- "" ) echo "jemalloc enabled"; TF_NEED_JEMALLOC=1;;
- * ) echo "Invalid selection: " $INPUT;;
- esac
- done
-else
- TF_NEED_JEMALLOC=0
-fi
-
-if [[ "$TF_NEED_JEMALLOC" == "1" ]]; then
- write_to_bazelrc 'build --define with_jemalloc=true'
-fi
-
-while [[ "$TF_NEED_GCP" == "" ]]; do
- read -p "Do you wish to build TensorFlow with "\
-"Google Cloud Platform support? [y/N] " INPUT
- case $INPUT in
- [Yy]* ) echo "Google Cloud Platform support will be enabled for "\
-"TensorFlow"; TF_NEED_GCP=1;;
- [Nn]* ) echo "No Google Cloud Platform support will be enabled for "\
-"TensorFlow"; TF_NEED_GCP=0;;
- "" ) echo "No Google Cloud Platform support will be enabled for "\
-"TensorFlow"; TF_NEED_GCP=0;;
- * ) echo "Invalid selection: " $INPUT;;
- esac
-done
-
-if [[ "$TF_NEED_GCP" == "1" ]]; then
- write_to_bazelrc 'build --define with_gcp_support=true'
-fi
-
-while [[ "$TF_NEED_HDFS" == "" ]]; do
- read -p "Do you wish to build TensorFlow with "\
-"Hadoop File System support? [y/N] " INPUT
- case $INPUT in
- [Yy]* ) echo "Hadoop File System support will be enabled for "\
-"TensorFlow"; TF_NEED_HDFS=1;;
- [Nn]* ) echo "No Hadoop File System support will be enabled for "\
-"TensorFlow"; TF_NEED_HDFS=0;;
- "" ) echo "No Hadoop File System support will be enabled for "\
-"TensorFlow"; TF_NEED_HDFS=0;;
- * ) echo "Invalid selection: " $INPUT;;
- esac
-done
-
-if [[ "$TF_NEED_HDFS" == "1" ]]; then
- write_to_bazelrc 'build --define with_hdfs_support=true'
-fi
-
-## Enable XLA.
-while [[ "$TF_ENABLE_XLA" == "" ]]; do
- read -p "Do you wish to build TensorFlow with the XLA just-in-time compiler (experimental)? [y/N] " INPUT
- case $INPUT in
- [Yy]* ) echo "XLA JIT support will be enabled for TensorFlow"; TF_ENABLE_XLA=1;;
- [Nn]* ) echo "No XLA JIT support will be enabled for TensorFlow"; TF_ENABLE_XLA=0;;
- "" ) echo "No XLA support will be enabled for TensorFlow"; TF_ENABLE_XLA=0;;
- * ) echo "Invalid selection: " $INPUT;;
- esac
-done
-
-if [[ "$TF_ENABLE_XLA" == "1" ]]; then
- write_to_bazelrc 'build --define with_xla_support=true'
-fi
-
-# Verbs configuration
-while [ "$TF_NEED_VERBS" == "" ]; do
- read -p "Do you wish to build TensorFlow with "\
-"VERBS support? [y/N] " INPUT
- case $INPUT in
- [Yy]* ) echo "VERBS support will be enabled for "\
-"TensorFlow"; TF_NEED_VERBS=1;;
- [Nn]* ) echo "No VERBS support will be enabled for "\
-"TensorFlow"; TF_NEED_VERBS=0;;
- "" ) echo "No VERBS support will be enabled for "\
-"TensorFlow"; TF_NEED_VERBS=0;;
- * ) echo "Invalid selection: " $INPUT;;
- esac
-done
-
-if [[ "$TF_NEED_VERBS" == "1" ]]; then
- write_to_bazelrc 'build --define with_verbs_support=true'
-fi
-
-# Append CC optimization flags to bazel.rc
-for opt in $CC_OPT_FLAGS; do
- write_to_bazelrc "build:opt --cxxopt=$opt --copt=$opt"
-done
-
-# Run the gen_git_source to create links where bazel can track dependencies for
-# git hash propagation
-GEN_GIT_SOURCE=tensorflow/tools/git/gen_git_source.py
-chmod a+x ${GEN_GIT_SOURCE}
-"${PYTHON_BIN_PATH}" ${GEN_GIT_SOURCE} --configure "${SOURCE_BASE_DIR}"
-
-## Set up SYCL-related environment settings
-while [ "$TF_NEED_OPENCL" == "" ]; do
- read -p "Do you wish to build TensorFlow with OpenCL support? [y/N] " INPUT
- case $INPUT in
- [Yy]* ) echo "OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=1;;
- [Nn]* ) echo "No OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=0;;
- "" ) echo "No OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=0;;
- * ) echo "Invalid selection: " $INPUT;;
- esac
-done
-
-## Set up Cuda-related environment settings
-
-while [ "$TF_NEED_CUDA" == "" ]; do
- read -p "Do you wish to build TensorFlow with CUDA support? [y/N] " INPUT
- case $INPUT in
- [Yy]* ) echo "CUDA support will be enabled for TensorFlow"; TF_NEED_CUDA=1;;
- [Nn]* ) echo "No CUDA support will be enabled for TensorFlow"; TF_NEED_CUDA=0;;
- "" ) echo "No CUDA support will be enabled for TensorFlow"; TF_NEED_CUDA=0;;
- * ) echo "Invalid selection: " $INPUT;;
- esac
-done
-
-export TF_NEED_CUDA
-write_action_env_to_bazelrc "TF_NEED_CUDA" "$TF_NEED_CUDA"
-
-export TF_NEED_OPENCL
-write_action_env_to_bazelrc "TF_NEED_OPENCL" "$TF_NEED_OPENCL"
-
-if [ "$TF_NEED_CUDA" == "1" ]; then
-while [[ "$TF_CUDA_CLANG" == "" ]]; do
- read -p "Do you want to use clang as CUDA compiler? [y/N] " INPUT
- case $INPUT in
- [Yy]* ) echo "Clang will be used as CUDA compiler"; TF_CUDA_CLANG=1;;
- [Nn]* ) echo "nvcc will be used as CUDA compiler"; TF_CUDA_CLANG=0;;
- "" ) echo "nvcc will be used as CUDA compiler"; TF_CUDA_CLANG=0;;
- * ) echo "Invalid selection: " $INPUT;;
- esac
-done
-
-export TF_CUDA_CLANG
-write_action_env_to_bazelrc "TF_CUDA_CLANG" "$TF_CUDA_CLANG"
-
-# Set up which clang we should use as the cuda / host compiler.
-while [[ "$TF_CUDA_CLANG" == "1" ]] && true; do
- fromuser=""
- if [ -z "$CLANG_CUDA_COMPILER_PATH" ]; then
- default_clang_host_compiler_path=$(which clang || true)
- read -p "Please specify which clang should be used as device and host compiler. [Default is $default_clang_host_compiler_path]: " CLANG_CUDA_COMPILER_PATH
- fromuser="1"
- if [ -z "$CLANG_CUDA_COMPILER_PATH" ]; then
- CLANG_CUDA_COMPILER_PATH="$default_clang_host_compiler_path"
- fi
- fi
- if [ -e "$CLANG_CUDA_COMPILER_PATH" ]; then
- export CLANG_CUDA_COMPILER_PATH
- write_action_env_to_bazelrc "CLANG_CUDA_COMPILER_PATH" "$CLANG_CUDA_COMPILER_PATH"
- break
- fi
- echo "Invalid clang path. ${CLANG_CUDA_COMPILER_PATH} cannot be found" 1>&2
- if [ -z "$fromuser" ]; then
- exit 1
- fi
- CLANG_CUDA_COMPILER_PATH=""
- # Retry
-done
-
-# Find out where the CUDA toolkit is installed
-while true; do
- # Configure the Cuda SDK version to use.
- if [ -z "$TF_CUDA_VERSION" ]; then
- read -p "Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 8.0]: " TF_CUDA_VERSION
- fi
- # Set default CUDA version if not set
- TF_CUDA_VERSION=${TF_CUDA_VERSION:-8.0}
-
- fromuser=""
- if [ -z "$CUDA_TOOLKIT_PATH" ]; then
- default_cuda_path=/usr/local/cuda
- if is_windows; then
- if [ -z "$CUDA_PATH" ]; then
- default_cuda_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v8.0"
- else
- default_cuda_path="$(cygpath -m "$CUDA_PATH")"
- fi
- elif is_linux; then
- # If the default doesn't exist, try an alternative default.
- if [ ! -d $default_cuda_path ] && [ -d /opt/cuda ]; then
- default_cuda_path=/opt/cuda
- fi
- fi
- read -p "Please specify the location where CUDA $TF_CUDA_VERSION toolkit is installed. Refer to README.md for more details. [Default is $default_cuda_path]: " CUDA_TOOLKIT_PATH
- fromuser="1"
- if [ -z "$CUDA_TOOLKIT_PATH" ]; then
- CUDA_TOOLKIT_PATH="$default_cuda_path"
- fi
- fi
-
- if [[ -z "$TF_CUDA_VERSION" ]]; then
- TF_CUDA_EXT=""
- else
- TF_CUDA_EXT=".$TF_CUDA_VERSION"
- fi
-
- if is_windows; then
- CUDA_RT_LIB_PATH="lib/x64/cudart.lib"
- elif is_linux; then
- CUDA_RT_LIB_PATH="lib64/libcudart.so${TF_CUDA_EXT}"
- elif is_macos; then
- CUDA_RT_LIB_PATH="lib/libcudart${TF_CUDA_EXT}.dylib"
- fi
-
- if [ -e "${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH}" ]; then
- export CUDA_TOOLKIT_PATH
- write_action_env_to_bazelrc "CUDA_TOOLKIT_PATH" "$CUDA_TOOLKIT_PATH"
- export TF_CUDA_VERSION
- break
- fi
- echo "Invalid path to CUDA $TF_CUDA_VERSION toolkit. ${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH} cannot be found"
-
- if [ -z "$fromuser" ]; then
- exit 1
- fi
- # Retry
- TF_CUDA_VERSION=""
- CUDA_TOOLKIT_PATH=""
-done
-
-export TF_CUDA_VERSION
-write_action_env_to_bazelrc "TF_CUDA_VERSION" "$TF_CUDA_VERSION"
-
-# Set up which gcc nvcc should use as the host compiler
-# No need to set this on Windows
-while [[ "$TF_CUDA_CLANG" != "1" ]] && ! is_windows && true; do
- fromuser=""
- if [ -z "$GCC_HOST_COMPILER_PATH" ]; then
- default_gcc_host_compiler_path=$(which gcc || true)
- cuda_bin_symlink="$CUDA_TOOLKIT_PATH/bin/gcc"
- if [ -L "$cuda_bin_symlink" ]; then
- default_gcc_host_compiler_path=$(readlink $cuda_bin_symlink)
- fi
- read -p "Please specify which gcc should be used by nvcc as the host compiler. [Default is $default_gcc_host_compiler_path]: " GCC_HOST_COMPILER_PATH
- fromuser="1"
- if [ -z "$GCC_HOST_COMPILER_PATH" ]; then
- GCC_HOST_COMPILER_PATH="$default_gcc_host_compiler_path"
- fi
- fi
- if [ -e "$GCC_HOST_COMPILER_PATH" ]; then
- export GCC_HOST_COMPILER_PATH
- write_action_env_to_bazelrc "GCC_HOST_COMPILER_PATH" "$GCC_HOST_COMPILER_PATH"
- break
- fi
- echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2
- if [ -z "$fromuser" ]; then
- exit 1
- fi
- GCC_HOST_COMPILER_PATH=""
- # Retry
-done
-
-# Find out where the cuDNN library is installed
-while true; do
- # Configure the cuDNN version to use.
- if [ -z "$TF_CUDNN_VERSION" ]; then
- read -p "Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 6.0]: " TF_CUDNN_VERSION
- fi
- # Set default CUDNN version if not set
- TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-6}
-
- fromuser=""
- if [ -z "$CUDNN_INSTALL_PATH" ]; then
- default_cudnn_path=${CUDA_TOOLKIT_PATH}
- read -p "Please specify the location where cuDNN $TF_CUDNN_VERSION library is installed. Refer to README.md for more details. [Default is $default_cudnn_path]: " CUDNN_INSTALL_PATH
- fromuser="1"
- if [ -z "$CUDNN_INSTALL_PATH" ]; then
- CUDNN_INSTALL_PATH=$default_cudnn_path
- fi
- # Result returned from "read" will be used unexpanded. That make "~" unusable.
- # Going through one more level of expansion to handle that.
- CUDNN_INSTALL_PATH=`"${PYTHON_BIN_PATH}" -c "import os; print(os.path.realpath(os.path.expanduser('${CUDNN_INSTALL_PATH}')))"`
- if is_windows; then
- CUDNN_INSTALL_PATH="$(cygpath -m "$CUDNN_INSTALL_PATH")"
- fi
- fi
-
- if [[ -z "$TF_CUDNN_VERSION" ]]; then
- TF_CUDNN_EXT=""
- else
- TF_CUDNN_EXT=".$TF_CUDNN_VERSION"
- fi
-
- if is_windows; then
- CUDA_DNN_LIB_PATH="lib/x64/cudnn.lib"
- CUDA_DNN_LIB_ALT_PATH="lib/x64/cudnn.lib"
- elif is_linux; then
- CUDA_DNN_LIB_PATH="lib64/libcudnn.so${TF_CUDNN_EXT}"
- CUDA_DNN_LIB_ALT_PATH="libcudnn.so${TF_CUDNN_EXT}"
- elif is_macos; then
- CUDA_DNN_LIB_PATH="lib/libcudnn${TF_CUDNN_EXT}.dylib"
- CUDA_DNN_LIB_ALT_PATH="libcudnn${TF_CUDNN_EXT}.dylib"
- fi
-
- if [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_ALT_PATH}" ] || [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_PATH}" ]; then
- export TF_CUDNN_VERSION
- write_action_env_to_bazelrc "TF_CUDNN_VERSION" "$TF_CUDNN_VERSION"
- export CUDNN_INSTALL_PATH
- write_action_env_to_bazelrc "CUDNN_INSTALL_PATH" "$CUDNN_INSTALL_PATH"
- break
- fi
-
- if is_linux; then
- if ! type ldconfig > /dev/null 2>&1; then
- LDCONFIG_BIN=/sbin/ldconfig
- else
- LDCONFIG_BIN=ldconfig
- fi
- CUDNN_PATH_FROM_LDCONFIG="$($LDCONFIG_BIN -p | sed -n 's/.*libcudnn.so .* => \(.*\)/\1/p')"
- if [ -e "${CUDNN_PATH_FROM_LDCONFIG}${TF_CUDNN_EXT}" ]; then
- export TF_CUDNN_VERSION
- export CUDNN_INSTALL_PATH
- CUDNN_INSTALL_PATH="$(dirname ${CUDNN_PATH_FROM_LDCONFIG})"
- write_action_env_to_bazelrc "CUDNN_INSTALL_PATH" "$CUDNN_INSTALL_PATH"
- break
- fi
- fi
- echo "Invalid path to cuDNN ${CUDNN_VERSION} toolkit. Neither of the following two files can be found:"
- echo "${CUDNN_INSTALL_PATH}/${CUDA_DNN_LIB_PATH}"
- echo "${CUDNN_INSTALL_PATH}/${CUDA_DNN_LIB_ALT_PATH}"
- if is_linux; then
- echo "${CUDNN_PATH_FROM_LDCONFIG}${TF_CUDNN_EXT}"
- fi
-
- if [ -z "$fromuser" ]; then
- exit 1
- fi
- # Retry
- TF_CUDNN_VERSION=""
- CUDNN_INSTALL_PATH=""
-done
-
-export TF_CUDNN_VERSION
-write_action_env_to_bazelrc "TF_CUDNN_VERSION" "$TF_CUDNN_VERSION"
-
-# Configure the compute capabilities that TensorFlow builds for.
-# Since Cuda toolkit is not backward-compatible, this is not guaranteed to work.
-function get_native_cuda_compute_capabilities {
- device_query_bin="$CUDA_TOOLKIT_PATH/extras/demo_suite/deviceQuery" # Also works on Windows without .exe
- "$device_query_bin" | grep 'Capability' | grep -o '[0-9]*\.[0-9]*' | sed ':a;{N;s/\n/,/};ba'
- exit 0 # ensure that this function always exit success even if device detection fails, to prevent the whole configure from aborting
-}
-while true; do
- fromuser=""
- native_cuda_compute_capabilities=$(get_native_cuda_compute_capabilities)
- default_cuda_compute_capabilities=${native_cuda_compute_capabilities:-"3.5,5.2"}
- if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
-cat << EOF
-Please specify a list of comma-separated Cuda compute capabilities you want to build with.
-You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
-Please note that each additional compute capability significantly increases your build time and binary size.
-EOF
- read -p "[Default is: \"$default_cuda_compute_capabilities\"]: " TF_CUDA_COMPUTE_CAPABILITIES
- fromuser=1
- fi
- if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
- TF_CUDA_COMPUTE_CAPABILITIES=$default_cuda_compute_capabilities
- fi
- # Check whether all capabilities from the input is valid
- COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES//,/ }
- ALL_VALID=1
- for CAPABILITY in $COMPUTE_CAPABILITIES; do
- if [[ ! "$CAPABILITY" =~ [0-9]+.[0-9]+ ]]; then
- echo "Invalid compute capability: " $CAPABILITY
- ALL_VALID=0
- break
- fi
- done
- if [ "$ALL_VALID" == "0" ]; then
- if [ -z "$fromuser" ]; then
- exit 1
- fi
- else
- export TF_CUDA_COMPUTE_CAPABILITIES
- write_action_env_to_bazelrc "TF_CUDA_COMPUTE_CAPABILITIES" "$TF_CUDA_COMPUTE_CAPABILITIES"
- break
- fi
- TF_CUDA_COMPUTE_CAPABILITIES=""
-done
-
-if is_windows; then
- # The following three variables are needed for MSVC toolchain configuration in Bazel
- export CUDA_PATH="$CUDA_TOOLKIT_PATH"
- export CUDA_COMPUTE_CAPABILITIES="$TF_CUDA_COMPUTE_CAPABILITIES"
- export NO_WHOLE_ARCHIVE_OPTION=1
- write_action_env_to_bazelrc "CUDA_PATH" "$CUDA_PATH"
- write_action_env_to_bazelrc "CUDA_COMPUTE_CAPABILITIES" "$CUDA_COMPUTE_CAPABILITIES"
- write_action_env_to_bazelrc "NO_WHOLE_ARCHIVE_OPTION" "1"
- write_to_bazelrc "build --config=win-cuda"
- write_to_bazelrc "test --config=win-cuda"
-else
- # If CUDA is enabled, always use GPU during build and test.
- if [ "$TF_CUDA_CLANG" == "1" ]; then
- write_to_bazelrc "build --config=cuda_clang"
- write_to_bazelrc "test --config=cuda_clang"
- else
- write_to_bazelrc "build --config=cuda"
- write_to_bazelrc "test --config=cuda"
- fi
+if [ -z "$PYTHON_BIN_PATH" ]; then
+ PYTHON_BIN_PATH=$(which python || which python3 || true)
fi
-# end of if "$TF_NEED_CUDA" == "1"
-fi
-
-# OpenCL configuration
-
-if [ "$TF_NEED_OPENCL" == "1" ]; then
-
-# Determine which C++ compiler should be used as the host compiler
-while true; do
- fromuser=""
- if [ -z "$HOST_CXX_COMPILER" ]; then
- default_cxx_host_compiler=$(which g++ || true)
- read -p "Please specify which C++ compiler should be used as the host C++ compiler. [Default is $default_cxx_host_compiler]: " HOST_CXX_COMPILER
- fromuser="1"
- if [ -z "$HOST_CXX_COMPILER" ]; then
- HOST_CXX_COMPILER=$default_cxx_host_compiler
- fi
- fi
- if [ -e "$HOST_CXX_COMPILER" ]; then
- export HOST_CXX_COMPILER
- write_action_env_to_bazelrc "HOST_CXX_COMPILER" "$HOST_CXX_COMPILER"
- break
- fi
- echo "Invalid C++ compiler path. ${HOST_CXX_COMPILER} cannot be found" 1>&2
- if [ -z "$fromuser" ]; then
- exit 1
- fi
- HOST_CXX_COMPILER=""
- # Retry
-done
-
-# Determine which C compiler should be used as the host compiler
-while true; do
- fromuser=""
- if [ -z "$HOST_C_COMPILER" ]; then
- default_c_host_compiler=$(which gcc || true)
- read -p "Please specify which C compiler should be used as the host C compiler. [Default is $default_c_host_compiler]: " HOST_C_COMPILER
- fromuser="1"
- if [ -z "$HOST_C_COMPILER" ]; then
- HOST_C_COMPILER=$default_c_host_compiler
- fi
- fi
- if [ -e "$HOST_C_COMPILER" ]; then
- export HOST_C_COMPILER
- write_action_env_to_bazelrc "HOST_C_COMPILER" "$HOST_C_COMPILER"
- break
- fi
- echo "Invalid C compiler path. ${HOST_C_COMPILER} cannot be found" 1>&2
- if [ -z "$fromuser" ]; then
- exit 1
- fi
- HOST_C_COMPILER=""
- # Retry
-done
-
-while true; do
- # Configure the OPENCL version to use.
- TF_OPENCL_VERSION="1.2"
-
- # Point to ComputeCpp root
- if [ -z "$COMPUTECPP_TOOLKIT_PATH" ]; then
- default_computecpp_toolkit_path=/usr/local/computecpp
- read -p "Please specify the location where ComputeCpp for SYCL $TF_OPENCL_VERSION is installed. [Default is $default_computecpp_toolkit_path]: " COMPUTECPP_TOOLKIT_PATH
- fromuser="1"
- if [ -z "$COMPUTECPP_TOOLKIT_PATH" ]; then
- COMPUTECPP_TOOLKIT_PATH=$default_computecpp_toolkit_path
- fi
- fi
-
- if is_linux; then
- SYCL_RT_LIB_PATH="lib/libComputeCpp.so"
- fi
-
- if [ -e "${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH}" ]; then
- export COMPUTECPP_TOOLKIT_PATH
- write_action_env_to_bazelrc "COMPUTECPP_TOOLKIT_PATH" "$COMPUTECPP_TOOLKIT_PATH"
- break
- fi
- echo "Invalid SYCL $TF_OPENCL_VERSION library path. ${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH} cannot be found"
-
- if [ -z "$fromuser" ]; then
- exit 1
- fi
- # Retry
- TF_OPENCL_VERSION=""
- COMPUTECPP_TOOLKIT_PATH=""
-done
-
-# end of if "$TF_NEED_OPENCL" == "1"
-fi
-
-
-while [ "$TF_NEED_MPI" == "" ]; do
- read -p "Do you wish to build TensorFlow with "\
-"MPI support? [y/N] " INPUT
- case $INPUT in
- [Yy]* ) echo "MPI support will be enabled for "\
-"TensorFlow"; TF_NEED_MPI=1;;
- [Nn]* ) echo "MPI support will not be enabled for "\
-"TensorFlow"; TF_NEED_MPI=0;;
- "" ) echo "MPI support will not be enabled for "\
-"TensorFlow"; TF_NEED_MPI=0;;
- * ) echo "Invalid selection: " $INPUT;;
- esac
-done
-
-# Find out where the MPI toolkit is installed
-while true; do
- if [ "$TF_NEED_MPI" == "0" ]; then
- break;
- fi
-
- fromuser=""
- if [ -z "$MPI_HOME" ]; then
- #Get the base folder by removing the bin path
- default_mpi_path=$(dirname $(dirname $(which mpirun)) || dirname $(dirname $(which mpiexec)) || true)
- read -p "Please specify the MPI toolkit folder. [Default is $default_mpi_path]: " MPI_HOME
- fromuser="1"
- if [ -z "$MPI_HOME" ]; then
- MPI_HOME=$default_mpi_path
- fi
- fi
-
- #Check that the include and library folders are where we expect them to be
- if [ -e "$MPI_HOME/include" ] && [ -e "$MPI_HOME/lib" ]; then
- break
- fi
-
- echo "Invalid path to the MPI Toolkit. ${MPI_HOME}/include or ${MPI_HOME}/lib cannot be found."
- if [ -z "$fromuser" ]; then
- exit 1
- fi
-
- # Retry
- MPI_HOME=""
-done
-
-
-if [ "$TF_NEED_MPI" == "1" ]; then
- write_to_bazelrc 'build --define with_mpi_support=true'
-
- #Link the MPI header files
- ln -sf "${MPI_HOME}/include/mpi.h" third_party/mpi/mpi.h
-
-
- #Determine if we use OpenMPI or MVAPICH, these require different header files
- #to be included here to make bazel dependency checker happy
-
- if [ -e "${MPI_HOME}/include/mpi_portable_platform.h" ]; then
- #OpenMPI
- ln -sf "${MPI_HOME}/include/mpi_portable_platform.h" third_party/mpi/
- sed -i -e "s/MPI_LIB_IS_OPENMPI=False/MPI_LIB_IS_OPENMPI=True/" third_party/mpi/mpi.bzl
- else
- #MVAPICH / MPICH
- ln -sf "${MPI_HOME}/include/mpio.h" third_party/mpi/
- ln -sf "${MPI_HOME}/include/mpicxx.h" third_party/mpi/
- sed -i -e "s/MPI_LIB_IS_OPENMPI=True/MPI_LIB_IS_OPENMPI=False/" third_party/mpi/mpi.bzl
- fi
-
-
- if [ -e "${MPI_HOME}/lib/libmpi.so" ]; then
- ln -sf "${MPI_HOME}/lib/libmpi.so" third_party/mpi/
- else
- echo "Cannot find the MPI library file in ${MPI_HOME}/lib "
- exit 1
- fi
-fi
+# Set all env variables
+$PYTHON_BIN_PATH configure.py
-echo "Configuration finished"
+echo "Configuration finished" \ No newline at end of file
diff --git a/configure.py b/configure.py
new file mode 100644
index 0000000000..fac00d1b74
--- /dev/null
+++ b/configure.py
@@ -0,0 +1,950 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""configure script to get build parameters from user."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import errno
+import os
+import platform
+import re
+import site
+import subprocess
+import sys
+
+_TF_BAZELRC = '.tf_configure.bazelrc'
+_DEFAULT_CUDA_VERSION = '8.0'
+_DEFAULT_CUDNN_VERSION = '6'
+_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
+_DEFAULT_CUDA_PATH = '/usr/local/cuda'
+_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
+_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
+ 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
+_TF_OPENCL_VERSION = '1.2'
+_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
+
+
+def is_windows():
+ return platform.system() == 'Windows'
+
+
+def is_linux():
+ return platform.system() == 'Linux'
+
+
+def is_macos():
+ return platform.system() == 'Darwin'
+
+
+def is_ppc64le():
+ return platform.machine() == 'ppc64le'
+
+
+def get_input(question):
+ try:
+ try:
+ answer = raw_input(question)
+ except NameError:
+ answer = input(question) # pylint: disable=bad-builtin
+ except EOFError:
+ answer = ''
+ return answer
+
+
+def symlink_force(target, link_name):
+ """Force symlink, equivalent of 'ln -sf'.
+
+ Args:
+ target: items to link to.
+ link_name: name of the link.
+ """
+ try:
+ os.symlink(target, link_name)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ os.remove(link_name)
+ os.symlink(target, link_name)
+ else:
+ raise e
+
+
+def sed_in_place(filename, old, new):
+ """Replace old string with new string in file.
+
+ Args:
+ filename: string for filename.
+ old: string to replace.
+ new: new string to replace to.
+ """
+ with open(filename, 'r') as f:
+ filedata = f.read()
+ newdata = filedata.replace(old, new)
+ with open(filename, 'w') as f:
+ f.write(newdata)
+
+
+def remove_line_with(filename, token):
+ """Remove lines that contain token from file.
+
+ Args:
+ filename: string for filename.
+ token: string token to check if to remove a line from file or not.
+ """
+ with open(filename, 'r') as f:
+ filedata = f.read()
+
+ with open(filename, 'w') as f:
+ for line in filedata.strip().split('\n'):
+ if token not in line:
+ f.write(line + '\n')
+
+
+def write_to_bazelrc(line):
+ with open(_TF_BAZELRC, 'a') as f:
+ f.write(line + '\n')
+
+
+def write_action_env_to_bazelrc(var_name, var):
+ write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
+
+
+def run_shell(cmd):
+ return subprocess.check_output(cmd, shell=True).decode('UTF-8').strip()
+
+
+def cygpath(path):
+ """Convert path from posix to windows."""
+ return run_shell('cygpath -m "%s"' % path)
+
+
+def get_python_path(environ_cp):
+ """Get the python site package paths."""
+ python_paths = []
+ if environ_cp.get('PYTHONPATH'):
+ python_paths = environ_cp.get('PYTHONPATH').split(':')
+ try:
+ library_paths = site.getsitepackages()
+ except AttributeError:
+ from distutils.sysconfig import get_python_lib # pylint: disable=g-import-not-at-top
+ library_paths = [get_python_lib()]
+ all_paths = set(python_paths + library_paths)
+
+ paths = []
+ for path in all_paths:
+ if os.path.isdir(path):
+ paths.append(path)
+ return paths
+
+
+def setup_python(environ_cp):
+ """Setup python related env variables."""
+ # Get PYTHON_BIN_PATH, default is the current running python.
+ default_python_bin_path = sys.executable
+ ask_python_bin_path = ('Please specify the location of python. [Default is '
+ '%s]: ') % default_python_bin_path
+ while True:
+ python_bin_path = get_from_env_or_user_or_default(
+ environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path,
+ default_python_bin_path)
+ # Check if the path is valid
+ if (os.path.isfile(python_bin_path) and os.access(
+ python_bin_path, os.X_OK)) or (os.path.isdir(python_bin_path)):
+ break
+ elif not os.path.exists(python_bin_path):
+ print('Invalid python path: %s cannot be found.' % python_bin_path)
+ else:
+ print('%s is not executable. Is it the python binary?' % python_bin_path)
+ environ_cp['PYTHON_BIN_PATH'] = ''
+
+ # Get PYTHON_LIB_PATH
+ python_lib_path = environ_cp.get('PYTHON_LIB_PATH')
+ if not python_lib_path:
+ python_lib_paths = get_python_path(environ_cp)
+ if environ_cp.get('USE_DEFAULT_PYTHON_LIB_PATH') == '1':
+ environ_cp['PYTHON_LIB_PATH'] = python_lib_paths[0]
+ else:
+ print('Found possible Python library paths:\n%s' %
+ '\n'.join(python_lib_paths))
+ default_python_lib_path = python_lib_paths[0]
+ python_lib_path = get_input(
+ 'Please input the desired Python library path to use. Default is %s'
+ % python_lib_paths[0])
+ if not python_lib_path:
+ python_lib_path = default_python_lib_path
+ environ_cp['PYTHON_LIB_PATH'] = python_lib_path
+
+ python_major_version = sys.version_info[0]
+ # Convert python path to Windows style before writing into bazel.rc
+ if is_windows():
+ python_bin_path = cygpath(python_bin_path)
+ python_lib_path = cygpath(python_lib_path)
+
+ # Set-up env variables used by python_configure.bzl
+ write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
+ write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
+ write_to_bazelrc('build --define PYTHON_BIN_PATH="%s"' % python_bin_path)
+ write_to_bazelrc('build --define PYTHON_LIB_PATH="%s"' % python_lib_path)
+ write_to_bazelrc('build --force_python=py%s' % python_major_version)
+ write_to_bazelrc('build --host_force_python=py%s' % python_major_version)
+ write_to_bazelrc('build --python%s_path=\"%s"' % (python_major_version,
+ python_bin_path))
+ write_to_bazelrc('test --force_python=py%s' % python_major_version)
+ write_to_bazelrc('test --host_force_python=py%s' % python_major_version)
+ write_to_bazelrc('test --define PYTHON_BIN_PATH="%s"' % python_bin_path)
+ write_to_bazelrc('test --define PYTHON_LIB_PATH="%s"' % python_lib_path)
+ write_to_bazelrc('run --define PYTHON_BIN_PATH="%s"' % python_bin_path)
+ write_to_bazelrc('run --define PYTHON_LIB_PATH="%s"' % python_lib_path)
+ environ_cp['PYTHON_BIN_PATH'] = python_bin_path
+
+ # Write tools/python_bin_path.sh
+ with open('tools/python_bin_path.sh', 'w') as f:
+ f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
+
+
+def reset_tf_configure_bazelrc():
+ """Reset file that contains customized config settings."""
+ open(_TF_BAZELRC, 'w').close()
+
+ home = os.path.expanduser('~')
+ if not os.path.exists('.bazelrc'):
+ if os.path.exists(os.path.join(home, '.bazelrc')):
+ with open('.bazelrc', 'a') as f:
+ f.write('import %s/.bazelrc\n' % home)
+ else:
+ open('.bazelrc', 'w').close()
+
+ remove_line_with('.bazelrc', 'tf_configure')
+ with open('.bazelrc', 'a') as f:
+ f.write('import %workspace%/.tf_configure.bazelrc\n')
+
+
+def run_gen_git_source(environ_cp):
+ """Run the gen_git_source to create links.
+
+ The links are for bazel to track dependencies for git hash propagation.
+
+ Args:
+ environ_cp: copy of the os.environ.
+ """
+ cmd = '%s tensorflow/tools/git/gen_git_source.py --configure %s' % (
+ environ_cp.get('PYTHON_BIN_PATH'), os.getcwd())
+ os.system(cmd)
+
+
+def cleanup_makefile():
+ """Delete any leftover BUILD files from the Makefile build.
+
+ These files could interfere with Bazel parsing.
+ """
+ makefile_download_dir = 'tensorflow/contrib/makefile/downloads'
+ if os.path.isdir(makefile_download_dir):
+ for root, _, filenames in os.walk(makefile_download_dir):
+ for f in filenames:
+ if f.endswith('BUILD'):
+ os.remove(os.path.join(root, f))
+
+
+def get_var(environ_cp,
+ var_name,
+ query_item,
+ enabled_by_default,
+ question=None,
+ yes_reply=None,
+ no_reply=None):
+ """Get boolean input from user.
+
+ If var_name is not set in env, ask user to enable query_item or not. If the
+ response is empty, use the default.
+
+ Args:
+ environ_cp: copy of the os.environ.
+ var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
+ query_item: string for feature related to the variable, e.g. "Hadoop File
+ System".
+ enabled_by_default: boolean for default behavior.
+ question: optional string for how to ask for user input.
+ yes_reply: optionanl string for reply when feature is enabled.
+ no_reply: optional string for reply when feature is disabled.
+
+ Returns:
+ boolean value of the variable.
+ """
+ if not question:
+ question = 'Do you wish to build TensorFlow with %s support?' % query_item
+ if not yes_reply:
+ yes_reply = '%s support will be enabled for TensorFlow.' % query_item
+ if not no_reply:
+ no_reply = 'No %s' % yes_reply
+
+ yes_reply += '\n'
+ no_reply += '\n'
+
+ if enabled_by_default:
+ question += ' [Y/n]: '
+ else:
+ question += ' [y/N]: '
+
+ var = environ_cp.get(var_name)
+ while var is None:
+ user_input_origin = get_input(question)
+ user_input = user_input_origin.strip().lower()
+ if user_input == 'y':
+ print(yes_reply)
+ var = True
+ elif user_input == 'n':
+ print(no_reply)
+ var = False
+ elif not user_input:
+ if enabled_by_default:
+ print(yes_reply)
+ var = True
+ else:
+ print(no_reply)
+ var = False
+ else:
+ print('Invalid selection: %s' % user_input_origin)
+ return var
+
+
+def set_build_var(environ_cp, var_name, query_item, option_name,
+ enabled_by_default):
+ """Set if query_item will be enabled for the build.
+
+ Ask user if query_item will be enabled. Default is used if no input is given.
+ Set subprocess environment variable and write to .bazelrc if enabled.
+
+ Args:
+ environ_cp: copy of the os.environ.
+ var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
+ query_item: string for feature related to the variable, e.g. "Hadoop File
+ System".
+ option_name: string for option to define in .bazelrc.
+ enabled_by_default: boolean for default behavior.
+ """
+
+ var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
+ environ_cp[var_name] = var
+ if var == '1':
+ write_to_bazelrc('build --define %s=true' % option_name)
+
+
+def set_action_env_var(environ_cp,
+ var_name,
+ query_item,
+ enabled_by_default,
+ question=None,
+ yes_reply=None,
+ no_reply=None):
+ """Set boolean action_env variable.
+
+ Ask user if query_item will be enabled. Default is used if no input is given.
+ Set environment variable and write to .bazelrc.
+
+ Args:
+ environ_cp: copy of the os.environ.
+ var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
+ query_item: string for feature related to the variable, e.g. "Hadoop File
+ System".
+ enabled_by_default: boolean for default behavior.
+ question: optional string for how to ask for user input.
+ yes_reply: optionanl string for reply when feature is enabled.
+ no_reply: optional string for reply when feature is disabled.
+ """
+ var = int(
+ get_var(environ_cp, var_name, query_item, enabled_by_default, question,
+ yes_reply, no_reply))
+
+ write_action_env_to_bazelrc(var_name, var)
+ environ_cp[var_name] = str(var)
+
+
+def check_bazel_version(min_version):
+ """Check installed bezel version is at least min_version.
+
+ Args:
+ min_version: string for minimum bazel version.
+ """
+ try:
+ curr_version = run_shell('bazel version')
+ except subprocess.CalledProcessError:
+ print('Cannot find bazel. Please install bazel.')
+ sys.exit(0)
+
+ for line in curr_version.split('\n'):
+ if 'Build label: ' in line:
+ curr_version = line.split('Build label: ')[1]
+ break
+
+ min_version_segments = min_version.split('.')
+ curr_version_segments = curr_version.split('.')
+
+ # Check if current bazel version can be detected properly.
+ for seg in curr_version_segments:
+ if not seg.isdigit():
+ print('WARNING: current bazel installation is not a release version.')
+ print('Make sure you are running at least bazel %s' % min_version)
+ return
+
+ min_version_str = ''.join(['%03d' % int(seg) for seg in min_version_segments])
+ curr_version_str = ''.join(
+ ['%03d' % int(seg) for seg in curr_version_segments])
+ if int(curr_version_str) < int(min_version_str):
+ print('Please upgrade your bazel installation to version %s or higher to '
+ 'build TensorFlow!' % min_version)
+ sys.exit(0)
+
+
+def set_cc_opt_flags(environ_cp):
+ """Set up architecture-dependent optimization flags.
+
+ Also append CC optimization flags to bazel.rc..
+
+ Args:
+ environ_cp: copy of the os.environ.
+ """
+ if is_ppc64le():
+ # gcc on ppc64le does not support -march, use mcpu instead
+ default_cc_opt_flags = '-mcpu=native'
+ else:
+ default_cc_opt_flags = '-march=native'
+ question = ('Please specify optimization flags to use during compilation when'
+ ' bazel option "--config=opt" is specified [Default is %s]: '
+ ) % default_cc_opt_flags
+ cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS',
+ question, default_cc_opt_flags)
+ for opt in cc_opt_flags.split():
+ write_to_bazelrc('build:opt --cxxopt=%s --copt=%s' % (opt, opt))
+
+
+def set_tf_cuda_clang(environ_cp):
+ """set TF_CUDA_CLANG action_env.
+
+ Args:
+ environ_cp: copy of the os.environ.
+ """
+ question = 'Do you want to use clang as CUDA compiler?'
+ yes_reply = 'Clang will be used as CUDA compiler.'
+ no_reply = 'nvcc will be used as CUDA compiler.'
+ set_action_env_var(
+ environ_cp,
+ 'TF_CUDA_CLANG',
+ None,
+ False,
+ question=question,
+ yes_reply=yes_reply,
+ no_reply=no_reply)
+
+
+def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var,
+ var_default):
+ """Get var_name either from env, or user or default.
+
+ If var_name has been set as environment variable, use the preset value, else
+ ask for user input. If no input is provided, the default is used.
+
+ Args:
+ environ_cp: copy of the os.environ.
+ var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
+ ask_for_var: string for how to ask for user input.
+ var_default: default value string.
+
+ Returns:
+ string value for var_name
+ """
+ var = environ_cp.get(var_name)
+ if not var:
+ var = get_input(ask_for_var)
+ if not var:
+ var = var_default
+ return var
+
+
+def set_clang_cuda_compiler_path(environ_cp):
+ """Set CLANG_CUDA_COMPILER_PATH."""
+ default_clang_path = run_shell('which clang || true')
+ ask_clang_path = ('Please specify which clang should be used as device and '
+ 'host compiler. [Default is %s]: ') % default_clang_path
+
+ while True:
+ clang_cuda_compiler_path = get_from_env_or_user_or_default(
+ environ_cp, 'CLANG_CUDA_COMPILER_PATH', ask_clang_path,
+ default_clang_path)
+ if os.path.exists(clang_cuda_compiler_path):
+ break
+
+ # Reset and retry
+ print('Invalid clang path: %s cannot be found.' % clang_cuda_compiler_path)
+ environ_cp['CLANG_CUDA_COMPILER_PATH'] = ''
+
+ # Set CLANG_CUDA_COMPILER_PATH
+ environ_cp['CLANG_CUDA_COMPILER_PATH'] = clang_cuda_compiler_path
+ write_action_env_to_bazelrc('CLANG_CUDA_COMPILER_PATH',
+ clang_cuda_compiler_path)
+
+
+def set_gcc_host_compiler_path(environ_cp):
+ """Set GCC_HOST_COMPILER_PATH."""
+ default_gcc_host_compiler_path = run_shell('which gcc || true')
+ cuda_bin_symlink = '%s/bin/gcc' % environ_cp.get('CUDA_TOOLKIT_PATH')
+
+ if os.path.islink(cuda_bin_symlink):
+ # os.readlink is only available in linux
+ default_gcc_host_compiler_path = run_shell('readlink %s' % cuda_bin_symlink)
+
+ ask_gcc_path = (
+ 'Please specify which gcc should be used by nvcc as the '
+ 'host compiler. [Default is %s]: ') % default_gcc_host_compiler_path
+ while True:
+ gcc_host_compiler_path = get_from_env_or_user_or_default(
+ environ_cp, 'GCC_HOST_COMPILER_PATH', ask_gcc_path,
+ default_gcc_host_compiler_path)
+
+ if os.path.exists(gcc_host_compiler_path):
+ break
+
+ # Reset and retry
+ print('Invalid gcc path. %s cannot be found' % gcc_host_compiler_path)
+ environ_cp['GCC_HOST_COMPILER_PATH'] = ''
+
+ # Set GCC_HOST_COMPILER_PATH
+ environ_cp['GCC_HOST_COMPILER_PATH'] = gcc_host_compiler_path
+ write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path)
+
+
+def set_tf_cuda_version(environ_cp):
+ """Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION."""
+ ask_cuda_version = (
+ 'Please specify the CUDA SDK version you want to use, '
+ 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
+
+ while True:
+ # Configure the Cuda SDK version to use.
+ tf_cuda_version = get_from_env_or_user_or_default(
+ environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION)
+
+ # Find out where the CUDA toolkit is installed
+ default_cuda_path = _DEFAULT_CUDA_PATH
+ if is_windows():
+ default_cuda_path = cygpath(
+ environ_cp.get('CUDA_PATH', _DEFAULT_CUDA_PATH_WIN))
+ elif is_linux():
+ # If the default doesn't exist, try an alternative default.
+ if (not os.path.exists(default_cuda_path)
+ ) and os.path.exists(_DEFAULT_CUDA_PATH_LINUX):
+ default_cuda_path = _DEFAULT_CUDA_PATH_LINUX
+ ask_cuda_path = ('Please specify the location where CUDA %s toolkit is'
+ ' installed. Refer to README.md for more details. '
+ '[Default is %s]: ') % (tf_cuda_version, default_cuda_path)
+ cuda_toolkit_path = get_from_env_or_user_or_default(
+ environ_cp, 'CUDA_TOOLKIT_PATH', ask_cuda_path, default_cuda_path)
+
+ if is_windows():
+ cuda_rt_lib_path = 'lib/x64/cudart.lib'
+ elif is_linux():
+ cuda_rt_lib_path = 'lib64/libcudart.so.%s' % tf_cuda_version
+ elif is_macos():
+ cuda_rt_lib_path = 'lib/libcudart.%s.dylib' % tf_cuda_version
+
+ cuda_toolkit_path_full = os.path.join(cuda_toolkit_path, cuda_rt_lib_path)
+ if os.path.exists(cuda_toolkit_path_full):
+ break
+
+ # Reset and retry
+ print('Invalid path to CUDA %s toolkit. %s cannot be found' %
+ (tf_cuda_version, cuda_toolkit_path_full))
+ environ_cp['TF_CUDA_VERSION'] = ''
+ environ_cp['CUDA_TOOLKIT_PATH'] = ''
+
+ # Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION
+ environ_cp['CUDA_TOOLKIT_PATH'] = cuda_toolkit_path
+ write_action_env_to_bazelrc('CUDA_TOOLKIT_PATH', cuda_toolkit_path)
+ environ_cp['TF_CUDA_VERSION'] = tf_cuda_version
+ write_action_env_to_bazelrc('TF_CUDA_VERSION', tf_cuda_version)
+
+
+def set_tf_cunn_version(environ_cp):
+ """Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION."""
+ ask_cudnn_version = (
+ '"Please specify the cuDNN version you want to use. '
+ '[Leave empty to default to cuDNN %s.0]: ') % _DEFAULT_CUDNN_VERSION
+
+ while True:
+ tf_cudnn_version = get_from_env_or_user_or_default(
+ environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version,
+ _DEFAULT_CUDNN_VERSION)
+
+ default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH')
+ ask_cudnn_path = (r'Please specify the location where cuDNN %s library is '
+ 'installed. Refer to README.md for more details. [Default'
+ ' is %s]:') % (tf_cudnn_version, default_cudnn_path)
+ cudnn_install_path = get_from_env_or_user_or_default(
+ environ_cp, 'CUDNN_INSTALL_PATH', ask_cudnn_path, default_cudnn_path)
+
+ # Result returned from "read" will be used unexpanded. That make "~"
+ # unusable. Going through one more level of expansion to handle that.
+ cudnn_install_path = os.path.realpath(
+ os.path.expanduser(cudnn_install_path))
+ if is_windows():
+ cudnn_install_path = cygpath(cudnn_install_path)
+
+ if is_windows():
+ cuda_dnn_lib_path = 'lib/x64/cudnn.lib'
+ cuda_dnn_lib_alt_path = 'lib/x64/cudnn.lib'
+ elif is_linux():
+ cuda_dnn_lib_path = 'lib64/libcudnn.so.%s' % tf_cudnn_version
+ cuda_dnn_lib_alt_path = 'libcudnn.so.%s' % tf_cudnn_version
+ elif is_macos():
+ cuda_dnn_lib_path = 'lib/libcudnn.%s.dylib' % tf_cudnn_version
+ cuda_dnn_lib_alt_path = 'libcudnn.%s.dylib' % tf_cudnn_version
+
+ cuda_dnn_lib_path_full = os.path.join(cudnn_install_path, cuda_dnn_lib_path)
+ cuda_dnn_lib_alt_path_full = os.path.join(cudnn_install_path,
+ cuda_dnn_lib_alt_path)
+ if os.path.exists(cuda_dnn_lib_path_full) or os.path.exists(
+ cuda_dnn_lib_alt_path_full):
+ break
+
+ # Try another alternative for Linux
+ if is_linux():
+ if subprocess.call(['which', 'ldconfig']):
+ ldconfig_bin = '/sbin/ldconfig'
+ else:
+ ldconfig_bin = 'ldconfig'
+ cudnn_path_from_ldconfig = run_shell(
+ r'%s -p | sed -n "s/.*libcudnn.so .* => \(.*\)/\\1/p"' % ldconfig_bin)
+ if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)):
+ cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
+ break
+
+ # Reset and Retry
+ print(
+ 'Invalid path to cuDNN %s toolkit. None of the following files can be '
+ 'found:' % tf_cudnn_version)
+ print(cuda_dnn_lib_path_full)
+ print(cuda_dnn_lib_alt_path_full)
+ if is_linux():
+ print('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version))
+
+ environ_cp['TF_CUDNN_VERSION'] = ''
+
+ # Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION
+ environ_cp['CUDNN_INSTALL_PATH'] = cudnn_install_path
+ write_action_env_to_bazelrc('CUDNN_INSTALL_PATH', cudnn_install_path)
+ environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version
+ write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)
+
+
+def get_native_cuda_compute_capabilities(environ_cp):
+ """Get native cuda compute capabilities.
+
+ Args:
+ environ_cp: copy of the os.environ.
+ Returns:
+ string of native cuda compute capabilities, separated by comma.
+ """
+ device_query_bin = os.path.join(
+ environ_cp.get('CUDA_TOOLKIT_PATH'), 'extras/demo_suite/deviceQuery')
+ cmd = (r'"%s" | grep "Capability" | grep -o "[0-9]*\.[0-9]*" | sed '
+ '":a;{N;s/\\n/,/};ba"') % device_query_bin
+ try:
+ output = run_shell(cmd)
+ except subprocess.CalledProcessError:
+ output = ''
+ return output
+
+
+def set_tf_cuda_compute_capabilities(environ_cp):
+ """Set TF_CUDA_COMPUTE_CAPABILITIES."""
+ while True:
+ native_cuda_compute_capabilities = get_native_cuda_compute_capabilities(
+ environ_cp)
+ if not native_cuda_compute_capabilities:
+ default_cuda_compute_capabilities = _DEFAULT_CUDA_COMPUTE_CAPABILITIES
+ else:
+ default_cuda_compute_capabilities = native_cuda_compute_capabilities
+
+ ask_cuda_compute_capabilities = (
+ 'Please specify a list of comma-separated '
+ 'Cuda compute capabilities you want to '
+ 'build with.\nYou can find the compute '
+ 'capability of your device at: '
+ 'https://developer.nvidia.com/cuda-gpus.\nPlease'
+ ' note that each additional compute '
+ 'capability significantly increases your '
+ 'build time and binary size. [Default is: %s]' %
+ default_cuda_compute_capabilities)
+ tf_cuda_compute_capabilities = get_from_env_or_user_or_default(
+ environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES',
+ ask_cuda_compute_capabilities, default_cuda_compute_capabilities)
+ # Check whether all capabilities from the input is valid
+ all_valid = True
+ for compute_capability in tf_cuda_compute_capabilities.split(','):
+ if not re.match('[0-9]+.[0-9]+', compute_capability):
+ print('Invalid compute capability: ' % compute_capability)
+ all_valid = False
+
+ if all_valid:
+ break
+
+ # Reset and Retry
+ environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = ''
+
+ # Set TF_CUDA_COMPUTE_CAPABILITIES
+ environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = tf_cuda_compute_capabilities
+ write_action_env_to_bazelrc('TF_CUDA_COMPUTE_CAPABILITIES',
+ tf_cuda_compute_capabilities)
+
+
+def set_other_cuda_vars(environ_cp):
+ """Set other CUDA related variables."""
+ if is_windows():
+ # The following three variables are needed for MSVC toolchain configuration
+ # in Bazel
+ environ_cp['CUDA_PATH'] = environ_cp.get('CUDA_TOOLKIT_PATH')
+ environ_cp['CUDA_COMPUTE_CAPABILITIES'] = environ_cp.get(
+ 'TF_CUDA_COMPUTE_CAPABILITIES')
+ environ_cp['NO_WHOLE_ARCHIVE_OPTION'] = 1
+ write_action_env_to_bazelrc('CUDA_PATH', environ_cp.get('CUDA_PATH'))
+ write_action_env_to_bazelrc('CUDA_COMPUTE_CAPABILITIE',
+ environ_cp.get('CUDA_COMPUTE_CAPABILITIE'))
+ write_action_env_to_bazelrc('NO_WHOLE_ARCHIVE_OPTION',
+ environ_cp.get('NO_WHOLE_ARCHIVE_OPTION'))
+ write_to_bazelrc('build --config=win-cuda')
+ write_to_bazelrc('test --config=win-cuda')
+ else:
+ # If CUDA is enabled, always use GPU during build and test.
+ if environ_cp.get('TF_CUDA_CLANG') == '1':
+ write_to_bazelrc('build --config=cuda_clang')
+ write_to_bazelrc('test --config=cuda_clang')
+ else:
+ write_to_bazelrc('build --config=cuda')
+ write_to_bazelrc('test --config=cuda')
+
+
+def set_host_cxx_compiler(environ_cp):
+ """Set HOST_CXX_COMPILER."""
+ default_cxx_host_compiler = run_shell('which g++ || true')
+ ask_cxx_host_compiler = (
+ 'Please specify which C++ compiler should be used as'
+ ' the host C++ compiler. [Default is %s]: ') % default_cxx_host_compiler
+
+ while True:
+ host_cxx_compiler = get_from_env_or_user_or_default(
+ environ_cp, 'HOST_CXX_COMPILER', ask_cxx_host_compiler,
+ default_cxx_host_compiler)
+ if os.path.exists(host_cxx_compiler):
+ break
+
+ # Reset and retry
+ print('Invalid C++ compiler path. %s cannot be found' % host_cxx_compiler)
+ environ_cp['HOST_CXX_COMPILER'] = ''
+
+ # Set HOST_CXX_COMPILER
+ environ_cp['HOST_CXX_COMPILER'] = host_cxx_compiler
+ write_action_env_to_bazelrc('HOST_CXX_COMPILER', host_cxx_compiler)
+
+
+def set_host_c_compiler(environ_cp):
+ """Set HOST_C_COMPILER."""
+ default_c_host_compiler = run_shell('which gcc || true')
+ ask_c_host_compiler = (
+ 'Please specify which C compiler should be used as the'
+ ' host C compiler. [Default is %s]: ') % default_c_host_compiler
+
+ while True:
+ host_c_compiler = get_from_env_or_user_or_default(
+ environ_cp, 'HOST_C_COMPILER', ask_c_host_compiler,
+ default_c_host_compiler)
+ if os.path.exists(host_c_compiler):
+ break
+
+ # Reset and retry
+ print('Invalid C compiler path. %s cannot be found' % host_c_compiler)
+ environ_cp['HOST_C_COMPILER'] = ''
+
+ # Set HOST_C_COMPILER
+ environ_cp['HOST_C_COMPILER'] = host_c_compiler
+ write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler)
+
+
+def set_computecpp_toolkit_path(environ_cp):
+ """Set COMPUTECPP_TOOLKIT_PATH."""
+ ask_computecpp_toolkit_path = ('Please specify the location where ComputeCpp '
+ 'for SYCL %s is installed. [Default is %s]: '
+ ) % (_TF_OPENCL_VERSION,
+ _DEFAULT_COMPUTECPP_TOOLKIT_PATH)
+
+ while True:
+ computecpp_toolkit_path = get_from_env_or_user_or_default(
+ environ_cp, 'COMPUTECPP_TOOLKIT_PATH', ask_computecpp_toolkit_path,
+ _DEFAULT_COMPUTECPP_TOOLKIT_PATH)
+ if is_linux():
+ sycl_rt_lib_path = 'lib/libComputeCpp.so'
+ else:
+ sycl_rt_lib_path = ''
+
+ sycl_rt_lib_path_full = os.path.join(computecpp_toolkit_path,
+ sycl_rt_lib_path)
+ if os.path.exists(sycl_rt_lib_path_full):
+ break
+
+ print('Invalid SYCL %s library path. %s cannot be found' %
+ (_TF_OPENCL_VERSION, sycl_rt_lib_path_full))
+ environ_cp['COMPUTECPP_TOOLKIT_PATH'] = ''
+
+ # Set COMPUTECPP_TOOLKIT_PATH
+ environ_cp['COMPUTECPP_TOOLKIT_PATH'] = computecpp_toolkit_path
+ write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH',
+ computecpp_toolkit_path)
+
+
+def set_mpi_home(environ_cp):
+ """Set MPI_HOME."""
+ cmd = ('dirname $(dirname $(which mpirun)) || dirname $(dirname $(which '
+ 'mpiexec)) || true')
+ default_mpi_home = run_shell(cmd)
+ ask_mpi_home = ('Please specify the MPI toolkit folder. [Default is %s]: '
+ ) % default_mpi_home
+ while True:
+ mpi_home = get_from_env_or_user_or_default(environ_cp, 'MPI_HOME',
+ ask_mpi_home, default_mpi_home)
+
+ if os.path.exists(os.path.join(mpi_home, 'include')) and os.path.exists(
+ os.path.join(mpi_home, 'lib')):
+ break
+
+ print('Invalid path to the MPI Toolkit. %s or %s cannot be found' %
+ (os.path.join(mpi_home, 'include'),
+ os.path.exists(os.path.join(mpi_home, 'lib'))))
+ environ_cp['MPI_HOME'] = ''
+
+ # Set MPI_HOME
+ environ_cp['MPI_HOME'] = str(mpi_home)
+
+
+def set_other_mpi_vars(environ_cp):
+ """Set other MPI related variables."""
+ # Link the MPI header files
+ mpi_home = environ_cp.get('MPI_HOME')
+ symlink_force('%s/include/mpi.h' % mpi_home, 'third_party/mpi/mpi.h')
+
+ # Determine if we use OpenMPI or MVAPICH, these require different header files
+ # to be included here to make bazel dependency checker happy
+ if os.path.exists(os.path.join(mpi_home, 'include/mpi_portable_platform.h')):
+ symlink_force(
+ os.path.join(mpi_home, 'include/mpi_portable_platform.h'),
+ 'third_party/mpi/mpi_portable_platform.h')
+ # TODO(gunan): avoid editing files in configure
+ sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI=False',
+ 'MPI_LIB_IS_OPENMPI=True')
+ else:
+ # MVAPICH / MPICH
+ symlink_force(
+ os.path.join(mpi_home, 'include/mpio.h'), 'third_party/mpi/mpio.h')
+ symlink_force(
+ os.path.join(mpi_home, 'include/mpicxx.h'), 'third_party/mpi/mpicxx.h')
+ # TODO(gunan): avoid editing files in configure
+ sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI=True',
+ 'MPI_LIB_IS_OPENMPI=False')
+
+ if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')):
+ symlink_force(
+ os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so')
+ else:
+ raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home)
+
+
+def set_mkl():
+ write_to_bazelrc('build:mkl --define with_mkl_support=true')
+ write_to_bazelrc('build:mkl --define using_mkl=true')
+ write_to_bazelrc('build:mkl -c opt')
+ write_to_bazelrc('build:mkl --copt="-DEIGEN_USE_VML"')
+ print(
+ 'Add "--config=mkl" to your bazel command to build with MKL '
+ 'support.\nPlease note that MKL on MacOS or windows is still not '
+ 'supported.\nIf you would like to use a local MKL instead of '
+ 'downloading, please set the environment variable \"TF_MKL_ROOT\" every '
+ 'time before build.')
+
+
+def main():
+ # Make a copy of os.environ to be clear when functions and getting and setting
+ # environment variables.
+ environ_cp = dict(os.environ)
+
+ check_bazel_version('0.4.5')
+
+ reset_tf_configure_bazelrc()
+ cleanup_makefile()
+ setup_python(environ_cp)
+ run_gen_git_source(environ_cp)
+
+ if is_windows():
+ environ_cp['TF_NEED_GCP'] = '0'
+ environ_cp['TF_NEED_HDFS'] = '0'
+ environ_cp['TF_NEED_JEMALLOC'] = '0'
+ environ_cp['TF_NEED_OPENCL'] = '0'
+ environ_cp['TF_CUDA_CLANG'] = '0'
+
+ if is_macos():
+ environ_cp['TF_NEED_JEMALLOC'] = '0'
+
+ set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
+ 'with_jemalloc', True)
+ set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform',
+ 'with_gcp_support', False)
+ set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System',
+ 'with_hdfs_support', False)
+ set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
+ False)
+ set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support',
+ False)
+
+ set_action_env_var(environ_cp, 'TF_NEED_OPENCL', 'OpenCL', False)
+ if environ_cp.get('TF_NEED_OPENCL') == '1':
+ set_host_cxx_compiler(environ_cp)
+ set_host_c_compiler(environ_cp)
+ set_computecpp_toolkit_path(environ_cp)
+
+ set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
+ if environ_cp.get('TF_NEED_CUDA') == '1':
+ set_tf_cuda_version(environ_cp)
+ set_tf_cunn_version(environ_cp)
+ set_tf_cuda_compute_capabilities(environ_cp)
+
+ set_tf_cuda_clang(environ_cp)
+ if environ_cp.get('TF_CUDA_CLANG') == '1':
+ # Set up which clang we should use as the cuda / host compiler.
+ set_clang_cuda_compiler_path(environ_cp)
+ else:
+ # Set up which gcc nvcc should use as the host compiler
+ # No need to set this on Windows
+ if not is_windows():
+ set_gcc_host_compiler_path(environ_cp)
+ set_other_cuda_vars(environ_cp)
+
+ set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
+ if environ_cp.get('TF_NEED_MPI') == '1':
+ set_mpi_home(environ_cp)
+ set_other_mpi_vars(environ_cp)
+
+ set_cc_opt_flags(environ_cp)
+ set_mkl()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 3e69134c50..371264ef6c 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -56,21 +56,16 @@ limitations under the License.
// The implementation below is at the top level instead of the
// brain namespace because we are defining 'extern "C"' functions.
-using tensorflow::error::Code;
-using tensorflow::errors::InvalidArgument;
-using tensorflow::gtl::ArraySlice;
-using tensorflow::strings::StrCat;
using tensorflow::AllocationDescription;
using tensorflow::DataType;
using tensorflow::Graph;
using tensorflow::GraphDef;
-using tensorflow::mutex_lock;
using tensorflow::NameRangeMap;
using tensorflow::NameRangesForNode;
using tensorflow::NewSession;
using tensorflow::Node;
-using tensorflow::NodeDef;
using tensorflow::NodeBuilder;
+using tensorflow::NodeDef;
using tensorflow::OpDef;
using tensorflow::OpRegistry;
using tensorflow::PartialTensorShape;
@@ -83,6 +78,11 @@ using tensorflow::TensorBuffer;
using tensorflow::TensorId;
using tensorflow::TensorShape;
using tensorflow::TensorShapeProto;
+using tensorflow::error::Code;
+using tensorflow::errors::InvalidArgument;
+using tensorflow::gtl::ArraySlice;
+using tensorflow::mutex_lock;
+using tensorflow::strings::StrCat;
extern "C" {
@@ -258,24 +258,27 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
return sz;
}
-size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
- size_t* dst_len, TF_Status* status) {
+static Status TF_StringDecode_Impl(const char* src, size_t src_len,
+ const char** dst, size_t* dst_len) {
tensorflow::uint64 len64 = 0;
const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64);
if (p == nullptr) {
- status->status =
- InvalidArgument("invalid string encoding or truncated src buffer");
- return 0;
+ return InvalidArgument("invalid string encoding or truncated src buffer");
}
if (len64 > std::numeric_limits<size_t>::max()) {
- status->status =
- InvalidArgument("encoded string is ", len64,
- "-bytes, which is too large for this architecture");
- return 0;
+ return InvalidArgument("encoded string is ", len64,
+ "-bytes, which is too large for this architecture");
}
*dst = p;
*dst_len = static_cast<size_t>(len64);
- return static_cast<size_t>(p - src) + *dst_len;
+ return Status::OK();
+}
+
+size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
+ size_t* dst_len, TF_Status* status) {
+ status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len);
+ if (!status->status.ok()) return 0;
+ return static_cast<size_t>(*dst - src) + *dst_len;
}
size_t TF_StringEncodedSize(size_t len) {
@@ -391,16 +394,20 @@ void TF_Reset(const TF_SessionOptions* opt, const char** containers,
namespace tensorflow {
-// Non-static for testing.
-bool TF_Tensor_DecodeStrings(TF_Tensor* src, Tensor* dst, TF_Status* status) {
+Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
+ if (src->dtype != TF_STRING) {
+ *dst = TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer);
+ return Status::OK();
+ }
+ // TF_STRING tensors require copying since Tensor class expects a sequence of
+ // string objects.
const tensorflow::int64 num_elements = src->shape.num_elements();
const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
const size_t src_size = TF_TensorByteSize(src);
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
num_elements) {
- status->status = InvalidArgument(
+ return InvalidArgument(
"Malformed TF_STRING tensor; too short to hold number of elements");
- return false;
}
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
const char* limit = input + src_size;
@@ -411,24 +418,30 @@ bool TF_Tensor_DecodeStrings(TF_Tensor* src, Tensor* dst, TF_Status* status) {
tensorflow::uint64 offset =
reinterpret_cast<const tensorflow::uint64*>(input)[i];
if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
- status->status = InvalidArgument("Malformed TF_STRING tensor; element ",
- i, " out of range");
- return false;
+ return InvalidArgument("Malformed TF_STRING tensor; element ", i,
+ " out of range");
}
size_t len;
const char* p;
const char* srcp = data_start + offset;
- TF_StringDecode(srcp, limit - srcp, &p, &len, status);
- if (!status->status.ok()) {
- return false;
- }
+ Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len);
+ if (!status.ok()) return status;
dstarray(i).assign(p, len);
}
- return true;
+ return Status::OK();
}
// Non-static for testing.
-TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src) {
+TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src) {
+ if (src.dtype() != DT_STRING) {
+ TensorBuffer* buf = TensorCApi::Buffer(src);
+ buf->Ref();
+ return new TF_Tensor{static_cast<TF_DataType>(src.dtype()), src.shape(),
+ buf};
+ }
+ // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
+ // encoded sequence of strings.
+
// Compute bytes needed for encoding.
size_t size = 0;
const auto& srcarray = src.flat<tensorflow::string>();
@@ -507,16 +520,8 @@ static bool TF_Run_Inputs(
TF_Status* status) {
const int ninputs = input_pairs->size();
for (int i = 0; i < ninputs; ++i) {
- TF_Tensor* src = c_inputs[i];
- if (c_inputs[i]->dtype != TF_STRING) {
- (*input_pairs)[i].second = tensorflow::TensorCApi::MakeTensor(
- src->dtype, src->shape, src->buffer);
- } else if (!tensorflow::TF_Tensor_DecodeStrings(
- src, &(*input_pairs)[i].second, status)) {
- // TF_STRING tensors require copying since Tensor class expects
- // a sequence of string objects.
- return false;
- }
+ status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
+ if (!status->status.ok()) return false;
}
return true;
}
@@ -574,15 +579,7 @@ static void TF_Run_Helper(
static_cast<TF_DataType>(src.dtype()), src.shape());
continue;
}
- if (src.dtype() != tensorflow::DT_STRING) {
- // Share the underlying buffer.
- TensorBuffer* buf = tensorflow::TensorCApi::Buffer(src);
- buf->Ref();
- c_outputs[i] = new TF_Tensor{static_cast<TF_DataType>(src.dtype()),
- src.shape(), buf};
- } else {
- c_outputs[i] = tensorflow::TF_Tensor_EncodeStrings(src);
- }
+ c_outputs[i] = TF_TensorFromTensor(src);
}
}
@@ -1062,20 +1059,9 @@ void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
TF_Tensor* value, TF_Status* status) {
- status->status = Status::OK();
Tensor t;
- bool ok = true;
-
- if (value->dtype != TF_STRING) {
- t = tensorflow::TensorCApi::MakeTensor(value->dtype, value->shape,
- value->buffer);
- } else {
- // TF_STRING tensors require copying since Tensor class expects
- // a sequence of string objects.
- ok = tensorflow::TF_Tensor_DecodeStrings(value, &t, status);
- }
-
- if (ok) desc->node_builder.Attr(attr_name, t);
+ status->status = TF_TensorToTensor(value, &t);
+ if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
}
void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
@@ -1084,21 +1070,14 @@ void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
status->status = Status::OK();
std::vector<Tensor> t;
t.reserve(num_values);
- bool ok = true;
- for (int i = 0; i < num_values && ok; ++i) {
- if (values[i]->dtype != TF_STRING) {
- t.emplace_back(tensorflow::TensorCApi::MakeTensor(
- values[i]->dtype, values[i]->shape, values[i]->buffer));
- } else {
- t.emplace_back(::tensorflow::DT_STRING);
- // TF_STRING tensors require copying since Tensor class expects
- // a sequence of string objects.
- ok = tensorflow::TF_Tensor_DecodeStrings(values[i], &t.back(), status);
- }
+ for (int i = 0; i < num_values && status->status.ok(); ++i) {
+ Tensor v;
+ status->status = TF_TensorToTensor(values[i], &v);
+ t.emplace_back(v);
}
- if (ok) desc->node_builder.Attr(attr_name, t);
+ if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
}
void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
@@ -1555,9 +1534,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
Tensor t;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
if (!status->status.ok()) return;
- *value = new TF_Tensor{static_cast<TF_DataType>(t.dtype()), t.shape(),
- tensorflow::TensorCApi::Buffer(t)};
- (*value)->buffer->Ref();
+ *value = TF_TensorFromTensor(t);
}
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
@@ -1568,10 +1545,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(ts.size()));
for (int i = 0; i < len; ++i) {
- const Tensor& t = ts[i];
- values[i] = new TF_Tensor{static_cast<TF_DataType>(t.dtype()), t.shape(),
- tensorflow::TensorCApi::Buffer(t)};
- values[i]->buffer->Ref();
+ values[i] = TF_TensorFromTensor(ts[i]);
}
}
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index d6debe3b99..25b6cbd8e7 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -45,9 +45,8 @@ limitations under the License.
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
-
-bool TF_Tensor_DecodeStrings(TF_Tensor* src, Tensor* dst, TF_Status* status);
-TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src);
+TF_Tensor* TF_TensorFromTensor(const Tensor& src);
+Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
namespace {
@@ -146,19 +145,16 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
src.flat<string>()(i) = data[i];
}
- TF_Tensor* dst = TF_Tensor_EncodeStrings(src);
+ TF_Tensor* dst = TF_TensorFromTensor(src);
// Convert back to a C++ Tensor and ensure we get expected output.
- TF_Status* status = TF_NewStatus();
Tensor output;
- ASSERT_TRUE(TF_Tensor_DecodeStrings(dst, &output, status)) << line;
- ASSERT_EQ(TF_OK, TF_GetCode(status)) << line;
+ ASSERT_EQ(Status::OK(), TF_TensorToTensor(dst, &output)) << line;
ASSERT_EQ(src.NumElements(), output.NumElements()) << line;
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
ASSERT_EQ(data[i], output.flat<string>()(i)) << line;
}
- TF_DeleteStatus(status);
TF_DeleteTensor(dst);
}
}
@@ -918,7 +914,7 @@ TEST(CAPI, SavedModel) {
TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr);
- csession.SetInputs({{input_op, TF_Tensor_EncodeStrings(input)}});
+ csession.SetInputs({{input_op, TF_TensorFromTensor(input)}});
const tensorflow::string output_op_name =
tensorflow::ParseTensorName(output_name).first.ToString();
@@ -1636,6 +1632,39 @@ TEST_F(CApiAttributesTest, Tensor) {
TF_DeleteTensor(value);
}
+TEST_F(CApiAttributesTest, StringTensor) {
+ // Create the string-Tensor "atttribute" value.
+ char encoded[] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, // array[uint64] offsets
+ 1, // varint encoded string length
+ 'A',
+ };
+ auto deallocator = [](void* data, size_t len, void* arg) {};
+ unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &encoded[0],
+ sizeof(encoded), deallocator, nullptr),
+ TF_DeleteTensor);
+
+ // Create a TF_Operation with the attribute t_in
+ auto desc = init("tensor");
+ TF_SetAttrTensor(desc, "v", t_in.get(), s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ auto oper = TF_FinishOperation(desc, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Fetch the attribute back.
+ EXPECT_TF_META("v", -1, TF_ATTR_TENSOR, -1);
+ TF_Tensor* t_out = nullptr;
+ TF_OperationGetAttrTensor(oper, "v", &t_out, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ EXPECT_EQ(TF_STRING, TF_TensorType(t_out));
+ EXPECT_EQ(0, TF_NumDims(t_out));
+ ASSERT_EQ(TF_TensorByteSize(t_in.get()), TF_TensorByteSize(t_out));
+ EXPECT_EQ(0, memcmp(TF_TensorData(t_in.get()), TF_TensorData(t_out),
+ TF_TensorByteSize(t_out)));
+ TF_DeleteTensor(t_out);
+}
+
TEST_F(CApiAttributesTest, TensorList) {
const char tensor1[] = {5, 7};
const int64_t dims1[] = {1, 2};
@@ -1647,7 +1676,8 @@ TEST_F(CApiAttributesTest, TensorList) {
auto desc = init("list(tensor)");
TF_Tensor* tmp[] = {
- Int8Tensor(dims1, ndims1, tensor1), Int8Tensor(dims2, ndims2, tensor2),
+ Int8Tensor(dims1, ndims1, tensor1),
+ Int8Tensor(dims2, ndims2, tensor2),
};
TF_SetAttrTensorList(desc, "v", tmp, TF_ARRAYSIZE(tmp), s_);
for (int i = 0; i < TF_ARRAYSIZE(tmp); ++i) {
diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc
index cec3ebc0ad..66a943410e 100644
--- a/tensorflow/cc/framework/gradients.cc
+++ b/tensorflow/cc/framework/gradients.cc
@@ -356,7 +356,7 @@ Status SymbolicGradientBuilder::AddGradients() {
// Check if any input nodes still have pending gradients and have not been
// processed yet. This happens if not all outputs of a node are in 'inputs_'.
std::unordered_map<Node*, int> requested_grads;
- for (Output nout : inputs_) {
+ for (const Output& nout : inputs_) {
if (pending_[nout.node()->id()] > 0) {
DCHECK_GT(nout.node()->num_outputs(), 1);
int idx = input_nodes_[nout];
@@ -365,7 +365,7 @@ Status SymbolicGradientBuilder::AddGradients() {
++requested_grads[nout.node()];
}
}
- for (auto& p : requested_grads) {
+ for (const auto& p : requested_grads) {
int num_requested_inputs = p.first->num_outputs() - pending_[p.first->id()];
CHECK_EQ(num_requested_inputs, p.second);
}
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 7eab7bb28f..77b45aa11e 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -257,6 +257,11 @@ Status MarkForCompilationPass::Run(
&registration)) {
return false;
}
+
+ // Don't compile control trigger nodes. We won't preserve their deadness
+ // semantics correctly, so it's safest not to compile them.
+ if (node->IsControlTrigger()) return false;
+
// If this device requires a JIT, we must say yes.
if (registration->requires_compilation) return true;
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 4f0137e8d9..c693f58f8b 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -354,6 +354,20 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "segment_reduction_ops_test",
+ size = "small",
+ srcs = ["segment_reduction_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:math_ops_gen",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "spacetobatch_op_test",
size = "medium",
srcs = ["spacetobatch_op_test.py"],
diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py
new file mode 100644
index 0000000000..260a04421b
--- /dev/null
+++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py
@@ -0,0 +1,139 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for segment reduction ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class SegmentReductionOpsTest(XLATestCase):
+ """Test cases for segment reduction ops."""
+
+ def UnsortedSegmentSum(self, data, indices, num_segments):
+ with self.test_session() as sess, self.test_scope():
+ d = array_ops.placeholder(data.dtype, shape=data.shape)
+ if isinstance(indices, int):
+ i = array_ops.placeholder(np.int32, shape=[])
+ else:
+ i = array_ops.placeholder(indices.dtype, shape=indices.shape)
+ return sess.run(
+ math_ops.unsorted_segment_sum(d, i, num_segments),
+ {d: data,
+ i: indices})
+
+ def testUnsortedSegmentSum0DIndices1DData(self):
+ for dtype in self.numeric_types:
+ self.assertAllClose(
+ np.array(
+ [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5],
+ [0, 0, 0, 0, 0, 0]],
+ dtype=dtype),
+ self.UnsortedSegmentSum(
+ np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4))
+
+ def testUnsortedSegmentSum1DIndices1DData(self):
+ for dtype in self.numeric_types:
+ self.assertAllClose(
+ np.array([1, 3, 2, 9], dtype=dtype),
+ self.UnsortedSegmentSum(
+ np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
+ np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4))
+
+ def testUnsortedSegmentSum1DIndices2DDataDisjoint(self):
+ for dtype in self.numeric_types:
+ data = np.array(
+ [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43],
+ [50, 51, 52, 53]],
+ dtype=dtype)
+ indices = np.array([8, 1, 0, 3, 7], dtype=np.int32)
+ num_segments = 10
+ y = self.UnsortedSegmentSum(data, indices, num_segments)
+ self.assertAllClose(
+ np.array(
+ [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0],
+ [40, 41, 42, 43], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],
+ [50, 51, 52, 53], [0, 1, 2, 3], [0, 0, 0, 0]],
+ dtype=dtype), y)
+
+ def testUnsortedSegmentSum1DIndices2DDataNonDisjoint(self):
+ for dtype in self.numeric_types:
+ data = np.array(
+ [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43],
+ [50, 51, 52, 53]],
+ dtype=dtype)
+ indices = np.array([0, 1, 2, 0, 1], dtype=np.int32)
+ num_segments = 4
+ y = self.UnsortedSegmentSum(data, indices, num_segments)
+ self.assertAllClose(
+ np.array(
+ [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33],
+ [0, 0, 0, 0]],
+ dtype=dtype), y)
+
+ def testUnsortedSegmentSum2DIndices3DData(self):
+ for dtype in self.numeric_types:
+ data = np.array(
+ [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]],
+ [[200, 201, 202], [210, 211, 212]], [[300, 301, 302],
+ [310, 311, 312]]],
+ dtype=dtype)
+ indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32)
+ num_segments = 8
+ y = self.UnsortedSegmentSum(data, indices, num_segments)
+ self.assertAllClose(
+ np.array(
+ [[210, 211, 212], [110, 111, 112], [310, 311, 312],
+ [100, 102, 104], [0, 0, 0.], [210, 212, 214], [300, 301,
+ 302], [0, 0, 0]],
+ dtype=dtype), y)
+
+ def testUnsortedSegmentSum1DIndices3DData(self):
+ for dtype in self.numeric_types:
+ data = np.array(
+ [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]],
+ [[200, 201, 202], [210, 211, 212]], [[300, 301, 302],
+ [310, 311, 312]]],
+ dtype=dtype)
+ indices = np.array([3, 0, 2, 5], dtype=np.int32)
+ num_segments = 6
+ y = self.UnsortedSegmentSum(data, indices, num_segments)
+ self.assertAllClose(
+ np.array(
+ [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]],
+ [[200, 201, 202], [210, 211, 212]], [[0, 1, 2.], [10, 11, 12]],
+ [[0, 0, 0], [0, 0, 0]], [[300, 301, 302], [310, 311, 312]]],
+ dtype=dtype), y)
+
+ def testUnsortedSegmentSumShapeError(self):
+ for dtype in self.numeric_types:
+ data = np.ones((4, 8, 7), dtype=dtype)
+ indices = np.ones((3, 2), dtype=np.int32)
+ num_segments = 4
+ self.assertRaises(ValueError,
+ functools.partial(self.UnsortedSegmentSum, data,
+ indices, num_segments))
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
index f277314352..ac039e0162 100644
--- a/tensorflow/compiler/tests/tensor_array_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -57,11 +57,13 @@ class TensorArrayTest(xla_test.XLATestCase):
r0 = w2.read(0)
r1 = w2.read(1)
r2 = w2.read(2)
+ flow = w2.flow
- d0, d1, d2 = session.run([r0, r1, r2])
+ d0, d1, d2, flow_val = session.run([r0, r1, r2, flow])
self.assertAllEqual([[4.0, 5.0]], d0)
self.assertAllEqual([[1.0, 3.0]], d1)
self.assertAllEqual([[7.0, -8.5]], d2)
+ self.assertAllEqual([], flow_val.shape)
def _testTensorArrayWritePack(self, tf_dtype):
with self.test_session(), self.test_scope():
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index faa88ecfe2..1c7a2046aa 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -323,12 +323,26 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame,
for (Arg& arg : frame->args) {
if (!arg.is_loop_invariant) {
// Follow the edge from the Enter to Merge.
- if (arg.enter->out_edges().size() != 1) {
+ const Edge* enter_merge = nullptr;
+ for (const Edge* e : arg.enter->out_edges()) {
+ // Ignore control-edges to the sink node. These are allowed by the
+ // graph invariants, although probably they should have been stripped
+ // off earlier.
+ if (e->IsControlEdge() && e->dst()->IsSink()) {
+ continue;
+ }
+ if (enter_merge != nullptr) {
+ return errors::Internal(
+ "Enter node for loop-varying argument ", arg.enter->name(),
+ " has multiple successors: ", enter_merge->dst()->name(), " and ",
+ e->dst()->name());
+ }
+ enter_merge = e;
+ }
+ if (enter_merge == nullptr) {
return errors::Internal("Enter node for loop-varying argument ",
- arg.enter->name(),
- " does not have exactly one successor");
+ arg.enter->name(), " has zero successors");
}
- const Edge* enter_merge = *arg.enter->out_edges().begin();
arg.merge = enter_merge->dst();
if (!IsMerge(arg.merge)) {
return errors::InvalidArgument(
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
index 2fb1cc0454..914c8999a6 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
@@ -96,6 +96,14 @@ TEST(FunctionalizeControlFlow, OneLoopVar) {
TF_EXPECT_OK(scope.ToGraph(&graph));
}
+ // Regression test: control edges from an Enter node to the graph sink should
+ // be ignored.
+ for (Node* n : graph.nodes()) {
+ if (n->name() == "while/Enter") {
+ graph.AddControlEdge(n, graph.sink_node());
+ }
+ }
+
FunctionLibraryDefinition library(OpRegistry::Global(), {});
TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 35bc6b5a24..546e9be864 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -47,6 +47,7 @@ tf_kernel_library(
"reshape_op.cc",
"retval_op.cc",
"reverse_op.cc",
+ "segment_reduction_ops.cc",
"select_op.cc",
"sequence_ops.cc",
"shape_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc
index b8f0c0b9fe..8c8a9bbe78 100644
--- a/tensorflow/compiler/tf2xla/kernels/no_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc
@@ -23,4 +23,9 @@ namespace tensorflow {
// dummy operator using CompilationOnly().
REGISTER_XLA_OP(Name("NoOp").CompilationOnly(), NoOp);
+// We register ControlTrigger as a no-op. This is correct since nodes seen
+// by the XLA compiler are never dead. This may need rethinking when we add
+// support for conditionals to XLA.
+REGISTER_XLA_OP(Name("ControlTrigger"), NoOp);
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
new file mode 100644
index 0000000000..6a0ce775dc
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
@@ -0,0 +1,155 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <sstream>
+#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+namespace {
+
+class UnsortedSegmentSum : public XlaOpKernel {
+ public:
+ explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ // output = unsorted_segment_sum(data, indices, num_segments)
+ // Compute a tensor such that:
+ // output[i] = sum over {j where indices[j] == i} of data[j]
+ // output[i] == 0 if i does not appear in indices
+ //
+ // Contrast with segment_sum(), which assumes indices are sorted and that
+ // max(indices)+1 is the desired size of the output.
+ //
+ // The returned output tensor has the same type as data, and the same shape
+ // as data with the first indices.rank dimensions are replaced
+ // by a single dimension with size num_segments.
+
+ xla::ComputationBuilder* builder = ctx->builder();
+
+ auto data = ctx->Input(0);
+ auto data_shape = ctx->InputShape(0);
+
+ auto indices = ctx->Input(1);
+ auto indices_shape = ctx->InputShape(1);
+
+ OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
+ errors::InvalidArgument(
+ "UnsortedSegmentSum requires that indices' rank be"
+ " less than or equal to data's rank."));
+ // Validate that indices.shape is a prefix of data.shape.
+ for (int d = 0; d < indices_shape.dims(); ++d) {
+ OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)),
+ errors::InvalidArgument(
+ "UnsortedSegmentSum requires indices shape to be prefix"
+ " of data_shape, but dimension ",
+ d, " differs ", data_shape.dim_size(d), " vs. ",
+ indices_shape.dim_size(d)));
+ }
+
+ int64 num_segments;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
+
+ // Flatten the indices into 1-D.
+ auto indices_1d = builder->Reshape(indices, {indices_shape.num_elements()});
+
+ // flatten data for dynamic indexing.
+ int64 out_tensor_dims = data_shape.dims() - indices_shape.dims();
+ std::vector<int64> flat_shape(1 + out_tensor_dims);
+ flat_shape[0] = indices_shape.num_elements();
+ for (int64 k = 0; k < out_tensor_dims; ++k) {
+ flat_shape[1 + k] = data_shape.dim_size(indices_shape.dims() + k);
+ }
+ auto data_flat = builder->Reshape(data, flat_shape);
+
+ // output shape; same as data_shape, but dimension 0 is num_segments.
+ std::vector<int64> out_shape(flat_shape);
+ out_shape[0] = num_segments;
+
+ // Pad the output array dims to rank >= 3 to work around lowering issues.
+ // TODO(b/37575001) This is awkward, and could be improved.
+ int64 extra_dims = 0;
+ if (out_shape.size() < 3) {
+ extra_dims = 3u - out_shape.size();
+ }
+ std::vector<int64> rshape(extra_dims + out_shape.size(), 1);
+ for (unsigned k = 0; k < out_shape.size(); ++k) {
+ rshape[extra_dims + k] = out_shape[k];
+ }
+ auto output = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), rshape);
+
+ auto zero = builder->ConstantR1<int32>({0});
+
+ for (int64 i = 0; i < indices_shape.num_elements(); ++i) {
+ // output[indices[i]] += data[i]
+
+ std::vector<int64> data_start_indices(flat_shape.size());
+ data_start_indices[0] = i;
+ for (unsigned d = 1; d < flat_shape.size(); ++d) {
+ data_start_indices[d] = 0;
+ }
+ std::vector<int64> data_limit_indices(flat_shape);
+ data_limit_indices[0] = i + 1;
+ std::vector<int64> stride(flat_shape.size(), 1);
+
+ auto data_slice = builder->Slice(data_flat, data_start_indices,
+ data_limit_indices, stride);
+
+ // Reshape the sliced data into the R3+ shape to match output array.
+ std::vector<int64> rdata_shape(extra_dims + flat_shape.size());
+ for (int64 k = 0; k <= extra_dims; ++k) {
+ rdata_shape[k] = 1;
+ }
+ for (unsigned k = 1; k < data_limit_indices.size(); ++k) {
+ rdata_shape[extra_dims + k] = data_limit_indices[k];
+ }
+ auto rdata_slice = builder->Reshape(data_slice, rdata_shape);
+
+ auto index = builder->Slice(indices_1d, {i}, {i + 1}, {1});
+
+ // Construct the index into the R3+ output array 0, ..., <index>, 0, ...
+ std::vector<xla::ComputationDataHandle> out_start_index_parts(
+ extra_dims + flat_shape.size(), zero);
+ out_start_index_parts[extra_dims] = builder->Reshape(index, {1});
+ auto out_start_indices = builder->ConcatInDim(out_start_index_parts, 0);
+
+ std::vector<int64> slice_size(rshape);
+ slice_size[extra_dims] = 1;
+
+ auto out_slice =
+ builder->DynamicSlice(output, out_start_indices, slice_size);
+ auto sumval = builder->Add(out_slice, rdata_slice);
+ output = builder->DynamicUpdateSlice(output, sumval, out_start_indices);
+ }
+ auto reshaped_output = builder->Reshape(output, out_shape);
+ ctx->SetOutput(0, reshaped_output);
+ }
+
+ private:
+ DataType dtype_;
+};
+
+REGISTER_XLA_OP(Name("UnsortedSegmentSum"), UnsortedSegmentSum);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index bdd52b7f8e..34cc8b2315 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -182,7 +182,10 @@ class TensorArrayOp : public XlaOpKernel {
dtype_, value, &var));
var->tensor_array_size = size;
ctx->SetResourceOutput(0, var);
- ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
+
+ Tensor flow(DT_FLOAT, TensorShape({}));
+ flow.scalar<float>()() = 0.0f;
+ ctx->SetConstantOutput(1, flow);
}
private:
@@ -216,6 +219,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
xla::ComputationDataHandle ta = resource->value;
xla::ComputationDataHandle index = ctx->Input(1);
xla::ComputationDataHandle value = ctx->Input(2);
+ xla::ComputationDataHandle flow = ctx->Input(3);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices = XlaHelpers::PadWithZeros(b, index, elem_shape.dims());
@@ -228,7 +232,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
resource->value = written;
- ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ ctx->SetOutput(0, flow);
}
private:
@@ -369,6 +373,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
xla::ComputationDataHandle ta = resource->value;
const xla::ComputationDataHandle value = ctx->Input(2);
+ const xla::ComputationDataHandle flow = ctx->Input(3);
auto slice_dims = value_shape.dim_sizes();
slice_dims[0] = 1LL;
@@ -394,7 +399,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
resource->value = ta;
- ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ ctx->SetOutput(0, flow);
}
private:
@@ -489,6 +494,7 @@ class TensorArraySplitOp : public XlaOpKernel {
lengths.size(), " vs. ", resource->tensor_array_size, ")"));
const xla::ComputationDataHandle value = ctx->Input(1);
+ const xla::ComputationDataHandle flow = ctx->Input(3);
OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(),
errors::InvalidArgument("mismatched element count ",
@@ -497,7 +503,7 @@ class TensorArraySplitOp : public XlaOpKernel {
resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
- ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ ctx->SetOutput(0, flow);
}
private:
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 5eef45b11d..e0a03a78f1 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -132,7 +132,10 @@ cc_library(
cc_library(
name = "statusor",
srcs = ["statusor.cc"],
- hdrs = ["statusor.h"],
+ hdrs = [
+ "statusor.h",
+ "statusor_internals.h",
+ ],
visibility = ["//visibility:public"],
deps = [
":status",
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 4c2705b007..bd040c166d 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1943,6 +1943,7 @@ cc_library(
":buffer_liveness",
":hlo",
":hlo_pass",
+ "//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index b351861425..4837402c15 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1488,9 +1488,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// We cannot insert bitcasts if the layouts will not be compatible.
// TODO(b/33178038): Consider inserting a transpose if a bitcast would be
// invalid.
- if (!valid_bitcast_callback_(lhs->shape(), input_shape) ||
- !valid_bitcast_callback_(rhs->shape(), new_filter_shape) ||
- !valid_bitcast_callback_(dot_output_shape, convolution_shape)) {
+ if (!valid_bitcast_callback_(input_shape, lhs->shape()) ||
+ !valid_bitcast_callback_(new_filter_shape, rhs->shape()) ||
+ !valid_bitcast_callback_(convolution_shape, dot_output_shape)) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
index f8919f0caa..4295a3227a 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.h
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -26,12 +26,13 @@ namespace xla {
// A pass which performs AlgebraicSimplications.
class AlgebraicSimplifier : public HloPassInterface {
public:
- // Given two shapes, determines if it is valid to bitcast between them after
- // considering platform dependent effects on layout like alignment
- // restrictions.
- // Precondition: the two shapes have layouts, the same number of
- // elements and ShapeUtil::ReshapeIsBitcast returns true.
- using ValidBitcastCallback = std::function<bool(const Shape&, const Shape&)>;
+ // Given shapes 'from_shape' and 'to_shape', determines if it is valid to
+ // bitcast from 'from_shape' to 'to_shape' after considering platform
+ // dependent effects on layout like alignment restrictions. Precondition: the
+ // two shapes have layouts, the same number of elements and
+ // ShapeUtil::ReshapeIsBitcast returns true.
+ using ValidBitcastCallback =
+ std::function<bool(const Shape& from_shape, const Shape& to_shape)>;
// If is_layout_sensitive is true, then the simplifier preserves layout during
// transformation. Otherwise, layout is ignored. If valid_bitcast_callback
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 7248cb5f4c..2ca4af67cd 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -72,6 +72,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:inliner",
+ "//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # fixdeps: keep
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 6d819355c4..b86342d0b3 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -74,6 +74,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/inliner.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -253,6 +254,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
HloPassPipeline pipeline("CPU");
pipeline.AddInvariantChecker<HloVerifier>();
+ for (const auto& reduce_precision_options :
+ module->config().debug_options().hlo_reduce_precision_options()) {
+ if (reduce_precision_options.pass_timing() ==
+ HloReducePrecisionOptions::BEFORE_OP_FUSION) {
+ pipeline.AddPass<ReducePrecisionInsertion>(reduce_precision_options);
+ }
+ }
+
// TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding
// where we will take this pass in future.
// pipeline.AddPass<Inliner>();
@@ -278,6 +287,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
TransposeFolding::NeverFoldTranspose);
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
pipeline.AddPass<CpuInstructionFusion>();
+
+ for (const auto& reduce_precision_options :
+ module->config().debug_options().hlo_reduce_precision_options()) {
+ if (reduce_precision_options.pass_timing() ==
+ HloReducePrecisionOptions::AFTER_OP_FUSION) {
+ pipeline.AddPass<ReducePrecisionInsertion>(reduce_precision_options);
+ }
+ }
+
pipeline.AddPass<CpuLayoutAssignment>(
module->mutable_entry_computation_layout());
// The LayoutAssignment pass may leave behind kCopy instructions which are
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index fa95e23499..cdd7c8187c 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -432,6 +432,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_proto_util",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index d60c45a5c3..2acf95084a 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -56,6 +56,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -123,6 +124,15 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
{
HloPassPipeline pipeline("optimization");
pipeline.AddInvariantChecker<HloVerifier>();
+
+ for (const auto& reduce_precision_options :
+ hlo_module->config().debug_options().hlo_reduce_precision_options()) {
+ if (reduce_precision_options.pass_timing() ==
+ HloReducePrecisionOptions::BEFORE_OP_FUSION) {
+ pipeline.AddPass<ReducePrecisionInsertion>(reduce_precision_options);
+ }
+ }
+
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
@@ -149,8 +159,27 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
fusion.AddPass<FusionMerger>();
- return fusion.Run(hlo_module).status();
+ TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
+
+ HloPassPipeline reduce_pipeline("reduce-precision");
+ for (const auto& reduce_precision_options :
+ hlo_module->config().debug_options().hlo_reduce_precision_options()) {
+ if (reduce_precision_options.pass_timing() ==
+ HloReducePrecisionOptions::AFTER_OP_FUSION) {
+ reduce_pipeline.AddPass<ReducePrecisionInsertion>(
+ reduce_precision_options);
+ }
+ }
+ StatusOr<bool> reduce_result = reduce_pipeline.Run(hlo_module);
+ TF_RETURN_IF_ERROR(reduce_result.status());
+
+ if (reduce_result.ValueOrDie()) {
+ // Do another fusion pass, with the expectation that we may be able to
+ // fuse the new ReducePrecision operations.
+ TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
+ }
}
+ return tensorflow::Status::OK();
}
// Modifies the given HLO module so that it will be accepted by IrEmitter.
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index acd26c4e31..c6202548f1 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -48,37 +48,37 @@ using ::tensorflow::Env;
using ::tensorflow::gtl::nullopt;
using ::tensorflow::gtl::optional;
using ::tensorflow::io::JoinPath;
-using ::tensorflow::strings::Appendf;
-using ::tensorflow::strings::Printf;
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
using ::tensorflow::str_util::Join;
+using ::tensorflow::str_util::StringReplace;
using ::tensorflow::WriteStringToFile;
namespace xla {
namespace hlo_graph_dumper {
namespace {
-// Node color schemes, used by NodeColorAttributes.
-enum ColorScheme {
- kBlue,
- kBrown,
- kDarkBlue,
- kDarkGreen,
- kDarkRed,
- kGray,
- kGreen,
- kOrange,
- kPurple,
- kRed,
- kWhite,
- kYellow,
-
- // Causes the node's border to be a dashed line, and its content to be gray
- // text on a white background, suggesting that this is an "unimportant" node.
- kDashedBorder,
+// Helpers for Printf and Appendf.
+template <typename T>
+struct PrintfConvert {
+ const T& operator()(const T& t) const { return t; }
+};
+template <>
+struct PrintfConvert<string> {
+ const char* operator()(const string& s) const { return s.c_str(); }
};
+// Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str()
+// on strings.
+template <typename... Ts>
+string Printf(const char* fmt, const Ts&... ts) {
+ return tensorflow::strings::Printf(fmt, PrintfConvert<Ts>()(ts)...);
+}
+template <typename... Ts>
+void Appendf(string* s, const char* fmt, const Ts&... ts) {
+ tensorflow::strings::Appendf(s, fmt, PrintfConvert<Ts>()(ts)...);
+}
+
// Used to indicate how we should treat a given HLOInstruction in the graph.
// should we treat it like normal, hide it, and so on?
enum NodeFilterResult {
@@ -92,6 +92,9 @@ enum NodeFilterResult {
// Style the node the same as kSomeOperandsOmitted, but also don't connect it
// to its operands, even if they're present in the graph.
kOmitNodeOperands,
+ // Same style as kSomeOperandsOmitted, but used to indicate that some of the
+ // node's *users* have been omitted.
+ kSomeUsersOmitted,
};
// NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
@@ -118,11 +121,41 @@ class NodeFilter {
auto result = filter_(instr);
return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
}
+ bool Deemphasized(const HloInstruction* instr) const {
+ auto result = filter_(instr);
+ return result == kOmitNodeOperands || result == kSomeOperandsOmitted ||
+ result == kSomeUsersOmitted;
+ }
+
+ bool ShowFusionSubcomputation(const HloInstruction* instr) const {
+ CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
+ return Show(instr) && !SomeOrAllOperandsOmitted(instr);
+ }
private:
std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
};
+// Node color schemes, used by NodeColorAttributes.
+enum ColorScheme {
+ kBlue,
+ kBrown,
+ kDarkBlue,
+ kDarkGreen,
+ kDarkRed,
+ kGray,
+ kGreen,
+ kOrange,
+ kPurple,
+ kRed,
+ kWhite,
+ kYellow,
+
+ // Causes the node's border to be a dashed line, and its content to be gray
+ // text on a white background, suggesting that this is an "unimportant" node.
+ kDashedBorder,
+};
+
// Given a ColorScheme, returns an attribute string for a node of that color.
// Sets the node's style and fill/stroke/text colors.
//
@@ -170,19 +203,8 @@ string NodeColorAttributes(ColorScheme color) {
// Replaces <> with &lt;&gt;, so that this string is safe(er) for use in a
// graphviz HTML-like string.
string HtmlLikeStringSanitize(tensorflow::StringPiece s) {
- return tensorflow::str_util::StringReplace(
- tensorflow::str_util::StringReplace(s, "<", "&lt;", /*replace_all=*/true),
- ">", "&gt;", /*replace_all=*/true);
-}
-
-// Returns the dot graph identifier for the given instruction.
-string InstructionId(const HloInstruction* instruction) {
- return Printf("%lld", reinterpret_cast<uint64>(instruction));
-}
-
-// Returns the dot graph identifier for the given computation.
-string ComputationId(const HloComputation* computation) {
- return Printf("%lld", reinterpret_cast<uint64>(computation));
+ return StringReplace(StringReplace(s, "<", "&lt;", /*replace_all=*/true), ">",
+ "&gt;", /*replace_all=*/true);
}
// Tries to generates a human-readable one-word description of the given
@@ -194,9 +216,15 @@ string ComputationId(const HloComputation* computation) {
// "return param0 * param1;" --> "multiply"
// "return min(param0, param1);" --> "min"
// "return max(param0, param1);" --> "max"
+// "return param0 <= param1;" --> "less-or-equal"
+// "return param0 >= param1;" --> "greater-or-equal"
+// "return param0 > param1;" --> "greater-than"
+// "return param0 < param1;" --> "less-than"
+// "return param0 == param1;" --> "equal-to"
+// "return param0 != param1;" --> "not-equal-to"
//
-// where param0 and param1 are effective scalars. Since all of the ops above
-// are commutative, we also support them with param0 and param1 swapped.
+// where param0 and param1 are effective scalars. For the ops that are
+// commutative, we also support them with param0 and param1 swapped.
//
// This is useful primarily for reduce and map nodes. These take a
// subcomputation which is almost always one of the four above, and pattern
@@ -219,6 +247,7 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
operand1->opcode() != HloOpcode::kParameter) {
return nullopt;
}
+
// Check that the two operands of root are param0 and param1. All of the
// opcodes we recognize are commutative, so we're OK with either order.
auto n0 = operand0->parameter_number();
@@ -227,6 +256,20 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
return nullopt;
}
+ // If the params are reversed, check that the operation being performed is
+ // commutative.
+ if (n0 == 1) {
+ switch (root->opcode()) {
+ case HloOpcode::kLe:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kLt:
+ return nullopt;
+ default:
+ break;
+ }
+ }
+
// Check that the root and params are all effective scalars.
if (!ShapeUtil::IsEffectiveScalar(root->shape()) ||
!ShapeUtil::IsEffectiveScalar(operand0->shape()) ||
@@ -244,444 +287,542 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
return "min";
case HloOpcode::kMaximum:
return "max";
+ case HloOpcode::kLe:
+ return "less-or-equal";
+ case HloOpcode::kGe:
+ return "greater-or-equal";
+ case HloOpcode::kGt:
+ return "greater-than";
+ case HloOpcode::kLt:
+ return "less-than";
+ case HloOpcode::kEq:
+ return "equal-to";
+ case HloOpcode::kNe:
+ return "not-equal-to";
default:
return nullopt;
}
}
-// Returns the dot graph edges and nodes for the given instruction sequence.
-// Edges which extend between computations are added to the vector
-// intercomputation_edges. This is necessary because graphviz does not render
-// the graph properly unless these inter-computation edges appear after all
-// subgraph statements.
-string InstructionSequenceGraph(
- const std::list<std::unique_ptr<HloInstruction>>& instructions,
- bool show_addresses, bool show_layouts,
- std::vector<string>* intercomputation_edges,
- const HloExecutionProfile* hlo_execution_profile,
- const NodeFilter& filter) {
- string graph_body;
-
- for (auto& instruction : instructions) {
- if (!filter.Show(instruction.get())) {
- continue;
- }
+class HloDotDumper {
+ public:
+ HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
+ bool show_addresses, bool show_layouts,
+ const HloExecutionProfile* profile, NodeFilter filter)
+ : computation_(computation),
+ label_(label.ToString()),
+ show_addresses_(show_addresses),
+ show_layouts_(show_layouts),
+ profile_(profile),
+ filter_(std::move(filter)) {}
+
+ string Dump();
- // We don't display constants as separate nodes; they're merged into their
- // users.
- if (instruction->opcode() == HloOpcode::kConstant) {
+ private:
+ // Returns the dot graph identifier for the given instruction.
+ string InstructionId(const HloInstruction* instruction) {
+ return StrCat(reinterpret_cast<uint64>(instruction));
+ }
+
+ // Returns the dot graph identifier for the given computation.
+ string SubcomputationId(const HloComputation* computation) {
+ return StrCat("cluster_", reinterpret_cast<uint64>(computation));
+ }
+
+ string Header();
+ string Footer();
+
+ // Maps HloComputations we should dump to their parent instruction in the
+ // outer computation.
+ std::unordered_map<const HloComputation*, const HloInstruction*>
+ SubcomputationsToDump();
+
+ string DumpSubcomputation(const HloComputation* subcomp,
+ const HloInstruction* parent_instr);
+ string DumpComputation(const HloComputation* comp);
+ string DumpInstruction(const HloInstruction* instr);
+ ColorScheme GetInstructionColor(const HloInstruction* instr);
+ string GetInstructionNodeShape(const HloInstruction* instr);
+ string GetInstructionNodeLabel(const HloInstruction* instr);
+ string GetInstructionNodeExtraInfo(const HloInstruction* instr);
+ string GetInstructionNodeInlinedConstants(const HloInstruction* instr);
+ void AddInstructionIncomingEdges(const HloInstruction* instr);
+
+ // If instr has just one computation and it's trivial (e.g. "return param0 +
+ // param1"), returns a string you can put into the node's body that names the
+ // subcomputation, e.g. "Subcomputation: <b>add</b>".
+ string GetInstructionTrivialComputationStr(const HloInstruction* instr);
+
+ const HloComputation* computation_; // never null
+ const string label_; // overall name for the graph
+ const bool show_addresses_;
+ const bool show_layouts_;
+ const HloExecutionProfile* profile_; // may be null
+ const NodeFilter filter_;
+
+ // Edges to print from Footer(). Edges come at the end because graphviz is
+ // unhappy if an edge from a subcomputation to a node in the outer computation
+ // appears before both the inner computation and the destination node are
+ // defined.
+ std::vector<string> edges_;
+};
+
+string HloDotDumper::Dump() {
+ string g = Header();
+ for (const auto& kv : SubcomputationsToDump()) {
+ const HloComputation* subcomp = kv.first;
+ const HloInstruction* parent = kv.second;
+ StrAppend(&g, DumpSubcomputation(subcomp, parent));
+ }
+ StrAppend(&g, DumpComputation(computation_));
+ StrAppend(&g, Footer());
+ return g;
+}
+
+string HloDotDumper::Header() {
+ // DOT graphs accept a stylesheet as a URI. So naturally, an inline
+ // stylesheet is a data URI!
+ const char* fmt = R"(digraph G {
+rankdir = TB;
+compound = true;
+label = <<b>%s</b>>;
+labelloc = t;
+stylesheet="
+ data:text/css,
+ @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
+ svg text {
+ font-family: 'Roboto';
+ font-size: 12px;
+ }
+"
+
+)";
+
+ string graph_label = StrCat(label_, "<br/>", computation_->name());
+ if (profile_ != nullptr) {
+ auto cycles = profile_->total_cycles_executed(*computation_);
+ Appendf(&graph_label, "<br/>total cycles = %lld (%s)", cycles,
+ tensorflow::strings::HumanReadableNum(cycles));
+ }
+ return Printf(fmt, graph_label);
+}
+
+string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); }
+
+std::unordered_map<const HloComputation*, const HloInstruction*>
+HloDotDumper::SubcomputationsToDump() {
+ // Dump the subcomputations of each instruction that's shown and doesn't have
+ // its operands omitted. If an instruction has just one subcomputation and
+ // it's trivial, omit it: We'll display that subcomputation inlined into the
+ // instruction's node when we draw it.
+ std::unordered_map<const HloComputation*, const HloInstruction*> to_dump;
+ for (const auto& instr : computation_->instructions()) {
+ if (!filter_.Show(instr.get()) ||
+ filter_.SomeOrAllOperandsOmitted(instr.get())) {
continue;
}
+ if (instr->opcode() == HloOpcode::kFusion) {
+ to_dump[instr->fused_instructions_computation()] = instr.get();
+ }
- ColorScheme color = kYellow;
- string shape = "box";
-
- // Build the first line or two of the node, containing its name and opcode
- // (if the opcode isn't redundant with the name).
- string name;
- if (instruction->opcode() == HloOpcode::kParameter) {
- // If we have a parameter, put the param number in the name.
- name = StrCat("<b>Parameter ", instruction->parameter_number(),
- "</b><br/>", HtmlLikeStringSanitize(instruction->name()));
- } else if (tensorflow::StringPiece(instruction->name())
- .starts_with(
- StrCat("%", instruction->ExtendedOpcodeStr()))) {
- // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
- // an add instruction. In this case we render just the name.
- name = StrCat("<b>", HtmlLikeStringSanitize(instruction->name()), "</b>");
- } else if (instruction->opcode() == HloOpcode::kFusion &&
- tensorflow::StringPiece(instruction->name())
- .starts_with(
- StrCat("%", HloOpcodeString(instruction->opcode())))) {
- // Fusion nodes are usually named e.g. "%fusion.5". We render these as
- // e.g. "%fusion.5<br/>input fusion".
- name = StrCat("<b>", HtmlLikeStringSanitize(instruction->name()),
- "</b><br/>",
- HtmlLikeStringSanitize(instruction->ToCategory()));
- } else {
- // If the name does not contain the opcode, render both.
- name = StrCat("<b>",
- HtmlLikeStringSanitize(instruction->ExtendedOpcodeStr()),
- "</b><br/>", HtmlLikeStringSanitize(instruction->name()));
- }
-
- if (HloOpcode::kConvolution == instruction->opcode()) {
- StrAppend(
- &name, "<br/>",
- HtmlLikeStringSanitize(
- instruction->ConvolutionDimensionNumbersToString()),
- "<br/>",
- HtmlLikeStringSanitize(window_util::ToString(instruction->window())));
- }
-
- if (!instruction->metadata().op_name().empty()) {
- StrAppend(&name, "<br/>",
- HtmlLikeStringSanitize(instruction->metadata().op_name()));
- }
- if (!instruction->metadata().source_file().empty() &&
- instruction->metadata().source_line() != 0) {
- StrAppend(&name, "<br/>", instruction->metadata().source_file(), ":",
- instruction->metadata().source_line());
- }
-
- // Pick different colors or shapes for instructions which are particularly
- // expensive (eg, dot) and those which are unusual in some way or unique
- // (eg, parameter).
- switch (instruction->opcode()) {
- // "Normal" instructions. Mostly cheap and elementwise. No call to
- // embedded computations. In this case, use default color, shape and
- // label.
- case HloOpcode::kAbs:
- case HloOpcode::kAdd:
- case HloOpcode::kCeil:
- case HloOpcode::kClamp:
- case HloOpcode::kConvert:
- case HloOpcode::kCos:
- case HloOpcode::kDivide:
- case HloOpcode::kEq:
- case HloOpcode::kExp:
- case HloOpcode::kFloor:
- case HloOpcode::kGe:
- case HloOpcode::kGt:
- case HloOpcode::kIndex:
- case HloOpcode::kIsFinite:
- case HloOpcode::kLe:
- case HloOpcode::kLog:
- case HloOpcode::kLogicalAnd:
- case HloOpcode::kLogicalNot:
- case HloOpcode::kLogicalOr:
- case HloOpcode::kLt:
- case HloOpcode::kMaximum:
- case HloOpcode::kMinimum:
- case HloOpcode::kMultiply:
- case HloOpcode::kNe:
- case HloOpcode::kNegate:
- case HloOpcode::kPower:
- case HloOpcode::kRemainder:
- case HloOpcode::kSelect:
- case HloOpcode::kSign:
- case HloOpcode::kSin:
- case HloOpcode::kSlice:
- case HloOpcode::kSort:
- case HloOpcode::kSubtract:
- case HloOpcode::kTanh:
- break;
- case HloOpcode::kRng:
- StrAppend(&name, "<br/>",
- RandomDistribution_Name(instruction->random_distribution()));
- break;
- case HloOpcode::kBroadcast:
- case HloOpcode::kTranspose:
- StrAppend(&name, "<br/>", "dims={",
- Join(instruction->dimensions(), ","), "}");
- break;
- case HloOpcode::kBitcast:
- case HloOpcode::kTuple:
- case HloOpcode::kTrace:
- color = kWhite;
- break;
- case HloOpcode::kGetTupleElement:
- color = kWhite;
- StrAppend(&name, "<br/>index=", instruction->tuple_index());
- break;
- case HloOpcode::kConcatenate:
- case HloOpcode::kCopy:
- case HloOpcode::kDynamicSlice:
- case HloOpcode::kDynamicUpdateSlice:
- case HloOpcode::kPad:
- case HloOpcode::kReshape:
- case HloOpcode::kReverse:
- case HloOpcode::kUpdate:
- color = kGreen;
- break;
- case HloOpcode::kConvolution:
- case HloOpcode::kDot:
- color = kDarkBlue;
- break;
- case HloOpcode::kParameter:
- color = kOrange;
- break;
- case HloOpcode::kBatchNormTraining:
- StrAppend(&name, " feature_index=", instruction->feature_index());
- color = kPurple;
- break;
- case HloOpcode::kBatchNormGrad:
- StrAppend(&name, " feature_index=", instruction->feature_index());
- color = kPurple;
- break;
- case HloOpcode::kReduce:
- StrAppend(&name, " dims=", Join(instruction->dimensions(), ","));
- color = kPurple;
- break;
- case HloOpcode::kSelectAndScatter:
- case HloOpcode::kReduceWindow:
- color = kPurple;
- break;
- case HloOpcode::kWhile:
- shape = "ellipse";
- color = kDarkGreen;
- break;
- case HloOpcode::kMap:
- case HloOpcode::kFusion:
- color = kGray;
- break;
- case HloOpcode::kSend:
- case HloOpcode::kRecv:
- case HloOpcode::kInfeed:
- case HloOpcode::kOutfeed:
- case HloOpcode::kCrossReplicaSum:
- color = kBrown;
- break;
- case HloOpcode::kCall:
- color = kDarkGreen;
- break;
- case HloOpcode::kCustomCall:
- color = kDarkGreen;
- StrAppend(&name, "<br/>",
- "custom_call_target=", instruction->custom_call_target());
- break;
- case HloOpcode::kReducePrecision:
- // Make ReducePrecision ops a bit more visible, since typically they
- // will be inserted as modifications to an existing graph.
- color = kRed;
- break;
- case HloOpcode::kConstant:
- LOG(FATAL) << "Constants don't get their own nodes in the graph.";
- }
-
- // Create instruction node with appropriate label, shape, and color.
- // label is interpreted as an HTML-like string, so newlines must be
- // delimited with <br/>, rather than \n.
- string label =
- StrCat(name, "<br/>", ShapeUtil::HumanString(instruction->shape()));
-
- if (show_addresses) {
- Appendf(&label, "<br/>[%p]", instruction.get());
- }
- if (show_layouts && LayoutUtil::HasLayout(instruction->shape())) {
- string layout_string;
- if (ShapeUtil::IsTuple(instruction->shape())) {
- // For tuples, emit the full shape because the layout of a tuple is not
- // represented in a single Layout field.
- layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape());
- } else {
- layout_string =
- Join(instruction->shape().layout().minor_to_major(), ",");
- }
- StrAppend(&label, "<br/>layout={", layout_string, "}");
- }
- if (hlo_execution_profile != nullptr) {
- auto hlo_cycles_executed =
- hlo_execution_profile->GetProfileResult(*instruction);
- auto total_cycles_executed =
- hlo_execution_profile->total_cycles_executed(*instruction->parent());
- if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
- Appendf(&label, "<br/>%% of cycles executed=%.2f",
- (static_cast<double>(hlo_cycles_executed) /
- static_cast<double>(total_cycles_executed)) *
- 100);
+ for (const HloComputation* comp : instr->called_computations()) {
+ if (!MatchTrivialComputation(comp)) {
+ to_dump[comp] = instr.get();
}
}
+ }
+ return to_dump;
+}
- // If this node's operands are omitted, style it accordingly.
- if (filter.SomeOrAllOperandsOmitted(instruction.get())) {
- color = kDashedBorder;
- }
+string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
+ const HloInstruction* parent_instr) {
+ const char* computation_fmt = R"(subgraph %s {
+%s;
+label = <%s>;
+labelloc = t;
+%s
+} // %s
- // If this node is highlighted, override its formatting.
- if (filter.Highlight(instruction.get())) {
- shape = "diamond";
- color = kDarkRed;
+)";
+
+ string id = SubcomputationId(subcomp);
+
+ string subcomp_label, style;
+ if (parent_instr->opcode() == HloOpcode::kFusion) {
+ subcomp_label = Printf("Fused expression for <b>%s</b><br/>%s",
+ HtmlLikeStringSanitize(parent_instr->name()),
+ HtmlLikeStringSanitize(parent_instr->ToCategory()));
+
+ // Subcomputation's fill/stroke color is light/dark red/gray, depending on
+ // whether or not the subcomputation's fusion node is highlighted.
+ bool highlight = filter_.Highlight(parent_instr);
+ const char* fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
+ const char* strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
+ style = Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s")",
+ fillcolor, strokecolor);
+ } else {
+ subcomp_label = Printf("Subcomputation for <b>%s</b><br/>%s",
+ HtmlLikeStringSanitize(parent_instr->name()),
+ HtmlLikeStringSanitize(subcomp->name()));
+ style = "style=rounded; color=black;";
+ }
+
+ string comp_body = DumpComputation(subcomp);
+ string computation =
+ Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
+
+ // Add an edge from the subcomputation to its parent node. If subcomp
+ // belongs to a fusion node, it's drawn in place of the fusion instruction, so
+ // there's no need to link those.
+ if (parent_instr->opcode() != HloOpcode::kFusion) {
+ const char* edge_fmt = R"(%s -> %s [ltail="%s", style="dashed"];)";
+ edges_.push_back(
+ Printf(edge_fmt, InstructionId(subcomp->root_instruction()),
+ InstructionId(parent_instr), SubcomputationId(subcomp)));
+ }
+
+ return computation;
+}
+
+string HloDotDumper::DumpComputation(const HloComputation* comp) {
+ string g;
+ for (const auto& instr : comp->instructions()) {
+ if (!filter_.Show(instr.get())) {
+ continue;
}
+ StrAppend(&g, DumpInstruction(instr.get()));
+ }
+ return g;
+}
- // Create edges from the instruction's operands to the instruction.
- if (!filter.OmitOperands(instruction.get())) {
- int64 operand_number = 0;
- for (auto* operand : instruction->operands()) {
- if (!filter.Show(operand) ||
- operand->opcode() == HloOpcode::kConstant) {
- ++operand_number;
- continue;
- }
- Appendf(&graph_body, "%s -> %s", InstructionId(operand).c_str(),
- InstructionId(instruction.get()).c_str());
- if (instruction->operand_count() > 1) {
- Appendf(&graph_body, " [headlabel=\"%lld\",labeldistance=2]",
- operand_number);
- }
- StrAppend(&graph_body, ";\n");
- ++operand_number;
- }
+string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
+ // We don't display constants as separate nodes; they're merged into their
+ // users.
+ if (instr->opcode() == HloOpcode::kConstant) {
+ return "";
+ }
+ // Omit the fusion node if its subcomputation is drawn, since the
+ // subcomputation will be drawn inline.
+ if (instr->opcode() == HloOpcode::kFusion &&
+ filter_.ShowFusionSubcomputation(instr)) {
+ return "";
+ }
- // Fusion nodes are handled specially because they contain nested
- // expressions.
- if (instruction->opcode() == HloOpcode::kFusion) {
- string cluster_name =
- StrCat("cluster_", InstructionId(instruction.get()));
- StrAppend(&graph_body, "subgraph ", cluster_name, " {\n");
- StrAppend(&graph_body, "label=<fused expression for <b>",
- HtmlLikeStringSanitize(instruction->name()),
- "</b>>;\nstyle=\"rounded,filled\";\n"
- "color=lightgrey;\n");
- StrAppend(&graph_body,
- InstructionSequenceGraph(instruction->fused_instructions(),
- show_addresses, show_layouts,
- intercomputation_edges,
- hlo_execution_profile, NodeFilter()),
- "}\n");
- string fusion_edge = StrCat(
- InstructionId(instruction->fused_expression_root()), " -> ",
- InstructionId(instruction.get()),
- " [ style = \"dotted\", arrowsize=0.0, ltail=", cluster_name,
- " ];\n");
- intercomputation_edges->push_back(fusion_edge);
- } else {
- // If instruction has just one computation and it's trivial (e.g.
- // "return param0 + param1"), put the trivial computation type (e.g.
- // "add") into instruction's label. Otherwise, add a dotted edge
- // between the instruction and its subcomputations.
- const auto& subcomputations = instruction->called_computations();
-
- bool trivial_subcomputation = false;
- if (subcomputations.size() == 1) {
- optional<string> computation_type =
- MatchTrivialComputation(subcomputations.front());
- if (computation_type) {
- trivial_subcomputation = true;
- StrAppend(&label, "<br/>Subcomputation: <b>", *computation_type,
- "</b>");
- }
- }
+ ColorScheme color = GetInstructionColor(instr);
+ string node_shape = GetInstructionNodeShape(instr);
+ string node_label = GetInstructionNodeLabel(instr);
+ string extra_info = GetInstructionNodeExtraInfo(instr);
+ string inlined_constants = GetInstructionNodeInlinedConstants(instr);
+ string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
+ AddInstructionIncomingEdges(instr);
+
+ // Override the node's styling if it should be (de-)emphasized.
+ if (filter_.Deemphasized(instr)) {
+ color = kDashedBorder;
+ }
+ if (filter_.Highlight(instr)) {
+ node_shape = "diamond";
+ color = kDarkRed;
+ }
- if (!trivial_subcomputation) {
- for (const HloComputation* computation :
- instruction->called_computations()) {
- string cluster_name =
- StrCat("cluster_", ComputationId(computation));
- string call_edge = Printf(
- "%s -> %s [ style=dashed; ltail=%s ];\n",
- InstructionId(computation->root_instruction()).c_str(),
- InstructionId(instruction.get()).c_str(), cluster_name.c_str());
- intercomputation_edges->push_back(call_edge);
- }
- }
- }
+ // Build the text that will be displayed inside the node.
+ string node_body = node_label;
+ for (const string& s :
+ {trivial_subcomputation, extra_info, inlined_constants}) {
+ if (!s.empty()) {
+ StrAppend(&node_body, "<br/>", s);
}
+ }
- // Inline constant operands into the node.
- for (int64 i = 0; i < instruction->operand_count(); ++i) {
- const HloInstruction* operand = instruction->operand(i);
- if (operand->opcode() != HloOpcode::kConstant) {
- continue;
- }
+ return Printf("%s [label=<%s>, shape=%s, %s];\n", InstructionId(instr),
+ node_body, node_shape, NodeColorAttributes(color));
+}
- StrAppend(&label, "<br/><b>operand ", i, "</b> = ");
- if (ShapeUtil::IsEffectiveScalar(operand->shape())) {
- auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex(
- operand->shape(), /*linear_index=*/0);
- StrAppend(&label, ShapeUtil::HumanString(operand->shape()), "{",
- operand->literal().GetAsString(elem_idx), "}");
- } else {
- if (tensorflow::StringPiece(operand->name()).starts_with("%constant")) {
- StrAppend(&label, operand->name());
- } else {
- StrAppend(&label, "constant ", operand->name());
- }
- }
+string HloDotDumper::GetInstructionNodeInlinedConstants(
+ const HloInstruction* instr) {
+ auto stringify_constant = [](const HloInstruction* constant) {
+ if (ShapeUtil::IsEffectiveScalar(constant->shape())) {
+ auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex(
+ constant->shape(), /*linear_index=*/0);
+ return Printf("%s{%s}", ShapeUtil::HumanString(constant->shape()),
+ constant->literal().GetAsString(elem_idx));
+ }
+ if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) {
+ return constant->name();
}
+ return StrCat("constant ", constant->name());
+ };
- Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n",
- InstructionId(instruction.get()).c_str(), label.c_str(),
- shape.c_str(), NodeColorAttributes(color).c_str());
+ // Special case: If instr is a parameter to a fusion node, check whether the
+ // corresponding operand to the fusion node is a constant.
+ if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
+ const HloInstruction* fusion = instr->fusion_instruction();
+ const HloInstruction* operand = fusion->operand(instr->parameter_number());
+ if (operand->opcode() != HloOpcode::kConstant) {
+ return "";
+ }
+ return stringify_constant(operand);
}
- return graph_body;
+
+ std::vector<string> lines;
+ for (int64 i = 0; i < instr->operand_count(); ++i) {
+ const HloInstruction* operand = instr->operand(i);
+ if (operand->opcode() != HloOpcode::kConstant) {
+ continue;
+ }
+ lines.push_back(
+ Printf("<b>operand %lld</b> = %s", i, stringify_constant(operand)));
+ }
+ return Join(lines, "<br/>");
}
-// DOT graphs accept a stylesheet as a URL. So naturally, an inline stylesheet
-// is a data URI!
-//
-// We don't perform any escaping on this string, so be careful not to use double
-// quotes inside.
-static const char* dot_stylesheet = R"(
-data:text/css,
-@import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
-svg text {
- font-family: 'Roboto';
- font-size: 12px;
+ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
+ // Pick different colors or shapes for instructions which are particularly
+ // expensive (eg, dot) and those which are unusual in some way or unique
+ // (eg, parameter).
+ switch (instr->opcode()) {
+ case HloOpcode::kAbs:
+ case HloOpcode::kAdd:
+ case HloOpcode::kCeil:
+ case HloOpcode::kClamp:
+ case HloOpcode::kConvert:
+ case HloOpcode::kCos:
+ case HloOpcode::kDivide:
+ case HloOpcode::kEq:
+ case HloOpcode::kExp:
+ case HloOpcode::kFloor:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kIndex:
+ case HloOpcode::kIsFinite:
+ case HloOpcode::kLe:
+ case HloOpcode::kLog:
+ case HloOpcode::kLogicalAnd:
+ case HloOpcode::kLogicalNot:
+ case HloOpcode::kLogicalOr:
+ case HloOpcode::kLt:
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kNe:
+ case HloOpcode::kNegate:
+ case HloOpcode::kPower:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kSelect:
+ case HloOpcode::kSign:
+ case HloOpcode::kSin:
+ case HloOpcode::kSlice:
+ case HloOpcode::kSort:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kTanh:
+ case HloOpcode::kRng:
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kTranspose:
+ return kYellow;
+ case HloOpcode::kBitcast:
+ case HloOpcode::kTuple:
+ case HloOpcode::kTrace:
+ case HloOpcode::kGetTupleElement:
+ return kWhite;
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kCopy:
+ case HloOpcode::kDynamicSlice:
+ case HloOpcode::kDynamicUpdateSlice:
+ case HloOpcode::kPad:
+ case HloOpcode::kReshape:
+ case HloOpcode::kReverse:
+ case HloOpcode::kUpdate:
+ return kGreen;
+ case HloOpcode::kConvolution:
+ case HloOpcode::kDot:
+ return kDarkBlue;
+ case HloOpcode::kReducePrecision:
+ return kRed;
+ case HloOpcode::kParameter:
+ return kOrange;
+ case HloOpcode::kBatchNormTraining:
+ case HloOpcode::kBatchNormGrad:
+ case HloOpcode::kReduce:
+ case HloOpcode::kSelectAndScatter:
+ case HloOpcode::kReduceWindow:
+ return kPurple;
+ case HloOpcode::kMap:
+ case HloOpcode::kFusion:
+ return kGray;
+ case HloOpcode::kSend:
+ case HloOpcode::kRecv:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kCrossReplicaSum:
+ return kBrown;
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kWhile:
+ case HloOpcode::kCall:
+ return kDarkGreen;
+ case HloOpcode::kConstant:
+ LOG(FATAL) << "Constants don't get their own nodes in the graph.";
+ }
}
-)";
-string ComputationToDotGraph(const HloComputation& computation,
- const string& label, bool show_addresses,
- bool show_layouts,
- const HloExecutionProfile* hlo_execution_profile,
- const NodeFilter& filter) {
- string graph_label = StrCat(label, "<br/>", computation.name());
- if (hlo_execution_profile != nullptr) {
- auto cycles = hlo_execution_profile->total_cycles_executed(computation);
- Appendf(&graph_label, "<br/>total cycles = %lld (%s)", cycles,
- tensorflow::strings::HumanReadableNum(cycles).c_str());
- }
- string graph = Printf(
- R"(digraph G {
-rankdir=TB;
-compound=true;
-label=<<b>%s</b>>;
-labelloc=t;
-stylesheet="%s"
-)",
- graph_label.c_str(), dot_stylesheet);
+string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
+ // Give while loops a different shape so they're easier to pick out.
+ switch (instr->opcode()) {
+ case HloOpcode::kWhile:
+ return "ellipse";
+ default:
+ return "rect";
+ }
+}
- // Dump the subcomputations of each instruction that's shown and doesn't have
- // its operands omitted. If an instruction has just one subcomputation and
- // it's trivial, omit it: We'll display that subcomputation inlined into the
- // instruction's node when we draw it.
- std::unordered_set<const HloComputation*> computations_to_dump;
- for (const auto& instr : computation.instructions()) {
- if (!filter.Show(instr.get()) || filter.OmitOperands(instr.get())) {
- continue;
+string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
+ // If we have a parameter, put the param number in the name.
+ if (instr->opcode() == HloOpcode::kParameter) {
+ return Printf("<b>Parameter %lld</b>", instr->parameter_number());
+ }
+
+ // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
+ // an add instruction. In this case we render just the name.
+ if (tensorflow::StringPiece(instr->name())
+ .starts_with(StrCat("%", HloOpcodeString(instr->opcode())))) {
+ return Printf("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
+ }
+
+ // If the name does not contain the opcode, render both.
+ return Printf("<b>%s</b><br/>%s",
+ HtmlLikeStringSanitize(instr->ExtendedOpcodeStr()),
+ HtmlLikeStringSanitize(instr->name()));
+}
+
+string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
+ string opcode_specific_info = [&]() -> string {
+ switch (instr->opcode()) {
+ case HloOpcode::kRng:
+ return RandomDistribution_Name(instr->random_distribution());
+ case HloOpcode::kConvolution:
+ return StrCat(
+ HtmlLikeStringSanitize(
+ instr->ConvolutionDimensionNumbersToString()),
+ "<br/>",
+ HtmlLikeStringSanitize(window_util::ToString(instr->window())));
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kTranspose:
+ case HloOpcode::kReduce:
+ return Printf("dims={%s}", Join(instr->dimensions(), ","));
+ case HloOpcode::kGetTupleElement:
+ return Printf("index=%lld", instr->tuple_index());
+ case HloOpcode::kBatchNormTraining:
+ case HloOpcode::kBatchNormGrad:
+ return Printf("feature_index=%lld", instr->feature_index());
+ case HloOpcode::kCustomCall:
+ return Printf("custom_call_target=%s", instr->custom_call_target());
+ default:
+ return "";
}
- if (instr->opcode() == HloOpcode::kFusion) {
- computations_to_dump.insert(instr->fused_instructions_computation());
+ }();
+
+ std::vector<string> lines;
+ if (!opcode_specific_info.empty()) {
+ lines.push_back(opcode_specific_info);
+ }
+
+ // Some instructions have giant tuples as their shapes, so truncate the HLO's
+ // shape to kMaxShapeLen characters.
+ constexpr int kMaxShapeLen = 64;
+ string instr_shape = ShapeUtil::HumanString(instr->shape());
+ if (instr_shape.length() > kMaxShapeLen) {
+ instr_shape =
+ StrCat(tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3),
+ "...");
+ }
+ lines.push_back(instr_shape);
+
+ if (show_addresses_) {
+ lines.push_back(Printf("[%p]", instr));
+ }
+ if (show_layouts_ && LayoutUtil::HasLayout(instr->shape())) {
+ string layout_str;
+ if (ShapeUtil::IsTuple(instr->shape())) {
+ // For tuples, emit the full shape because the layout of a tuple is not
+ // represented in a single Layout field.
+ layout_str = ShapeUtil::HumanStringWithLayout(instr->shape());
+ } else {
+ layout_str = Join(instr->shape().layout().minor_to_major(), ",");
+ }
+ lines.push_back(Printf("layout={%s}", layout_str));
+ }
+ if (profile_ != nullptr) {
+ double hlo_cycles_executed = profile_->GetProfileResult(*instr);
+ double total_cycles_executed =
+ profile_->total_cycles_executed(*instr->parent());
+ if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
+ lines.push_back(
+ Printf("%% of cycles executed=%.2f",
+ 100 * hlo_cycles_executed / total_cycles_executed));
}
+ }
+ return Join(lines, "<br/>");
+}
- const auto& subcomputations = instr->called_computations();
- if (subcomputations.size() != 1 ||
- !MatchTrivialComputation(subcomputations.front())) {
- for (const HloComputation* computation : instr->called_computations()) {
- computations_to_dump.insert(computation);
- }
+void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
+ auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
+ int64 operand_num) {
+ // Fusion nodes' subcomputations are displayed inline, so if 'from' is a
+ // fusion node and the node's subcomputation is shown, we draw our edge
+ // starting at the fusion node's root instead of at the fusion node itself.
+ if (from->opcode() == HloOpcode::kFusion &&
+ filter_.ShowFusionSubcomputation(from)) {
+ from = from->fused_expression_root();
+ }
+ if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) {
+ return;
}
+ string edge = Printf("%s -> %s", InstructionId(from), InstructionId(to));
+ if (instr->operand_count() > 1) {
+ Appendf(&edge, R"( [headlabel="%lld",labeldistance=2])", operand_num);
+ }
+ StrAppend(&edge, ";");
+ edges_.push_back(edge);
+ };
+
+ // Add edges from instr's operands to instr. Parameters within fusion
+ // expressions are handled specially -- we draw an edge from the corresponding
+ // operand on the fusion node itself to the parameter.
+ if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
+ const HloInstruction* fusion = instr->fusion_instruction();
+ add_edge(fusion->operand(instr->parameter_number()), instr,
+ /*operand_num=*/0);
+ } else {
+ for (int64 i = 0; i < instr->operand_count(); ++i) {
+ add_edge(instr->operand(i), instr, i);
+ }
+ }
+}
+
+string HloDotDumper::GetInstructionTrivialComputationStr(
+ const HloInstruction* instr) {
+ // called_computations() on a fusion node "inherits" any called computations
+ // of the fused root, which isn't what we want. Just ignore fusion nodes
+ // here; they're handled separately.
+ if (instr->opcode() == HloOpcode::kFusion) {
+ return "";
}
- // Emit embedded computations as subgraph clusters.
- std::vector<string> intercomputation_edges;
- for (const HloComputation* embedded :
- computation.MakeEmbeddedComputationsList()) {
- if (!computations_to_dump.count(embedded)) {
+ std::vector<string> lines;
+ for (int64 i = 0; i < instr->called_computations().size(); ++i) {
+ optional<string> computation_type =
+ MatchTrivialComputation(instr->called_computations()[i]);
+ if (!computation_type) {
continue;
}
- // Don't pass our filter down into the subcomputation -- always render the
- // whole thing.
- string graph_body = InstructionSequenceGraph(
- embedded->instructions(), show_addresses, show_layouts,
- &intercomputation_edges, hlo_execution_profile, NodeFilter());
- Appendf(&graph,
- "subgraph cluster_%s "
- "{\nstyle=rounded;label=<<b>%s</b>>;labelloc=t;\n%s}\n",
- ComputationId(embedded).c_str(), embedded->name().c_str(),
- graph_body.c_str());
- }
- StrAppend(&graph,
- InstructionSequenceGraph(computation.instructions(), show_addresses,
- show_layouts, &intercomputation_edges,
- hlo_execution_profile, filter));
-
- // Edges between computations (subgraph clusters) must be emitted last for the
- // graph to be rendered properly for some reason.
- StrAppend(&graph, Join(intercomputation_edges, "\n"), "}\n");
-
- return graph;
+ if (instr->called_computations().size() == 1) {
+ lines.push_back(Printf("Subcomputation: <b>%s</b>",
+ HtmlLikeStringSanitize(*computation_type)));
+ } else {
+ lines.push_back(Printf("Subcomputation %lld: <b>%s</b>", i,
+ HtmlLikeStringSanitize(*computation_type)));
+ }
+ }
+ return Join(lines, "<br/>");
}
tensorflow::mutex& RendererMutex() {
@@ -750,14 +891,6 @@ class FileGraphRenderer : public GraphRendererInterface {
// Gets a NodeFilter that includes roughly all instructions whose distance from
// root is <= radius.
-//
-// It's confusing to draw a node and include only some of its operands. So if
-// some but not all of a node's operands are <= radius units away from the root,
-// we include the other operands (unless there are a lot of them, as often in a
-// tuple node). These additional operands may have as inputs other nodes
-// already present in the graph, but we don't draw those edges unless *all* of
-// the inputs are present. (Otherwise we'd have the same problem we were trying
-// to solve in the first place!)
NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
// First, find the neighborhood of nodes with distance from root <= radius.
// These nodes are our initial set of "normal" nodes.
@@ -788,14 +921,25 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
}
}
- // If you're looking at node X, it's probably not interesting that node Y
- // also happens to use the same constant, so we don't traverse into
- // constants' users.
- if (instr->opcode() != HloOpcode::kConstant) {
- for (const HloInstruction* user : instr->users()) {
- if (!nodes.count(user)) {
- worklist.push_back({user, depth + 1});
- }
+ // Traverse into instr's users, unless:
+ //
+ // - there are a ton of them, in which case they're probably not
+ // interesting (and anyway, rendering them all would make the graph
+ // unreadable), or
+ // - instr is a constant, in which case its users are probably not
+ // interesting.
+ if (instr->opcode() == HloOpcode::kConstant) {
+ continue;
+ }
+ constexpr int kMaxUsersToRender = 16;
+ if (instr->user_count() > kMaxUsersToRender) {
+ // If we're going to skip this node's users, style it as such.
+ nodes[instr] = kSomeUsersOmitted;
+ continue;
+ }
+ for (const HloInstruction* user : instr->users()) {
+ if (!nodes.count(user)) {
+ worklist.push_back({user, depth + 1});
}
}
}
@@ -804,43 +948,27 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
return nodes.count(instr) > 0;
};
- // If a node has some but not all of its operands omitted, add the operands to
- // the map with type kOmitNodeOperands. Unless the node has a lot of
- // operands, in which case just mark the node as "some operands omitted".
- std::vector<const HloInstruction*> extra_operands;
+ // Mark nodes which don't have all of their operands present as "some operands
+ // omitted".
for (auto& kv : nodes) {
const HloInstruction* instr = kv.first;
NodeFilterResult& filter_result = kv.second;
const auto& operands = instr->operands();
- // Mark nodes with many operands and some omitted as "some operands omitted"
- // and carry on -- don't add their omitted operands to extra_operands.
- if (operands.size() > 4) {
- if (std::any_of(operands.begin(), operands.end(), is_displayed) &&
- !std::all_of(operands.begin(), operands.end(), is_displayed)) {
- filter_result = kSomeOperandsOmitted;
- }
- continue;
- }
-
- if (std::any_of(operands.begin(), operands.end(), is_displayed)) {
- for (const HloInstruction* operand : operands) {
- if (!is_displayed(operand)) {
- extra_operands.push_back(operand);
- }
- }
+ // Mark nodes with some omitted as "some operands omitted".
+ if (std::any_of(operands.begin(), operands.end(), is_displayed) &&
+ !std::all_of(operands.begin(), operands.end(), is_displayed)) {
+ filter_result = kSomeOperandsOmitted;
}
}
- for (const HloInstruction* instr : extra_operands) {
- nodes[instr] = kOmitNodeOperands;
- }
- // Some of the nodes in extra_operands may now have all of their inputs
- // present in nodes. We can promote these to normal nodes.
- for (const HloInstruction* instr : extra_operands) {
- const auto& operands = instr->operands();
- if (std::all_of(operands.begin(), operands.end(), is_displayed)) {
- nodes[instr] = kNormalNode;
+ // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
+ // users made it into the graph by other means.
+ for (auto& kv : nodes) {
+ const auto& users = kv.first->users();
+ if (kv.second == kSomeUsersOmitted &&
+ std::all_of(users.begin(), users.end(), is_displayed)) {
+ kv.second = kNormalNode;
}
}
@@ -862,6 +990,10 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
if (it != nodes.end()) {
return it->second;
}
+ // Show all nodes in subcomputations.
+ if (instr->parent() != root->parent()) {
+ return kNormalNode;
+ }
return kHideNode;
});
}
@@ -886,10 +1018,12 @@ string DumpGraph(const HloComputation& computation, const string& label,
graph_url = FileGraphRenderer().RenderGraph(
graph, GraphRendererInterface::TF_GRAPHDEF, debug_options);
} else {
- graph = ComputationToDotGraph(computation, label,
- debug_options.xla_hlo_graph_addresses(),
- debug_options.xla_hlo_graph_layout(),
- hlo_execution_profile, NodeFilter());
+ graph =
+ HloDotDumper(&computation, label,
+ /*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
+ /*show_layouts=*/debug_options.xla_hlo_graph_layout(),
+ hlo_execution_profile, NodeFilter())
+ .Dump();
graph_url = GetGraphRenderer()->RenderGraph(
graph, GraphRendererInterface::DOT_GRAPH, debug_options);
}
@@ -903,11 +1037,12 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius) {
string label =
StrCat("Neighborhood of ", radius, " nodes around ", node.name());
NodeFilter filter = MakeNodeFilter(&node, radius);
- string graph = ComputationToDotGraph(
- *node.parent(), label,
- /*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
- /*show_layouts=*/debug_options.xla_hlo_graph_layout(),
- /*hlo_execution_profile=*/nullptr, filter);
+ string graph =
+ HloDotDumper(node.parent(), label,
+ /*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
+ /*show_layouts=*/debug_options.xla_hlo_graph_layout(),
+ /*profile=*/nullptr, filter)
+ .Dump();
return GetGraphRenderer()->RenderGraph(
graph, GraphRendererInterface::DOT_GRAPH, debug_options);
}
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 358e611d57..8a6376b2d1 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -112,6 +112,11 @@ bool HloOpcodeIsComparison(HloOpcode opcode);
// Returns true iff the given opcode has variadic operands.
bool HloOpcodeIsVariadic(HloOpcode opcode);
+// Returns the number of HloOpcode values.
+inline const uint32_t HloOpcodeCount() {
+ return static_cast<uint32_t>(HloOpcode::kWhile) + 1;
+}
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc
index dafefdc491..e083226b14 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -30,14 +31,15 @@ StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
for (auto& instruction : computation->instructions()) {
VLOG(3) << "Visited instruction: " << instruction->ToString();
- // For now, ReducePrecision is only implemented for F32 data, so this
+ // For now, ReducePrecision is only implemented for F32 arrays, so this
// ignore instructions that produce other data. In particular, this
// currently ignores instructions producing tuples, even if those tuples
- // contain F32 data inside them. The assumption is that in most cases
+ // contain F32 arrays inside them. The assumption is that in most cases
// equivalent behavior can be obtained by adding ReducePrecision
- // instructions after the instructions that pull the F32 data out of the
- // tuples.
+ // instructions after the instructions that pull the F32 arrays out of
+ // the tuples.
if (instruction->shape().element_type() == PrimitiveType::F32 &&
+ !ShapeUtil::IsScalar(instruction->shape()) &&
should_reduce_output_precision_(instruction->opcode())) {
instructions_to_suffix.push_back(instruction.get());
}
@@ -58,4 +60,33 @@ StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
return changed;
}
+ReducePrecisionInsertion::OpcodeFilterFunction
+ReducePrecisionInsertion::make_filter_function(
+ const HloReducePrecisionOptions& reduce_precision_options) {
+ // Implement the filter function with a lookup table.
+ std::vector<bool> filter(HloOpcodeCount(), false);
+ for (const auto& opcode : reduce_precision_options.opcodes_to_suffix()) {
+ filter[opcode] = true;
+ }
+ return [filter](const HloOpcode opcode) {
+ return filter[static_cast<unsigned int>(opcode)];
+ };
+}
+
+HloReducePrecisionOptions ReducePrecisionInsertion::make_options_proto(
+ const HloReducePrecisionOptions::PassTiming pass_timing,
+ const int exponent_bits, const int mantissa_bits,
+ const OpcodeFilterFunction& should_reduce_output_precision) {
+ HloReducePrecisionOptions options;
+ options.set_pass_timing(pass_timing);
+ options.set_exponent_bits(exponent_bits);
+ options.set_mantissa_bits(mantissa_bits);
+ for (uint32_t opcode = 0; opcode < HloOpcodeCount(); opcode++) {
+ if (should_reduce_output_precision(static_cast<HloOpcode>(opcode))) {
+ options.add_opcodes_to_suffix(opcode);
+ }
+ }
+ return options;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
index e9c8bba031..34b865b9ce 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
@@ -42,6 +42,17 @@ class ReducePrecisionInsertion : public HloPassInterface {
: exponent_bits_(exponent_bits),
mantissa_bits_(mantissa_bits),
should_reduce_output_precision_(should_reduce_output_precision) {}
+
+ // Version of the constructor that takes an HloReducePrecisionOptions proto
+ // rather than explicitly-enumerated parameters, for convenience when
+ // creating passes based on DebugOptions.
+ explicit ReducePrecisionInsertion(
+ const HloReducePrecisionOptions& reduce_precision_options)
+ : exponent_bits_(reduce_precision_options.exponent_bits()),
+ mantissa_bits_(reduce_precision_options.mantissa_bits()),
+ should_reduce_output_precision_(
+ make_filter_function(reduce_precision_options)) {}
+
~ReducePrecisionInsertion() override{};
tensorflow::StringPiece name() const override {
@@ -52,6 +63,15 @@ class ReducePrecisionInsertion : public HloPassInterface {
// (reduce-precision instructions were inserted).
StatusOr<bool> Run(HloModule* module) override;
+ // Convert between the (inconvenient) xla.proto HloReducePrecisionOptions
+ // representation and OpcodeFilterFunction functions.
+ static OpcodeFilterFunction make_filter_function(
+ const HloReducePrecisionOptions& reduce_precision_options);
+ static HloReducePrecisionOptions make_options_proto(
+ const HloReducePrecisionOptions::PassTiming pass_timing,
+ const int exponent_bits, const int mantissa_bits,
+ const OpcodeFilterFunction& should_reduce_output_precision);
+
private:
// Parameters for the precision reduction to be added.
const int exponent_bits_;
@@ -59,7 +79,7 @@ class ReducePrecisionInsertion : public HloPassInterface {
// Function to determine (from the opcode) whether a given instruction should
// have a reduce-precision instruction inserted in its output stream.
- const OpcodeFilterFunction& should_reduce_output_precision_;
+ const OpcodeFilterFunction should_reduce_output_precision_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/statusor.cc b/tensorflow/compiler/xla/statusor.cc
index 36f08fc99f..72ab67ff81 100644
--- a/tensorflow/compiler/xla/statusor.cc
+++ b/tensorflow/compiler/xla/statusor.cc
@@ -19,28 +19,20 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace xla {
-namespace internal {
+namespace internal_statusor {
-Status StatusOrHelper::HandleInvalidStatusCtorArg() {
+void Helper::HandleInvalidStatusCtorArg(Status* status) {
const char* kMessage =
- "Status::OK is not a valid constructor argument to StatusOr<T>";
+ "An OK status is not a valid constructor argument to StatusOr<T>";
LOG(ERROR) << kMessage;
- // In optimized builds, we will fall back to tensorflow::error::INTERNAL.
- return Status(tensorflow::error::INTERNAL, kMessage);
+ // Fall back to tensorflow::error::INTERNAL.
+ *status = ::tensorflow::errors::Internal(kMessage);
}
-Status StatusOrHelper::HandleNullObjectCtorArg() {
- const char* kMessage =
- "NULL is not a valid constructor argument to StatusOr<T*>";
- LOG(ERROR) << kMessage;
- // In optimized builds, we will fall back to tensorflow::error::INTERNAL.
- return Status(tensorflow::error::INTERNAL, kMessage);
-}
-
-void StatusOrHelper::Crash(const Status& status) {
+void Helper::Crash(const Status& status) {
LOG(FATAL) << "Attempting to fetch value instead of handling error "
<< status;
}
-} // namespace internal
+} // namespace internal_statusor
} // namespace xla
diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h
index d8cd736238..92bcfa0f44 100644
--- a/tensorflow/compiler/xla/statusor.h
+++ b/tensorflow/compiler/xla/statusor.h
@@ -72,216 +72,233 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_STATUSOR_H_
#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor_internals.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
#if defined(__clang__)
// Only clang supports warn_unused_result as a type annotation.
-template <typename T, bool CopyConstructible>
+template <typename T>
class TF_MUST_USE_RESULT StatusOr;
#endif
-template <typename T,
- bool CopyConstructible = std::is_copy_constructible<T>::value>
-class StatusOr {
- template <typename U, bool UC>
+template <typename T>
+class StatusOr : private internal_statusor::StatusOrData<T>,
+ private internal_statusor::TraitsBase<
+ std::is_copy_constructible<T>::value,
+ std::is_move_constructible<T>::value> {
+ template <typename U>
friend class StatusOr;
+ typedef internal_statusor::StatusOrData<T> Base;
+
public:
typedef T element_type;
- // Construct a new StatusOr with Status::UNKNOWN status
- StatusOr();
+ // Constructs a new StatusOr with Status::UNKNOWN status. This is marked
+ // 'explicit' to try to catch cases like 'return {};', where people think
+ // StatusOr<std::vector<int>> will be initialized with an empty vector,
+ // instead of a Status::UNKNOWN status.
+ explicit StatusOr();
+
+ // StatusOr<T> will be copy constructuble/assignable if T is copy
+ // constructible.
+ StatusOr(const StatusOr&) = default;
+ StatusOr& operator=(const StatusOr&) = default;
+
+ // StatusOr<T> will be move constructuble/assignable if T is move
+ // constructible.
+ StatusOr(StatusOr&&) = default;
+ StatusOr& operator=(StatusOr&&) = default;
+
+ // Conversion copy/move constructor, T must be convertible from U.
+ // TODO(b/62186717): These should not participate in overload resolution if U
+ // is not convertible to T.
+ template <typename U>
+ StatusOr(const StatusOr<U>& other);
+ template <typename U>
+ StatusOr(StatusOr<U>&& other);
- // Construct a new StatusOr with the given non-ok status. After calling
- // this constructor, calls to ValueOrDie() will CHECK-fail.
- //
- // NOTE: Not explicit - we want to use StatusOr<T> as a return
- // value, so it is convenient and sensible to be able to do 'return
- // Status()' when the return type is StatusOr<T>.
- //
- // REQUIRES: status != Status::OK. This requirement is DCHECKed.
- // In optimized builds, passing Status::OK here will have the effect
- // of passing tensorflow::error::INTERNAL as a fallback.
- StatusOr(Status status); // NOLINT
+ // Conversion copy/move assignment operator, T must be convertible from U.
+ template <typename U>
+ StatusOr& operator=(const StatusOr<U>& other);
+ template <typename U>
+ StatusOr& operator=(StatusOr<U>&& other);
- // Construct a new StatusOr with the given value. If T is a plain pointer,
- // value must not be NULL. After calling this constructor, calls to
- // ValueOrDie() will succeed, and calls to status() will return OK.
+ // Constructs a new StatusOr with the given value. After calling this
+ // constructor, calls to ValueOrDie() will succeed, and calls to status() will
+ // return OK.
//
// NOTE: Not explicit - we want to use StatusOr<T> as a return type
// so it is convenient and sensible to be able to do 'return T()'
// when the return type is StatusOr<T>.
//
- // REQUIRES: if T is a plain pointer, value != NULL. This requirement is
- // DCHECKed. In optimized builds, passing a NULL pointer here will have
- // the effect of passing tensorflow::error::INTERNAL as a fallback.
- StatusOr(const T& value); // NOLINT
-
- // Copy constructor.
- StatusOr(const StatusOr& other) = default;
-
- // Conversion copy constructor, T must be copy constructible from U
- template <typename U>
- StatusOr(const StatusOr<U>& other);
-
- // Assignment operator.
- StatusOr& operator=(const StatusOr& other) = default;
+ // REQUIRES: T is copy constructible.
+ StatusOr(const T& value);
- // Conversion assignment operator, T must be assignable from U
- template <typename U>
- StatusOr& operator=(const StatusOr<U>& other);
+ // Constructs a new StatusOr with the given non-ok status. After calling
+ // this constructor, calls to ValueOrDie() will CHECK-fail.
+ //
+ // NOTE: Not explicit - we want to use StatusOr<T> as a return
+ // value, so it is convenient and sensible to be able to do 'return
+ // Status()' when the return type is StatusOr<T>.
+ //
+ // REQUIRES: !status.ok(). This requirement is DCHECKed.
+ // In optimized builds, passing Status::OK() here will have the effect
+ // of passing tensorflow::error::INTERNAL as a fallback.
+ StatusOr(const Status& status);
+ StatusOr& operator=(const Status& status);
- // Move constructor and move-assignment operator.
- StatusOr(StatusOr&& other) = default;
- StatusOr& operator=(StatusOr&& other) = default;
+ // TODO(b/62186997): Add operator=(T) overloads.
- // Rvalue-reference overloads of the other constructors and assignment
- // operators, to support move-only types and avoid unnecessary copying.
+ // Similar to the `const T&` overload.
//
- // Implementation note: we could avoid all these rvalue-reference overloads
- // if the existing lvalue-reference overloads took their arguments by value
- // instead. I think this would also let us omit the conversion assignment
- // operator altogether, since we'd get the same functionality for free
- // from the implicit conversion constructor and ordinary assignment.
- // However, this could result in extra copy operations unless we use
- // std::move to avoid them, and we can't use std::move because this code
- // needs to be portable to C++03.
- StatusOr(T&& value); // NOLINT
- template <typename U>
- StatusOr(StatusOr<U>&& other);
+ // REQUIRES: T is move constructible.
+ StatusOr(T&& value);
- // Returns a reference to our status. If this contains a T, then
- // returns Status::OK.
- const Status& status() const { return status_; }
+ // RValue versions of the operations declared above.
+ StatusOr(Status&& status);
+ StatusOr& operator=(Status&& status);
// Returns this->status().ok()
- bool ok() const { return status_.ok(); }
+ bool ok() const { return this->status_.ok(); }
+
+ // Returns a reference to our status. If this contains a T, then
+ // returns Status::OK().
+ const Status& status() const &;
+ Status status() &&;
// Returns a reference to our current value, or CHECK-fails if !this->ok().
- const T& ValueOrDie() const;
- T& ValueOrDie();
+ //
+ // Note: for value types that are cheap to copy, prefer simple code:
+ //
+ // T value = statusor.ValueOrDie();
+ //
+ // Otherwise, if the value type is expensive to copy, but can be left
+ // in the StatusOr, simply assign to a reference:
+ //
+ // T& value = statusor.ValueOrDie(); // or `const T&`
+ //
+ // Otherwise, if the value type supports an efficient move, it can be
+ // used as follows:
+ //
+ // T value = std::move(statusor).ValueOrDie();
+ //
+ // The std::move on statusor instead of on the whole expression enables
+ // warnings about possible uses of the statusor object after the move.
+ // C++ style guide waiver for ref-qualified overloads granted in cl/143176389
+ // See go/ref-qualifiers for more details on such overloads.
+ const T& ValueOrDie() const &;
+ T& ValueOrDie() &;
+ const T&& ValueOrDie() const &&;
+ T&& ValueOrDie() &&;
- // Moves our current value out of this object and returns it, or CHECK-fails
- // if !this->ok().
- // Use of this method is discouraged; prefer std::move(statusor.ValueOrDie())
- // instead.
T ConsumeValueOrDie() { return std::move(ValueOrDie()); }
- private:
- Status status_;
- T value_;
-};
-
-// Partial specialization for when T is not copy-constructible. This uses all
-// methods from the core implementation, but removes copy assignment and copy
-// construction.
-template <typename T>
-class StatusOr<T, false> : public StatusOr<T, true> {
- public:
- // Remove copies.
- StatusOr(const StatusOr& other) = delete;
- StatusOr& operator=(const StatusOr& other) = delete;
- template <typename U>
- StatusOr(const StatusOr<U>& other) = delete;
- StatusOr(const T& value) = delete;
-
- // Use the superclass version for other constructors and operators.
- StatusOr() = default;
- StatusOr(StatusOr&& other) = default;
- StatusOr& operator=(StatusOr&& other) = default;
- StatusOr(T&& value) // NOLINT
- : StatusOr<T, true>::StatusOr(std::move(value)) {}
- StatusOr(Status status) // NOLINT
- : StatusOr<T, true>::StatusOr(std::move(status)) {}
- template <typename U>
- StatusOr(StatusOr<U>&& other) // NOLINT
- : StatusOr<T, true>::StatusOr(std::move(other)) {}
+ // Ignores any errors. This method does nothing except potentially suppress
+ // complaints from any tools that are checking that errors are not dropped on
+ // the floor.
+ void IgnoreError() const;
};
////////////////////////////////////////////////////////////////////////////////
// Implementation details for StatusOr<T>
-namespace internal {
+template <typename T>
+StatusOr<T>::StatusOr() : Base(Status(tensorflow::error::UNKNOWN, "")) {}
-class StatusOrHelper {
- public:
- // Move type-agnostic error handling to the .cc.
- static Status HandleInvalidStatusCtorArg();
- static Status HandleNullObjectCtorArg();
- static void Crash(const Status& status);
-
- // Customized behavior for StatusOr<T> vs. StatusOr<T*>
- template <typename T>
- struct Specialize;
-};
+template <typename T>
+StatusOr<T>::StatusOr(const T& value) : Base(value) {}
template <typename T>
-struct StatusOrHelper::Specialize {
- // For non-pointer T, a reference can never be NULL.
- static inline bool IsValueNull(const T& t) { return false; }
-};
+StatusOr<T>::StatusOr(const Status& status) : Base(status) {}
template <typename T>
-struct StatusOrHelper::Specialize<T*> {
- static inline bool IsValueNull(const T* t) { return t == NULL; }
-};
+StatusOr<T>& StatusOr<T>::operator=(const Status& status) {
+ this->Assign(status);
+ return *this;
+}
-} // namespace internal
+template <typename T>
+StatusOr<T>::StatusOr(T&& value) : Base(std::move(value)) {}
-template <typename T, bool CopyConstructible>
-inline StatusOr<T, CopyConstructible>::StatusOr()
- : status_(tensorflow::error::UNKNOWN, "") {}
+template <typename T>
+StatusOr<T>::StatusOr(Status&& status) : Base(std::move(status)) {}
-template <typename T, bool CopyConstructible>
-inline StatusOr<T, CopyConstructible>::StatusOr(Status status)
- : status_(std::move(status)) {
- if (status_.ok()) {
- status_ = internal::StatusOrHelper::HandleInvalidStatusCtorArg();
- }
+template <typename T>
+StatusOr<T>& StatusOr<T>::operator=(Status&& status) {
+ this->Assign(std::move(status));
+ return *this;
}
-template <typename T, bool CopyConstructible>
-inline StatusOr<T, CopyConstructible>::StatusOr(const T& value)
- : value_(value) {
- if (internal::StatusOrHelper::Specialize<T>::IsValueNull(value)) {
- status_ = internal::StatusOrHelper::HandleNullObjectCtorArg();
- }
-}
+template <typename T>
+template <typename U>
+inline StatusOr<T>::StatusOr(const StatusOr<U>& other)
+ : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {}
-template <typename T, bool CopyConstructible>
+template <typename T>
template <typename U>
-inline StatusOr<T, CopyConstructible>::StatusOr(const StatusOr<U>& other)
- : status_(other.status_), value_(other.value_) {}
-
-template <typename T, bool CopyConstructible>
-inline StatusOr<T, CopyConstructible>::StatusOr(T&& value)
- : value_(std::move(value)) {
- if (internal::StatusOrHelper::Specialize<T>::IsValueNull(value_)) {
- status_ = internal::StatusOrHelper::HandleNullObjectCtorArg();
- }
+inline StatusOr<T>& StatusOr<T>::operator=(const StatusOr<U>& other) {
+ if (other.ok())
+ this->Assign(other.ValueOrDie());
+ else
+ this->Assign(other.status());
+ return *this;
}
-template <typename T, bool CopyConstructible>
+template <typename T>
template <typename U>
-inline StatusOr<T, CopyConstructible>::StatusOr(StatusOr<U>&& other)
- : status_(std::move(other.status_)), value_(std::move(other.value_)) {}
+inline StatusOr<T>::StatusOr(StatusOr<U>&& other)
+ : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {}
-template <typename T, bool CopyConstructible>
-inline const T& StatusOr<T, CopyConstructible>::ValueOrDie() const {
- if (!ok()) {
- internal::StatusOrHelper::Crash(status());
+template <typename T>
+template <typename U>
+inline StatusOr<T>& StatusOr<T>::operator=(StatusOr<U>&& other) {
+ if (other.ok()) {
+ this->Assign(std::move(other).ValueOrDie());
+ } else {
+ this->Assign(std::move(other).status());
}
- return value_;
+ return *this;
}
-template <typename T, bool CopyConstructible>
-inline T& StatusOr<T, CopyConstructible>::ValueOrDie() {
- if (!status_.ok()) {
- internal::StatusOrHelper::Crash(status());
- }
- return value_;
+template <typename T>
+const Status& StatusOr<T>::status() const & {
+ return this->status_;
+}
+template <typename T>
+Status StatusOr<T>::status() && {
+ return ok() ? Status::OK() : std::move(this->status_);
+}
+
+template <typename T>
+const T& StatusOr<T>::ValueOrDie() const & {
+ this->EnsureOk();
+ return this->data_;
+}
+
+template <typename T>
+T& StatusOr<T>::ValueOrDie() & {
+ this->EnsureOk();
+ return this->data_;
+}
+
+template <typename T>
+const T&& StatusOr<T>::ValueOrDie() const && {
+ this->EnsureOk();
+ return std::move(this->data_);
+}
+
+template <typename T>
+T&& StatusOr<T>::ValueOrDie() && {
+ this->EnsureOk();
+ return std::move(this->data_);
+}
+
+template <typename T>
+void StatusOr<T>::IgnoreError() const {
+ // no-op
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/statusor_internals.h b/tensorflow/compiler/xla/statusor_internals.h
new file mode 100644
index 0000000000..a2fda5bb3c
--- /dev/null
+++ b/tensorflow/compiler/xla/statusor_internals.h
@@ -0,0 +1,245 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_
+
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+namespace internal_statusor {
+
+class Helper {
+ public:
+ // Move type-agnostic error handling to the .cc.
+ static void HandleInvalidStatusCtorArg(Status*);
+ TF_ATTRIBUTE_NORETURN static void Crash(const Status& status);
+};
+
+// Construct an instance of T in `p` through placement new, passing Args... to
+// the constructor.
+// This abstraction is here mostly for the gcc performance fix.
+template <typename T, typename... Args>
+void PlacementNew(void* p, Args&&... args) {
+#if defined(__GNUC__) && !defined(__clang__)
+ // Teach gcc that 'p' cannot be null, fixing code size issues.
+ if (p == nullptr) __builtin_unreachable();
+#endif
+ new (p) T(std::forward<Args>(args)...);
+}
+
+// Helper base class to hold the data and all operations.
+// We move all this to a base class to allow mixing with the appropriate
+// TraitsBase specialization.
+template <typename T>
+class StatusOrData {
+ template <typename U>
+ friend class StatusOrData;
+
+ public:
+ StatusOrData() = delete;
+
+ StatusOrData(const StatusOrData& other) {
+ if (other.ok()) {
+ MakeValue(other.data_);
+ MakeStatus();
+ } else {
+ MakeStatus(other.status_);
+ }
+ }
+
+ StatusOrData(StatusOrData&& other) noexcept {
+ if (other.ok()) {
+ MakeValue(std::move(other.data_));
+ MakeStatus();
+ } else {
+ MakeStatus(std::move(other.status_));
+ }
+ }
+
+ template <typename U>
+ StatusOrData(const StatusOrData<U>& other) {
+ if (other.ok()) {
+ MakeValue(other.data_);
+ MakeStatus();
+ } else {
+ MakeStatus(other.status_);
+ }
+ }
+
+ template <typename U>
+ StatusOrData(StatusOrData<U>&& other) {
+ if (other.ok()) {
+ MakeValue(std::move(other.data_));
+ MakeStatus();
+ } else {
+ MakeStatus(std::move(other.status_));
+ }
+ }
+
+ explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); }
+ explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); }
+
+ explicit StatusOrData(const Status& status) : status_(status) {
+ EnsureNotOk();
+ }
+ explicit StatusOrData(Status&& status) : status_(std::move(status)) {
+ EnsureNotOk();
+ }
+
+ StatusOrData& operator=(const StatusOrData& other) {
+ if (this == &other) return *this;
+ if (other.ok())
+ Assign(other.data_);
+ else
+ Assign(other.status_);
+ return *this;
+ }
+
+ StatusOrData& operator=(StatusOrData&& other) {
+ if (this == &other) return *this;
+ if (other.ok())
+ Assign(std::move(other.data_));
+ else
+ Assign(std::move(other.status_));
+ return *this;
+ }
+
+ ~StatusOrData() {
+ if (ok()) {
+ status_.~Status();
+ data_.~T();
+ } else {
+ status_.~Status();
+ }
+ }
+
+ void Assign(const T& value) {
+ if (ok()) {
+ data_.~T();
+ MakeValue(value);
+ } else {
+ MakeValue(value);
+ status_ = Status::OK();
+ }
+ }
+
+ void Assign(T&& value) {
+ if (ok()) {
+ data_.~T();
+ MakeValue(std::move(value));
+ } else {
+ MakeValue(std::move(value));
+ status_ = Status::OK();
+ }
+ }
+
+ void Assign(const Status& status) {
+ Clear();
+ status_ = status;
+ EnsureNotOk();
+ }
+
+ void Assign(Status&& status) {
+ Clear();
+ status_ = std::move(status);
+ EnsureNotOk();
+ }
+
+ bool ok() const { return status_.ok(); }
+
+ protected:
+ // status_ will always be active after the constructor.
+ // We make it a union to be able to initialize exactly how we need without
+ // waste.
+ // Eg. in the copy constructor we use the default constructor of Status in
+ // the ok() path to avoid an extra Ref call.
+ union {
+ Status status_;
+ };
+
+ // data_ is active iff status_.ok()==true
+ struct Dummy {};
+ union {
+ // When T is const, we need some non-const object we can cast to void* for
+ // the placement new. dummy_ is that object.
+ Dummy dummy_;
+ T data_;
+ };
+
+ void Clear() {
+ if (ok()) data_.~T();
+ }
+
+ void EnsureOk() const {
+ if (!ok()) Helper::Crash(status_);
+ }
+
+ void EnsureNotOk() {
+ if (ok()) Helper::HandleInvalidStatusCtorArg(&status_);
+ }
+
+ // Construct the value (ie. data_) through placement new with the passed
+ // argument.
+ template <typename Arg>
+ void MakeValue(Arg&& arg) {
+ internal_statusor::PlacementNew<T>(&dummy_, std::forward<Arg>(arg));
+ }
+
+ // Construct the status (ie. status_) through placement new with the passed
+ // argument.
+ template <typename... Args>
+ void MakeStatus(Args&&... args) {
+ internal_statusor::PlacementNew<Status>(&status_,
+ std::forward<Args>(args)...);
+ }
+};
+
+// Helper base class to allow implicitly deleted constructors and assignment
+// operations in StatusOr.
+// TraitsBase will explicitly delete what it can't support and StatusOr will
+// inherit that behavior implicitly.
+template <bool Copy, bool Move>
+struct TraitsBase {
+ TraitsBase() = default;
+ TraitsBase(const TraitsBase&) = default;
+ TraitsBase(TraitsBase&&) = default;
+ TraitsBase& operator=(const TraitsBase&) = default;
+ TraitsBase& operator=(TraitsBase&&) = default;
+};
+
+template <>
+struct TraitsBase<false, true> {
+ TraitsBase() = default;
+ TraitsBase(const TraitsBase&) = delete;
+ TraitsBase(TraitsBase&&) = default;
+ TraitsBase& operator=(const TraitsBase&) = delete;
+ TraitsBase& operator=(TraitsBase&&) = default;
+};
+
+template <>
+struct TraitsBase<false, false> {
+ TraitsBase() = default;
+ TraitsBase(const TraitsBase&) = delete;
+ TraitsBase(TraitsBase&&) = delete;
+ TraitsBase& operator=(const TraitsBase&) = delete;
+ TraitsBase& operator=(TraitsBase&&) = delete;
+};
+
+} // namespace internal_statusor
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_
diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc
index f8555113f8..5fa2211ac6 100644
--- a/tensorflow/compiler/xla/statusor_test.cc
+++ b/tensorflow/compiler/xla/statusor_test.cc
@@ -29,8 +29,6 @@ limitations under the License.
namespace xla {
namespace {
-using tensorflow::Status;
-
class Base1 {
public:
virtual ~Base1() {}
@@ -59,6 +57,14 @@ class CopyNoAssign {
const CopyNoAssign& operator=(const CopyNoAssign&);
};
+class NoDefaultConstructor {
+ public:
+ explicit NoDefaultConstructor(int foo);
+};
+
+static_assert(!std::is_default_constructible<NoDefaultConstructor>(),
+ "Should not be default-constructible.");
+
StatusOr<std::unique_ptr<int>> ReturnUniquePtr() {
// Uses implicit constructor from T&&
return std::unique_ptr<int>(new int(0));
@@ -69,6 +75,18 @@ TEST(StatusOr, ElementType) {
static_assert(std::is_same<StatusOr<char>::element_type, char>(), "");
}
+TEST(StatusOr, TestNoDefaultConstructorInitialization) {
+ // Explicitly initialize it with an error code.
+ StatusOr<NoDefaultConstructor> statusor(tensorflow::errors::Cancelled(""));
+ EXPECT_FALSE(statusor.ok());
+ EXPECT_EQ(statusor.status().code(), tensorflow::error::CANCELLED);
+
+ // Default construction of StatusOr initializes it with an UNKNOWN error code.
+ StatusOr<NoDefaultConstructor> statusor2;
+ EXPECT_FALSE(statusor2.ok());
+ EXPECT_EQ(statusor2.status().code(), tensorflow::error::UNKNOWN);
+}
+
TEST(StatusOr, TestMoveOnlyInitialization) {
StatusOr<std::unique_ptr<int>> thing(ReturnUniquePtr());
ASSERT_TRUE(thing.ok());
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index 074e28cec7..d692a81032 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -308,6 +308,137 @@ XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) {
ErrorSpec(0.01, 1));
}
+// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
+XLA_TEST_P(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(
+ DISABLED_ON_GPU(RandomizedGradTests)))) {
+ float epsilon = 0.001;
+ ComputationBuilder builder(client_, TestName());
+ const std::vector<int64>& bounds = GetParam().bounds;
+ Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
+ input_array.FillRandom(GetParam().random_value_var,
+ GetParam().random_value_mean);
+
+ Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]);
+ grad_output_array.FillRandom(GetParam().random_value_var,
+ GetParam().random_value_mean);
+
+ const int64 feature_index = GetParam().feature_index;
+ const int64 num_elements_per_feature =
+ Product(bounds) / bounds[feature_index];
+ const int64 feature_bound = bounds[feature_index];
+ std::vector<float> scale(feature_bound, 2);
+
+ auto input_squared =
+ ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
+ std::vector<int64> reduce_dims;
+ for (int64 i = 0; i < bounds.size(); ++i) {
+ if (i != feature_index) {
+ reduce_dims.push_back(i);
+ }
+ }
+
+ auto sum =
+ ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
+ [](float a, float b) { return a + b; });
+
+ auto sum_squared =
+ ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
+ [](float a, float b) { return a + b; });
+
+ std::vector<float> mean(feature_bound);
+
+ for (int64 i = 0; i < feature_bound; ++i) {
+ mean[i] = sum[i] / num_elements_per_feature;
+ }
+
+ std::vector<float> mean_square(feature_bound);
+ for (int64 i = 0; i < feature_bound; ++i) {
+ mean_square[i] = mean[i] * mean[i];
+ }
+
+ std::vector<float> square_mean(feature_bound);
+ for (int64 i = 0; i < feature_bound; ++i) {
+ square_mean[i] = sum_squared[i] / num_elements_per_feature;
+ }
+
+ std::vector<float> var(feature_bound);
+ for (int64 i = 0; i < feature_bound; ++i) {
+ var[i] = square_mean[i] - mean_square[i];
+ }
+
+ Array4D<float> mean_4D =
+ *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
+ auto var_4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
+ auto scale_4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
+
+ auto var_add_epsilon = *ReferenceUtil::MapArray4D(
+ var_4D, [epsilon](float a) { return std::sqrt(a + epsilon); });
+
+ auto grad_output_times_var =
+ *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
+ [](float a, float b) { return a * b; });
+
+ auto grad_activation = *ReferenceUtil::MapArray4D(
+ grad_output_times_var, scale_4D, [](float a, float b) { return a * b; });
+
+ auto activation_shifted = *ReferenceUtil::MapArray4D(
+ input_array, mean_4D, [](float a, float b) { return a - b; });
+
+ auto grad_scale_before_reduction =
+ *ReferenceUtil::MapArray4D(grad_output_times_var, activation_shifted,
+ [](float a, float b) { return a * b; });
+
+ auto grad_scale = ReferenceUtil::Reduce4DTo1D(
+ grad_scale_before_reduction, /*init=*/0.0f, reduce_dims,
+ [](float a, float b) { return a + b; });
+
+ auto grad_offset =
+ ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims,
+ [](float a, float b) { return a + b; });
+
+ auto expected_grad_activation =
+ Literal::CreateR4FromArray4D<float>(grad_activation);
+
+ auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
+ auto scale_literal = Literal::CreateR1<float>(scale);
+ auto mean_literal = Literal::CreateR1<float>(mean);
+ auto var_literal = Literal::CreateR1<float>(var);
+ auto grad_output_literal =
+ Literal::CreateR4FromArray4D<float>(grad_output_array);
+
+ auto input_parameter = builder.Parameter(0, input_literal->shape(), "input");
+ auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale");
+ auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean");
+ auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance");
+ auto grad_output_parameter =
+ builder.Parameter(4, grad_output_literal->shape(), "grad_output");
+
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> scale_data =
+ client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> mean_data =
+ client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> var_data =
+ client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> grad_output_data =
+ client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
+
+ auto t = builder.BatchNormGrad(input_parameter, scale_parameter,
+ mean_parameter, var_parameter,
+ grad_output_parameter, epsilon, feature_index);
+
+ auto expected =
+ *Literal::MakeTuple({expected_grad_activation.get(),
+ Literal::CreateR1<float>(grad_scale).get(),
+ Literal::CreateR1<float>(grad_offset).get()});
+
+ ComputeAndCompareTuple(&builder, expected,
+ {input_data.get(), scale_data.get(), mean_data.get(),
+ var_data.get(), grad_output_data.get()},
+ ErrorSpec(0.01, 1));
+}
+
INSTANTIATE_TEST_CASE_P(
BatchNormTest_Instantiation, BatchNormTest,
::testing::Values(BatchNormTestParam{{2, 2, 2, 2}, 0, 100.2f, 200.0f},
@@ -319,6 +450,7 @@ INSTANTIATE_TEST_CASE_P(
BatchNormTestParam{{10, 10, 10, 10}, 1, -666.6f, 777.7f},
BatchNormTestParam{{10, 10, 10, 10}, 2, 0.f, 777.7f},
BatchNormTestParam{{1, 1, 10, 130}, 2, 0.f, 777.7f},
+ BatchNormTestParam{{1, 1, 130, 11}, 2, 0.f, 777.7f},
BatchNormTestParam{{1, 1, 10, 1}, 3, 888.8f, 9.9f},
BatchNormTestParam{{24, 129, 1, 2}, 2, 10000, 10000},
@@ -446,6 +578,37 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(LargeEpsilonTest)) {
ErrorSpec(0.1));
}
+// TODO(b/62764704): Implement on CPU and GPU. Disabled on 2017-07-11.
+XLA_TEST_F(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(
+ DISABLED_ON_GPU(BatchNormGradBasic)))) {
+ const int kFeatureIndex = 2;
+ ComputationBuilder builder(client_, TestName());
+
+ auto operand =
+ builder.ConstantR4FromArray4D<float>(Array4D<float>(2, 2, 2, 1, 0.0f));
+
+ auto scale = builder.ConstantR1<float>({1.0f, 1.0f});
+
+ auto mean = builder.ConstantR1<float>({0.0f, 0.0f});
+
+ auto var = builder.ConstantR1<float>({1.0f, 1.0f});
+
+ auto grad_output = builder.ConstantR4FromArray4D<float>(
+ {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
+
+ builder.BatchNormGrad(operand, scale, mean, var, grad_output,
+ /*epsilon=*/0.0, kFeatureIndex);
+
+ auto expected = *Literal::MakeTuple(
+ {Literal::CreateR4<float>(
+ {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}})
+ .get(),
+ Literal::CreateR1<float>({0, 0}).get(),
+ Literal::CreateR1<float>({16, 20}).get()});
+
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index 576c1c703d..9e85e35707 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -44,295 +44,310 @@ namespace {
class DynamicSliceTest : public ClientLibraryTestBase {
protected:
- template <typename IndexT>
+ template <typename IndexT, typename DataT>
void TestR1() {
// Slice at dimension start.
- RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {0}, {5},
- {0.0, 1.0, 2.0, 3.0, 4.0});
+ RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {0}, {5}, {0, 1, 2, 3, 4});
// Slice in the middle.
- RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {2}, {3},
- {2.0, 3.0, 4.0});
+ RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {3}, {2, 3, 4});
// Slice at dimension boundaries.
- RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {5}, {3},
- {5.0, 6.0, 7.0});
+ RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {5}, {3}, {5, 6, 7});
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
- RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {6}, {4},
- {6.0, 7.0, 0.0, 1.0});
+ RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1});
// Zero element slice.
- RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {2}, {0}, {});
+ RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {0}, {});
}
- template <typename IndexT>
+ template <typename IndexT, typename DataT>
void TestR2() {
// Slice at dimension start.
- RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
- {0, 0}, {2, 2}, {{1.0f, 2.0f}, {4.0f, 5.0f}});
+ RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 2},
+ {{1, 2}, {4, 5}});
// Slice in the middle.
- RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
- {1, 1}, {2, 1}, {{5.0f}, {8.0f}});
+ RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1},
+ {{5}, {8}});
// Slice at dimension boundaries.
- RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
- {1, 1}, {2, 1}, {{5.0f}, {8.0f}});
+ RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1},
+ {{5}, {8}});
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
- RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
- {1, 1}, {3, 3},
- {{5.0f, 6.0f, 4.0f}, {8.0f, 9.0f, 7.0f}, {2.0f, 3.0f, 1.0f}});
+ RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3},
+ {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}});
// Zero element slice: 2x0.
- RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
- {0, 0}, {2, 0}, {{}, {}});
+ RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 0},
+ {{}, {}});
// Zero element slice: 0x2.
- RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
- {0, 0}, {0, 2}, Array2D<float>(0, 2));
+ RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {0, 2},
+ Array2D<DataT>(0, 2));
}
- template <typename IndexT>
+ template <typename IndexT, typename DataT>
void TestR3() {
// R3 Shape: [2, 3, 2]
// clang-format off
// Slice at dimension start.
- RunR3<IndexT>(
- {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
- {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
- {0, 0, 0}, {2, 1, 2},
- {{{1.0f, 2.0f}}, {{7.0f, 8.0f}}});
+ RunR3<IndexT, DataT>(
+ {{{1, 2}, {3, 4}, {5, 6}},
+ {{7, 8}, {9, 10}, {11, 12}}},
+ {0, 0, 0}, {2, 1, 2},
+ {{{1, 2}}, {{7, 8}}});
// Slice in the middle.
- RunR3<IndexT>(
- {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
- {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
- {0, 1, 1}, {2, 2, 1},
- {{{4.0f}, {6.0f}}, {{10.0f}, {12.0f}}});
+ RunR3<IndexT, DataT>(
+ {{{1, 2}, {3, 4}, {5, 6}},
+ {{7, 8}, {9, 10}, {11, 12}}},
+ {0, 1, 1}, {2, 2, 1},
+ {{{4}, {6}}, {{10}, {12}}});
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
- RunR3<IndexT>(
- {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
- {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
- {0, 2, 1}, {2, 1, 2},
- {{{6.0f, 5.0f}}, {{12.0f, 11.0f}}});
+ RunR3<IndexT, DataT>(
+ {{{1, 2}, {3, 4}, {5, 6}},
+ {{7, 8}, {9, 10}, {11, 12}}},
+ {0, 2, 1}, {2, 1, 2},
+ {{{6, 5}}, {{12, 11}}});
// clang-format on
}
- template <typename IndexT>
- void RunR1(const std::vector<float>& input_values,
+ template <typename IndexT, typename DataT>
+ void RunR1(tensorflow::gtl::ArraySlice<DataT> input_values,
const std::vector<IndexT> slice_starts,
const std::vector<int64>& slice_sizes,
- const std::vector<float>& expected_values) {
+ tensorflow::gtl::ArraySlice<DataT> expected_values) {
ComputationBuilder builder(client_, TestName());
// Initialize and transfer dynamic slice start indices parameter.
ComputationDataHandle starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
- auto input = builder.ConstantR1<float>(input_values);
+ auto input = builder.ConstantR1<DataT>(input_values);
builder.DynamicSlice(input, starts, slice_sizes);
// Run computation and compare against expected values.
- ComputeAndCompareR1<float>(&builder, expected_values, {start_data.get()},
- ErrorSpec(0.000001));
+ ComputeAndCompareR1<DataT>(&builder, expected_values, {start_data.get()});
}
- template <typename IndexT>
- void RunR2(const Array2D<float>& input_values,
+ template <typename IndexT, typename DataT>
+ void RunR2(const Array2D<DataT>& input_values,
const std::vector<IndexT> slice_starts,
const std::vector<int64>& slice_sizes,
- const Array2D<float>& expected_values) {
+ const Array2D<DataT>& expected_values) {
ComputationBuilder builder(client_, TestName());
// Initialize and transfer dynamic slice start indices parameter.
ComputationDataHandle starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
- auto input = builder.ConstantR2FromArray2D<float>(input_values);
+ auto input = builder.ConstantR2FromArray2D<DataT>(input_values);
builder.DynamicSlice(input, starts, slice_sizes);
// Run computation and compare against expected values.
- ComputeAndCompareR2<float>(&builder, expected_values, {start_data.get()},
- ErrorSpec(0.000001));
+ ComputeAndCompareR2<DataT>(&builder, expected_values, {start_data.get()});
}
- template <typename IndexT>
- void RunR3(const Array3D<float>& input_values,
+ template <typename IndexT, typename DataT>
+ void RunR3(const Array3D<DataT>& input_values,
const std::vector<IndexT> slice_starts,
const std::vector<int64>& slice_sizes,
- const Array3D<float>& expected_values) {
+ const Array3D<DataT>& expected_values) {
ComputationBuilder builder(client_, TestName());
// Initialize and transfer dynamic slice start indices parameter.
ComputationDataHandle starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
- auto input = builder.ConstantR3FromArray3D<float>(input_values);
+ auto input = builder.ConstantR3FromArray3D<DataT>(input_values);
builder.DynamicSlice(input, starts, slice_sizes);
// Run computation and compare against expected values.
- ComputeAndCompareR3<float>(&builder, expected_values, {start_data.get()},
- ErrorSpec(0.000001));
+ ComputeAndCompareR3<DataT>(&builder, expected_values, {start_data.get()});
}
};
-XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32>(); }
+XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32, int32>(); }
+
+XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64, float>(); }
+
+XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64, double>(); }
+
+XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32, float>(); }
-XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64>(); }
+XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64, double>(); }
-XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64>(); }
+XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64, int32>(); }
-XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32>(); }
+XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32, int32>(); }
-XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64>(); }
+XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64, float>(); }
-XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64>(); }
+XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64, double>(); }
-XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32>(); }
+XLA_TEST_F(DynamicSliceTest, Int32R1Pred) {
+ // Slice at dimension start.
+ RunR1<int32, bool>({true, false, false, true, false, true, true, false}, {0},
+ {5}, {true, false, false, true, false});
+ // Slice in the middle.
+ RunR1<int32, bool>({true, false, false, true, false, true, true, false}, {2},
+ {3}, {false, true, false});
+ // Slice at dimension boundaries.
+ RunR1<int32, bool>({true, false, false, true, false, true, true, false}, {5},
+ {3}, {true, true, false});
+ // Zero element slice.
+ RunR1<int32, bool>({true, false, false, true, false, true, true, false}, {2},
+ {0}, {});
+}
-XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64>(); }
+XLA_TEST_F(DynamicSliceTest, Int32R2Pred) {
+ // Slice at dimension start.
+ RunR2<int32, bool>(
+ {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
+ {2, 2}, {{true, false}, {false, false}});
+ // Slice in the middle.
+ RunR2<int32, bool>(
+ {{true, false, true}, {false, false, true}, {true, true, false}}, {1, 1},
+ {2, 1}, {{false}, {true}});
+ // Slice at dimension boundaries.
+ RunR2<int32, bool>(
+ {{true, false, true}, {false, false, true}, {true, true, false}}, {1, 1},
+ {2, 1}, {{false}, {true}});
+ // Zero element slice: 2x0.
+ RunR2<int32, bool>(
+ {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
+ {2, 0}, {{}, {}});
+ // Zero element slice: 0x2.
+ RunR2<int32, bool>(
+ {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0},
+ {0, 2}, Array2D<bool>(0, 2));
+}
-XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64>(); }
+XLA_TEST_F(DynamicSliceTest, Int32R3Pred) {
+ // R3 Shape: [2, 3, 2]
+ // clang-format off
+
+ // Slice at dimension start.
+ RunR3<int32, bool>(
+ {{{true, false}, {false, true}, {true, true}},
+ {{false, true}, {true, false}, {false, false}}},
+ {0, 0, 0}, {2, 1, 2},
+ {{{true, false}}, {{false, true}}});
+
+ // Slice in the middle.
+ RunR3<int32, bool>(
+ {{{true, false}, {false, true}, {true, true}},
+ {{false, true}, {true, false}, {false, false}}},
+ {0, 1, 1}, {2, 2, 1},
+ {{{true}, {true}}, {{false}, {false}}});
+
+ // clang-format on
+}
class DynamicUpdateSliceTest : public ClientLibraryTestBase {
protected:
- template <typename IndexT>
+ template <typename IndexT, typename DataT>
void TestR1() {
- // clang-format off
// Slice at dimension start.
- RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
- {8.0, 9.0, 10.0}, {0},
- {8.0, 9.0, 10.0, 3.0, 4.0, 5.0, 6.0, 7.0});
+ RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {0},
+ {8, 9, 10, 3, 4, 5, 6, 7});
// Slice in the middle.
- RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
- {8.0, 9.0, 10.0}, {2},
- {0.0, 1.0, 8.0, 9.0, 10.0, 5.0, 6.0, 7.0});
+ RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {2},
+ {0, 1, 8, 9, 10, 5, 6, 7});
// Slice at dimension boundaries.
- RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
- {8.0, 9.0, 10.0}, {5},
- {0.0, 1.0, 2.0, 3.0, 4.0, 8.0, 9.0, 10.0});
+ RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {5},
+ {0, 1, 2, 3, 4, 8, 9, 10});
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
- RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
- {8.0, 9.0, 10.0}, {6},
- {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 8.0, 9.0});
+ RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6},
+ {0, 1, 2, 3, 4, 5, 8, 9});
// Zero-sized update.
- RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
- {}, {2},
- {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0});
- // clang-format on
+ RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {}, {2},
+ {0, 1, 2, 3, 4, 5, 6, 7});
}
- template <typename IndexT>
+ template <typename IndexT, typename DataT>
void TestR2() {
- // clang-format off
// Slice at dimension start.
- RunR2<IndexT>(
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
- {{10.0f, 11.0f}}, {0, 0},
- {{10.0f, 11.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}});
+ RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {0, 0},
+ {{10, 11, 3}, {4, 5, 6}, {7, 8, 9}});
// Slice in the middle.
- RunR2<IndexT>(
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
- {{10.0f, 11.0f}}, {1, 1},
- {{1.0f, 2.0f, 3.0f}, {4.0f, 10.0f, 11.0f}, {7.0f, 8.0f, 9.0f}});
+ RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {1, 1},
+ {{1, 2, 3}, {4, 10, 11}, {7, 8, 9}});
// Slice at dimension boundaries.
- RunR2<IndexT>(
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
- {{10.0f, 11.0f}}, {2, 1},
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 10.0f, 11.0f}});
+ RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 1},
+ {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}});
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
- RunR2<IndexT>(
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
- {{10.0f, 11.0f}}, {2, 2},
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 10.0f}});
+ RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2},
+ {{1, 2, 3}, {4, 5, 6}, {7, 8, 10}});
// Zero-sized update.
- RunR2<IndexT>(
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
- {{}}, {2, 1},
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}});
- // clang-format on
+ RunR2<IndexT, DataT>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{}}, {2, 1},
+ {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
}
- template <typename IndexT>
+ template <typename IndexT, typename DataT>
void TestR3() {
// R3 Shape: [2, 3, 2]
- // clang-format off
// Slice at dimension start.
- RunR3<IndexT>(
- {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
- {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
- {{{13.0f, 14.0f}, {15.0f, 16.0f}},
- {{17.0f, 18.0f}, {19.0f, 20.0f}}},
- {0, 0, 0},
- {{{13.0f, 14.0f}, {15.0f, 16.0f}, {5.0f, 6.0f}},
- {{17.0f, 18.0f}, {19.0f, 20.0f}, {11.0f, 12.0f}}});
+ RunR3<IndexT, DataT>(
+ {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}},
+ {{{13, 14}, {15, 16}}, {{17, 18}, {19, 20}}}, {0, 0, 0},
+ {{{13, 14}, {15, 16}, {5, 6}}, {{17, 18}, {19, 20}, {11, 12}}});
// Slice in the middle.
- RunR3<IndexT>(
- {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
- {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
- {{{13.0f}, {15.0f}}},
- {1, 1, 1},
- {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
- {{7.0f, 8.0f}, {9.0f, 13.0f}, {11.0f, 15.0f}}});
+ RunR3<IndexT, DataT>(
+ {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}},
+ {1, 1, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}});
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
- RunR3<IndexT>(
- {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
- {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
- {{{13.0f}, {15.0f}}},
- {1, 2, 1},
- {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
- {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 13.0f}}});
- // clang-format on
+ RunR3<IndexT, DataT>(
+ {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}},
+ {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 13}}});
}
- template <typename IndexT>
- void RunR1(const std::vector<float>& input_values,
- const std::vector<float>& update_values,
+ template <typename IndexT, typename DataT>
+ void RunR1(tensorflow::gtl::ArraySlice<DataT> input_values,
+ tensorflow::gtl::ArraySlice<DataT> update_values,
const std::vector<IndexT> slice_starts,
- const std::vector<float>& expected_values) {
+ tensorflow::gtl::ArraySlice<DataT> expected_values) {
ComputationBuilder builder(client_, TestName());
// Initialize and transfer dynamic slice start indices parameter.
ComputationDataHandle starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
- auto input = builder.ConstantR1<float>(input_values);
- auto update = builder.ConstantR1<float>(update_values);
+ auto input = builder.ConstantR1<DataT>(input_values);
+ auto update = builder.ConstantR1<DataT>(update_values);
builder.DynamicUpdateSlice(input, update, starts);
// Run computation and compare against expected values.
- ComputeAndCompareR1<float>(&builder, expected_values, {start_data.get()},
- ErrorSpec(0.000001));
+ ComputeAndCompareR1<DataT>(&builder, expected_values, {start_data.get()});
}
- template <typename IndexT>
- void RunR2(const Array2D<float>& input_values,
- const Array2D<float>& update_values,
+ template <typename IndexT, typename DataT>
+ void RunR2(const Array2D<DataT>& input_values,
+ const Array2D<DataT>& update_values,
const std::vector<IndexT> slice_starts,
- const Array2D<float>& expected_values) {
+ const Array2D<DataT>& expected_values) {
ComputationBuilder builder(client_, TestName());
// Initialize and transfer dynamic slice start indices parameter.
ComputationDataHandle starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
- auto input = builder.ConstantR2FromArray2D<float>(input_values);
- auto update = builder.ConstantR2FromArray2D<float>(update_values);
+ auto input = builder.ConstantR2FromArray2D<DataT>(input_values);
+ auto update = builder.ConstantR2FromArray2D<DataT>(update_values);
builder.DynamicUpdateSlice(input, update, starts);
// Run computation and compare against expected values.
- ComputeAndCompareR2<float>(&builder, expected_values, {start_data.get()},
- ErrorSpec(0.000001));
+ ComputeAndCompareR2<DataT>(&builder, expected_values, {start_data.get()});
}
- template <typename IndexT>
- void RunR3(const Array3D<float>& input_values,
- const Array3D<float>& update_values,
+ template <typename IndexT, typename DataT>
+ void RunR3(const Array3D<DataT>& input_values,
+ const Array3D<DataT>& update_values,
const std::vector<IndexT> slice_starts,
- const Array3D<float>& expected_values) {
+ const Array3D<DataT>& expected_values) {
ComputationBuilder builder(client_, TestName());
// Initialize and transfer dynamic slice start indices parameter.
ComputationDataHandle starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
- auto input = builder.ConstantR3FromArray3D<float>(input_values);
- auto update = builder.ConstantR3FromArray3D<float>(update_values);
+ auto input = builder.ConstantR3FromArray3D<DataT>(input_values);
+ auto update = builder.ConstantR3FromArray3D<DataT>(update_values);
builder.DynamicUpdateSlice(input, update, starts);
// Run computation and compare against expected values.
- ComputeAndCompareR3<float>(&builder, expected_values, {start_data.get()},
- ErrorSpec(0.000001));
+ ComputeAndCompareR3<DataT>(&builder, expected_values, {start_data.get()});
}
void RunR3Contiguous(std::vector<int32> operand_shape, int32 index,
@@ -393,23 +408,81 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
}
};
-XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1<int32>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1<int32, float>(); }
-XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1<int64>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1<int64, float>(); }
-XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1<uint64>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1<uint64, double>(); }
-XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2<int32>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2<int32, float>(); }
-XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2<int64>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2<int64, int64>(); }
-XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2<uint64>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2<uint64, int32>(); }
-XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32, float>(); }
-XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64, int64>(); }
-XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64, uint64>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) {
+ // Slice at dimension start.
+ RunR1<int32, bool>({false, false, true, true, false, true, true, false},
+ {true, true, false}, {0},
+ {true, true, false, true, false, true, true, false});
+ // Slice in the middle.
+ RunR1<int32, bool>({false, false, true, true, false, true, true, false},
+ {false, true, true}, {2},
+ {false, false, false, true, true, true, true, false});
+ // Slice at dimension boundaries.
+ RunR1<int32, bool>({false, false, true, true, false, true, true, false},
+ {false, true, true}, {5},
+ {false, false, true, true, false, false, true, true});
+ // Zero-sized update.
+ RunR1<int32, bool>({false, false, true, true, false, true, true, false}, {},
+ {2}, {false, false, true, true, false, true, true, false});
+}
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R2Pred) {
+ // Slice at dimension start.
+ RunR2<int32, bool>(
+ {{false, true, false}, {true, false, true}, {false, true, true}},
+ {{true, false}}, {0, 0},
+ {{true, false, false}, {true, false, true}, {false, true, true}});
+ // Slice in the middle.
+ RunR2<int32, bool>(
+ {{false, true, false}, {true, false, true}, {false, true, true}},
+ {{true, false}}, {1, 1},
+ {{false, true, false}, {true, true, false}, {false, true, true}});
+ // Slice at dimension boundaries.
+ RunR2<int32, bool>(
+ {{false, true, false}, {true, false, true}, {false, true, true}},
+ {{true, false}}, {2, 1},
+ {{false, true, false}, {true, false, true}, {false, true, false}});
+ // Zero-sized update.
+ RunR2<int32, bool>(
+ {{false, true, false}, {true, false, true}, {false, true, true}}, {{}},
+ {2, 1}, {{false, true, false}, {true, false, true}, {false, true, true}});
+}
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) {
+ // R3 Shape: [2, 3, 2]
+ // Slice at dimension start.
+ RunR3<int32, bool>(
+ {{{true, false}, {false, true}, {true, true}},
+ {{false, false}, {false, true}, {true, false}}},
+ {{{false, true}, {true, false}}, {{true, true}, {false, true}}},
+ {0, 0, 0},
+ {{{false, true}, {true, false}, {true, true}},
+ {{true, true}, {false, true}, {true, false}}});
+ // Slice in the middle.
+ RunR3<int32, bool>({{{true, false}, {false, true}, {true, true}},
+ {{false, false}, {false, true}, {true, false}}},
+ {{{false}, {true}}}, {1, 1, 1},
+ {{{true, false}, {false, true}, {true, true}},
+ {{false, false}, {false, false}, {true, true}}});
+}
// Tests for simple R3 case where the update is contiguous (i.e. the minor
// two dimensions are not sliced).
diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
index 48212dc7d1..527205bbb0 100644
--- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -39,8 +40,11 @@ limitations under the License.
namespace xla {
namespace {
-class ReducePrecisionTest : public ClientLibraryTestBase,
- public ::testing::WithParamInterface<int> {};
+// Tests to confirm that the ReducePrecision operation produces the expected
+// numerical values.
+class ReducePrecisionAccuracyTest : public ClientLibraryTestBase,
+ public ::testing::WithParamInterface<int> {
+};
// For reduction to IEEE-f16, we want to test the following cases, in both
// positive and negative variants. (Note: IEEE-f16 is 5 exponent bits and 10
@@ -201,7 +205,7 @@ static const uint32_t test_values[][4] = {
FPVAL(11111111, 1111111111, 1111111111111) // NaN
}};
-XLA_TEST_P(ReducePrecisionTest, ReducePrecisionF32) {
+XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
int index = GetParam();
int exponent_bits = exponent_sizes[index];
int mantissa_bits = mantissa_sizes[index];
@@ -238,9 +242,87 @@ XLA_TEST_P(ReducePrecisionTest, ReducePrecisionF32) {
ComputeAndCompareR1<float>(&builder, expected_values, {a_data.get()});
}
-INSTANTIATE_TEST_CASE_P(ReducePrecisionTest, ReducePrecisionTest,
+INSTANTIATE_TEST_CASE_P(ReducePrecisionAccuracyTest,
+ ReducePrecisionAccuracyTest,
::testing::Values(0, 1, 2, 3), TestDataToString);
+// Tests to confirm that the compiler optimization functions add the expected
+// ReducePrecisionInsertion passes.
+class ReducePrecisionInsertionTest : public ClientLibraryTestBase {};
+
+XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionBeforeFusion) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({1.00001});
+ std::unique_ptr<GlobalData> a_data =
+ client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ auto a = builder.Parameter(0, a_literal->shape(), "a");
+
+ // Abs doesn't affect resolution.
+ auto abs = builder.Abs(a);
+
+ // Near 1.0, Log(x) approximates x - 1; this lets us confirm that the
+ // reduce-precision operation showed up in the correct place in the
+ // graph.
+ auto log = builder.Log(abs);
+
+ // Insert precision-reduction after the Abs(x) operation, rounding that
+ // result to exactly 1.0f.
+ auto reduce_precision_pass = execution_options_.mutable_debug_options()
+ ->add_hlo_reduce_precision_options();
+ *reduce_precision_pass = ReducePrecisionInsertion::make_options_proto(
+ HloReducePrecisionOptions::BEFORE_OP_FUSION, 5, 10,
+ [](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; });
+
+ ComputeAndCompareR1<float>(&builder, {0.0f}, {a_data.get()});
+}
+
+XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionSkippedAfterFusion) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({1.00001});
+ std::unique_ptr<GlobalData> a_data =
+ client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ auto a = builder.Parameter(0, a_literal->shape(), "a");
+
+ // These two operations should be fused by any reasonable backend.
+ auto abs = builder.Abs(a);
+ auto neg = builder.Neg(abs);
+
+ // Add a pass after operation fusion, suffixing kAbs operations. This
+ // should not see into the fusion nodes and thus should not affect the
+ // result.
+ auto reduce_precision_pass = execution_options_.mutable_debug_options()
+ ->add_hlo_reduce_precision_options();
+ *reduce_precision_pass = ReducePrecisionInsertion::make_options_proto(
+ HloReducePrecisionOptions::AFTER_OP_FUSION, 5, 10,
+ [](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; });
+
+ ComputeAndCompareR1<float>(&builder, {-1.00001f}, {a_data.get()});
+}
+
+XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionAddedAfterFusion) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({1.00001});
+ std::unique_ptr<GlobalData> a_data =
+ client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ auto a = builder.Parameter(0, a_literal->shape(), "a");
+
+ // These two operations should be fused by any reasonable backend.
+ auto abs = builder.Abs(a);
+ auto neg = builder.Neg(abs);
+
+ // Add a pass after operation fusion, suffixing kFusion operations.
+ auto reduce_precision_pass = execution_options_.mutable_debug_options()
+ ->add_hlo_reduce_precision_options();
+ *reduce_precision_pass = ReducePrecisionInsertion::make_options_proto(
+ HloReducePrecisionOptions::AFTER_OP_FUSION, 5, 10,
+ [](const HloOpcode opcode) { return opcode == HloOpcode::kFusion; });
+
+ ComputeAndCompareR1<float>(&builder, {-1.0f}, {a_data.get()});
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 00fb7f12b8..be4e00f63c 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -20,6 +20,24 @@ import "tensorflow/compiler/xla/service/session.proto";
package xla;
+// Options for the HLO insert-reduce-precision-operations pass.
+message HloReducePrecisionOptions {
+ // When to run the pass.
+ enum PassTiming {
+ BEFORE_OP_FUSION = 0;
+ AFTER_OP_FUSION = 1;
+ }
+ PassTiming pass_timing = 1;
+
+ // Exponent and mantissa bit counts for the reduced precision.
+ uint32 exponent_bits = 2;
+ uint32 mantissa_bits = 3;
+
+ // Opcodes for operations that should be suffixed with reduced-precision
+ // operations.
+ repeated uint32 opcodes_to_suffix = 4;
+}
+
// Debugging options for XLA. These options may change at any time - there are
// no guarantees about backward or forward compatibility for these fields.
message DebugOptions {
@@ -112,6 +130,11 @@ message DebugOptions {
// the generated IR.
bool xla_llvm_enable_invariant_load_metadata = 72;
+ // Options for inserting reduce-precision operations for numerical
+ // experimentation. This is a repeated field, as we may want to have
+ // multiple passes with different parameters.
+ repeated HloReducePrecisionOptions hlo_reduce_precision_options = 80;
+
// This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
// computation will run n! times with all permunations of layouts for the
// output shape in rank n. For example, with a 3D shape, all permutations of
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index d3cc61ce29..f1f6144acd 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -199,6 +199,7 @@ tensorflow/core/kernels/aggregate_ops.cc
tensorflow/core/kernels/depthwise_conv_op.cc
tensorflow/core/kernels/dequantize_op.cc
tensorflow/core/kernels/meta_support.cc
+tensorflow/core/kernels/population_count_op.cc
tensorflow/core/kernels/quantization_utils.cc
tensorflow/core/kernels/quantize_down_and_shrink_range.cc
tensorflow/core/kernels/quantize_op.cc
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 06954f51d8..c14463bdad 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -210,7 +210,7 @@ class RNNCellTest(test.TestCase):
sess.run([variables_lib.global_variables_initializer()])
sess.run([g, out_m],
{x.name: 1 * np.ones([batch_size, input_size]),
- m.name: 0.1 * np.ones([batch_size - 1, state_size])})
+ m.name: 0.1 * np.ones([batch_size - 1, state_size])})
def testBasicLSTMCellStateSizeError(self):
"""Tests that state_size must be num_units * 2."""
@@ -218,7 +218,7 @@ class RNNCellTest(test.TestCase):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
num_units = 2
- state_size = num_units * 3 # state_size must be num_units * 2
+ state_size = num_units * 3 # state_size must be num_units * 2
batch_size = 3
input_size = 4
x = array_ops.zeros([batch_size, input_size])
@@ -406,6 +406,31 @@ class RNNCellTest(test.TestCase):
# States are left untouched
self.assertAllClose(res[2], res[3])
+ def testResidualWrapperWithSlice(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 5])
+ m = array_ops.zeros([1, 3])
+ base_cell = rnn_cell_impl.GRUCell(3)
+ g, m_new = base_cell(x, m)
+ variable_scope.get_variable_scope().reuse_variables()
+ def residual_with_slice_fn(inp, out):
+ inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3])
+ return inp_sliced + out
+ g_res, m_new_res = rnn_cell_impl.ResidualWrapper(
+ base_cell, residual_with_slice_fn)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res_g, res_g_res, res_m_new, res_m_new_res = sess.run(
+ [g, g_res, m_new, m_new_res], {
+ x: np.array([[1., 1., 1., 1., 1.]]),
+ m: np.array([[0.1, 0.1, 0.1]])
+ })
+ # Residual connections
+ self.assertAllClose(res_g_res, res_g + [1., 1., 1.])
+ # States are left untouched
+ self.assertAllClose(res_m_new, res_m_new_res)
+
def testDeviceWrapper(self):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
index be28184ae6..61b7107a17 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
@@ -220,15 +220,14 @@ class SpectralOpsTest(test.TestCase):
# stft_bound, inverse_stft_bound).
# TODO(rjryan): Investigate why STFT gradient error is so high.
test_configs = [
- (512, 64, 32, 64, 2e-3, 3e-5),
- (512, 64, 64, 64, 2e-3, 3e-5),
- (512, 64, 25, 64, 2e-3, 3e-5),
- (512, 25, 15, 36, 2e-3, 3e-5),
- (123, 23, 5, 42, 2e-3, 4e-5),
+ (64, 16, 8, 16),
+ (64, 16, 16, 16),
+ (64, 16, 7, 16),
+ (64, 7, 4, 9),
+ (29, 5, 1, 10),
]
- for (signal_length, frame_length, frame_step, fft_length,
- stft_bound, inverse_stft_bound) in test_configs:
+ for (signal_length, frame_length, frame_step, fft_length) in test_configs:
signal_shape = [signal_length]
signal = random_ops.random_uniform(signal_shape)
stft_shape = [max(0, 1 + (signal_length - frame_length) // frame_step),
@@ -242,8 +241,8 @@ class SpectralOpsTest(test.TestCase):
stft, stft_shape)
inverse_stft_error = test.compute_gradient_error(
stft, stft_shape, inverse_stft, inverse_stft_shape)
- self.assertLess(stft_error, stft_bound)
- self.assertLess(inverse_stft_error, inverse_stft_bound)
+ self.assertLess(stft_error, 2e-3)
+ self.assertLess(inverse_stft_error, 4e-5)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 712871cc04..b6fa185709 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -54,9 +54,12 @@ def _tpu_job(run_config):
return None if run_config.master in ['', 'local'] else 'tpu_worker'
-def _per_shard_batch_size(global_batch_size, run_config):
+def _per_shard_batch_size(global_batch_size, run_config, use_tpu):
"""Returns the batch size for each shard."""
- return global_batch_size // run_config.tpu_config.num_shards
+ if use_tpu:
+ return global_batch_size // run_config.tpu_config.num_shards
+ else:
+ return global_batch_size
class _SIGNAL(object):
@@ -470,7 +473,7 @@ class _ModelFnWrapper(object):
self._train_batch_size = train_batch_size
def call_without_tpu(self, features, labels):
- return self._call_model_fn(features, labels)
+ return self._call_model_fn(features, labels, False)
def convert_to_single_tpu_train_step(self, dequeue_fn):
"""Converts the `model_fn` as a single train step on TPU."""
@@ -481,8 +484,8 @@ class _ModelFnWrapper(object):
features, labels = dequeue_fn()
# Makes deep copy with `config` and params` in case user mutates them.
- estimator_spec = self._verify_estimator_spec(self._call_model_fn(
- features, labels, add_batch_size_in_params=True))
+ estimator_spec = self._verify_estimator_spec(
+ self._call_model_fn(features, labels, True))
loss, train_op = estimator_spec.loss, estimator_spec.train_op
with ops.control_dependencies([train_op]):
return array_ops.identity(loss)
@@ -492,7 +495,7 @@ class _ModelFnWrapper(object):
def config(self):
return self._config
- def _call_model_fn(self, features, labels, add_batch_size_in_params=False):
+ def _call_model_fn(self, features, labels, use_tpu):
"""Calls the model_fn with required parameters."""
model_fn_args = util.fn_args(self._model_fn)
kwargs = {}
@@ -513,16 +516,15 @@ class _ModelFnWrapper(object):
if 'params' in model_fn_args:
kwargs['params'] = params
- if add_batch_size_in_params:
- if 'params' not in model_fn_args:
- raise ValueError(
- 'model_fn ({}) does not include params argument, '
- 'required by TPUEstimator to pass batch size as '
- 'params[\'batch_size\']'.format(self._model_fn))
- if self._mode == model_fn_lib.ModeKeys.TRAIN:
- # For TPU training. `params` is never `None`.
- params[_BATCH_SIZE_KEY] = _per_shard_batch_size(self._train_batch_size,
- config)
+ if 'params' not in model_fn_args:
+ raise ValueError(
+ 'model_fn ({}) does not include params argument, '
+ 'required by TPUEstimator to pass batch size as '
+ 'params[\'batch_size\']'.format(self._model_fn))
+ if self._mode == model_fn_lib.ModeKeys.TRAIN:
+ # For TPU training. `params` is never `None`.
+ params[_BATCH_SIZE_KEY] = _per_shard_batch_size(
+ self._train_batch_size, config, use_tpu)
return self._model_fn(features=features, **kwargs)
@@ -609,16 +611,12 @@ class TPUEstimator(estimator_lib.Estimator):
'batch size {} must be divisible by number of shards {}'
.format(train_batch_size, config.tpu_config.num_shards))
- if use_tpu:
- # Verifies the model_fn signature according to Estimator framework.
- estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access
- # We cannot store config and params in this constructor as parent
- # constructor might change them, such as assigning a temp dir for
- # config.model_dir.
- model_function = augment_model_fn_with_tpu_support(
- model_fn, train_batch_size)
- else:
- model_function = model_fn
+ # Verifies the model_fn signature according to Estimator framework.
+ estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access
+ # We cannot store config and params in this constructor as parent
+ # constructor might change them, such as assigning a temp dir for
+ # config.model_dir.
+ model_function = _augment_model_fn(model_fn, train_batch_size, use_tpu)
super(TPUEstimator, self).__init__(
model_fn=model_function,
@@ -670,9 +668,6 @@ class TPUEstimator(estimator_lib.Estimator):
Raises:
ValueError: if input_fn takes invalid arguments or does not have `params`.
"""
- if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
- return super(TPUEstimator, self)._call_input_fn(input_fn, mode)
-
input_fn_args = util.fn_args(input_fn)
config = self.config # a deep copy.
kwargs = {}
@@ -686,8 +681,13 @@ class TPUEstimator(estimator_lib.Estimator):
kwargs['config'] = config
# Now for TPU training.
- per_shard_batch_size = _per_shard_batch_size(self._train_batch_size, config)
- kwargs['params'][_BATCH_SIZE_KEY] = per_shard_batch_size
+ if mode == model_fn_lib.ModeKeys.TRAIN:
+ kwargs['params'][_BATCH_SIZE_KEY] = (
+ _per_shard_batch_size(self._train_batch_size, config, self._use_tpu))
+
+ if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
+ with ops.device('/cpu:0'):
+ return input_fn(**kwargs)
job = _tpu_job(config)
def placement_function(index):
@@ -746,7 +746,7 @@ def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder):
return (dequeue_fn, enqueue_fn)
-def augment_model_fn_with_tpu_support(model_fn, train_batch_size):
+def _augment_model_fn(model_fn, train_batch_size, use_tpu):
"""Returns a new model_fn, which wraps the TPU support."""
def _model_fn(features, labels, mode, config, params):
@@ -755,7 +755,7 @@ def augment_model_fn_with_tpu_support(model_fn, train_batch_size):
train_batch_size)
# TODO(jhseu): Move to EVAL and PREDICT to TPU.
- if mode != model_fn_lib.ModeKeys.TRAIN:
+ if not use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
return model_fn_wrapper.call_without_tpu(features, labels)
inputs = _InputsHolder(sharded_features=features, sharded_labels=labels)
diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc
index 9fdb3da6a0..95cafa24b1 100644
--- a/tensorflow/core/framework/attr_value_util.cc
+++ b/tensorflow/core/framework/attr_value_util.cc
@@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value_util.h"
+#include <string>
#include <vector>
+
#include "tensorflow/core/framework/attr_value.pb_text.h"
#include "tensorflow/core/framework/tensor.pb_text.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@@ -27,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
-
namespace {
string SummarizeString(const string& str) {
@@ -460,7 +461,8 @@ bool HasPlaceHolder(const AttrValue& val) {
return false;
}
-bool SubstitutePlaceholders(SubstituteFunc substitute, AttrValue* value) {
+bool SubstitutePlaceholders(const SubstituteFunc& substitute,
+ AttrValue* value) {
switch (value->value_case()) {
case AttrValue::kList: {
for (NameAttrList& func : *value->mutable_list()->mutable_func()) {
diff --git a/tensorflow/core/framework/attr_value_util.h b/tensorflow/core/framework/attr_value_util.h
index 08cc3b7158..08d813bb6f 100644
--- a/tensorflow/core/framework/attr_value_util.h
+++ b/tensorflow/core/framework/attr_value_util.h
@@ -16,8 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
#define TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
+#include <functional>
#include <string>
#include <vector>
+
#include "tensorflow/core/framework/attr_value.pb.h" // TODO(62899350): Remove
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
@@ -100,8 +102,8 @@ bool HasPlaceHolder(const AttrValue& val);
// SubstituteFunc is given a placeholder string. If the placeholder is
// unknown, SubstituteFunc returns false. Otherwise, overwrites the
// attr value and returns true.
-typedef std::function<bool(const string&, AttrValue*)> SubstituteFunc;
-bool SubstitutePlaceholders(SubstituteFunc substitute, AttrValue* value);
+using SubstituteFunc = std::function<bool(const string&, AttrValue*)>;
+bool SubstitutePlaceholders(const SubstituteFunc& substitute, AttrValue* value);
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index db279ae67f..7f845bb9e2 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -104,7 +104,8 @@ ConstantFolding::ConstantFolding() {
ops_to_preserve_ = std::regex(
"Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader|"
"Enter|RefEnter|Exit|RefExit|NextIteration|RefNextIteration|"
- ".*Quantized.*");
+ ".*Quantized.*",
+ std::regex_constants::optimize);
}
string ConstantFolding::AddControlDependency(const string& input_name) {
@@ -240,13 +241,18 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
}
bool ConstantFolding::IsFoldable(const NodeDef& node) const {
+ // Folding not applicable to ops with no inputs.
+ if (node.input().empty()) {
+ return false;
+ }
+
// Skips nodes that must be preserved, and op_types that don't benefit from
// folding
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
return false;
}
- std::cmatch match;
- if (std::regex_match(node.op().c_str(), match, ops_to_preserve_)) {
+ if (std::regex_match(node.op().c_str(), ops_to_preserve_,
+ std::regex_constants::match_any)) {
return false;
}
@@ -264,23 +270,6 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
return false;
}
- DeviceTypeVector device_types;
- status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node,
- &device_types);
- if (!status.ok()) {
- return false;
- }
- // Only fold ops with a CPU implementation available.
- if (device_types.empty()) {
- return false;
- }
- DCHECK_EQ(DeviceType(DEVICE_CPU), device_types[0]);
-
- // Folding not applicable to ops with no inputs.
- if (node.input().empty()) {
- return false;
- }
-
// No need to (and don't) fold nodes that have no outgoing edges. Such nodes
// could be introduced by an earlier constant folding pass and are preserved
// in case users want to fetch their values; re-processing them would
@@ -391,12 +380,15 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
// Control dependency
break;
}
- // There should be a single output since the input node should be a constant
- // node.
- TensorVector output;
- TF_RETURN_IF_ERROR(
- EvaluateNode(*node_map_->GetNode(input), TensorVector(), &output));
- inputs.push_back(output[position]);
+ const NodeDef* input_node = node_map_->GetNode(input);
+ if (!IsConstant(*input_node)) {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("Can't fold ", node.name(), ", its ", input,
+ " isn't constant"));
+ }
+ Tensor* value = new Tensor(input_node->attr().at("dtype").type());
+ CHECK(value->FromProto(input_node->attr().at("value").tensor()));
+ inputs.emplace_back(value);
}
TensorVector output_tensors;
@@ -583,24 +575,31 @@ Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) {
Status ConstantFolding::FoldGraph(GraphDef* output) {
std::unordered_set<string> processed_nodes;
- int previously_processed = 0;
- do {
- previously_processed = processed_nodes.size();
- for (const auto& node : graph_.node()) {
- if (IsFoldable(node) &&
- processed_nodes.find(node.name()) == processed_nodes.end()) {
- Status s = FoldNode(node, output);
- if (!s.ok()) {
- VLOG(1) << "Failed to fold node " << node.name() << ": " << s;
+ std::deque<const NodeDef*> queue;
+ for (const auto& node : graph_.node()) {
+ if (IsFoldable(node)) {
+ queue.push_back(&node);
+ }
+ }
+ while (!queue.empty()) {
+ const NodeDef* node = queue.front();
+ queue.pop_front();
+ if (processed_nodes.count(node->name())) {
+ continue;
+ }
+ Status s = FoldNode(*node, output);
+ processed_nodes.insert(node->name());
+ if (!s.ok()) {
+ VLOG(1) << "Failed to fold node " << node->name() << ": " << s;
+ } else {
+ auto outputs = node_map_->GetOutputs(node->name());
+ for (auto& output : outputs) {
+ if (IsFoldable(*output)) {
+ queue.push_back(output);
}
- processed_nodes.insert(node.name());
}
}
- // Try again as long as we find new constants. In most cases, this loop will
- // only run once since the graph is already in topological order.
- VLOG(1) << "Folded " << processed_nodes.size() - previously_processed
- << " nodes in this pass";
- } while (previously_processed != processed_nodes.size());
+ }
// Build the graph after constant folding. Note that we keep all processed
// nodes in the graph in case users need to fetch their values.
@@ -740,7 +739,6 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
graph_ = item.graph;
- LOG(INFO) << "Initial graph size: " << item.graph.node_size();
node_map_.reset(new NodeMap(&graph_));
for (const auto& node : item.fetch) {
nodes_to_preserve_.insert(NodeName(node));
@@ -761,7 +759,6 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
TF_RETURN_IF_ERROR(FoldGraph(output));
TF_RETURN_IF_ERROR(SimplifyGraph(output, properties));
- LOG(INFO) << "Optimized graph size: " << output->node_size();
*output->mutable_library() = item.graph.library();
*output->mutable_versions() = item.graph.versions();
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index fffcb980db..a493452777 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -702,6 +702,39 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "compare_and_bitpack_op",
+ srcs = ["compare_and_bitpack_op.cc"],
+ hdrs = ["compare_and_bitpack_op.h"],
+ gpu_srcs = [
+ "compare_and_bitpack_op.h",
+ "compare_and_bitpack_op_gpu.cu.cc",
+ ],
+ deps = ARRAY_DEPS,
+)
+
+# TODO(ebrevdo): Add benchmarks once the op is in the autogen array namespace.
+# tf_cuda_cc_test(
+# name = "compare_and_bitpack_op_test",
+# srcs = ["compare_and_bitpack_op_test.cc"],
+# deps = [
+# ":array",
+# ":ops_testutil",
+# ":ops_util",
+# "//third_party/eigen3",
+# "//tensorflow/cc:cc_ops",
+# "//tensorflow/cc:cc_ops_internal",
+# "//tensorflow/core:core_cpu",
+# "//tensorflow/core:core_cpu_internal",
+# "//tensorflow/core:framework",
+# "//tensorflow/core:lib",
+# "//tensorflow/core:protos_all_cc",
+# "//tensorflow/core:test",
+# "//tensorflow/core:test_main",
+# "//tensorflow/core:testlib",
+# ],
+# )
+
+tf_kernel_library(
name = "reshape_op",
prefix = "reshape_op",
deps = ARRAY_DEPS,
@@ -2344,10 +2377,12 @@ cc_library(
":bucketize_op",
":cast_op",
":check_numerics_op",
+ ":compare_and_bitpack_op",
":cross_op",
":cwise_op",
":fft_ops",
":matmul_op",
+ ":population_count_op",
":reduction_ops",
":scan_ops",
":segment_reduction_ops",
@@ -2410,6 +2445,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "population_count_op",
+ prefix = "population_count_op",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
name = "fft_ops",
prefix = "fft_ops",
deps = MATH_DEPS + [
@@ -4292,6 +4333,8 @@ filegroup(
"fake_quant_ops.cc",
"fifo_queue.cc",
"fused_batch_norm_op.cc",
+ "population_count_op.cc",
+ "population_count_op.h",
"winograd_transform.h",
":android_extended_ops_headers",
] + select({
diff --git a/tensorflow/core/kernels/compare_and_bitpack_op.cc b/tensorflow/core/kernels/compare_and_bitpack_op.cc
new file mode 100644
index 0000000000..9f626a274a
--- /dev/null
+++ b/tensorflow/core/kernels/compare_and_bitpack_op.cc
@@ -0,0 +1,185 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/math_ops.cc
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/compare_and_bitpack_op.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class CompareAndBitpackOp : public OpKernel {
+ public:
+ explicit CompareAndBitpackOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* c) override {
+ const Tensor& input_t = c->input(0);
+ const Tensor& threshold_t = c->input(1);
+ OP_REQUIRES(
+ c, TensorShapeUtils::IsScalar(threshold_t.shape()),
+ errors::InvalidArgument("Compare must be a scalar, but saw shape: ",
+ threshold_t.shape().DebugString()));
+ const TensorShape& input_shape = input_t.shape();
+ OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_shape),
+ errors::InvalidArgument(
+ "Input should be at least a vector, but saw a scalar."));
+ OP_REQUIRES(c, input_shape.dim_size(input_shape.dims() - 1) % 8 == 0,
+ errors::InvalidArgument(
+ "Inner dimension of input should be "
+ "divisible by ",
+ 8, ", but saw shape: ", input_shape.DebugString()));
+
+ TensorShape output_shape = input_shape;
+ int rank = input_shape.dims();
+ output_shape.set_dim(rank - 1, input_shape.dim_size(rank - 1) / 8);
+
+ Tensor* output_t;
+ OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output_t));
+
+ auto input = input_t.flat_inner_dims<T>();
+ auto threshold = threshold_t.scalar<T>();
+ auto output = output_t->flat_inner_dims<uint8>();
+
+ functor::CompareAndBitpack<Device, T> func;
+ func(c, input, threshold, output);
+ }
+};
+
+#define REGISTER_COMPARE_AND_BITPACK(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("CompareAndBitpack").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ CompareAndBitpackOp<CPUDevice, type>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_COMPARE_AND_BITPACK);
+TF_CALL_bool(REGISTER_COMPARE_AND_BITPACK);
+
+#undef REGISTER_COMPARE_AND_BITPACK
+
+namespace functor {
+
+template <typename T, class = void, class = void>
+struct ComputeShard {
+ static EIGEN_STRONG_INLINE void Compute(typename TTypes<T>::ConstMatrix input,
+ typename TTypes<uint8>::Matrix output,
+ const T& thresh, int64 start,
+ int64 limit) {
+ for (int64 i = start; i < limit; ++i) {
+ uint8* out = output.data() + i;
+ const T* block = input.data() + 8 * i;
+ *out = ((((block[0] > thresh) << 7)) | (((block[1] > thresh) << 6)) |
+ (((block[2] > thresh) << 5)) | (((block[3] > thresh) << 4)) |
+ (((block[4] > thresh) << 3)) | (((block[5] > thresh) << 2)) |
+ (((block[6] > thresh) << 1)) | (((block[7] > thresh))));
+ }
+ }
+};
+
+// Specialization for bool on systems where sizeof(bool) == 1.
+template <typename T>
+struct ComputeShard<T,
+ typename std::enable_if<std::is_same<T, bool>::value>::type,
+ typename std::enable_if<sizeof(T) == 1>::type> {
+ static EIGEN_STRONG_INLINE void Compute(
+ typename TTypes<bool>::ConstMatrix input,
+ typename TTypes<uint8>::Matrix output, bool /*thresh*/, int64 start,
+ int64 limit) {
+ // NOTE(ebrevdo): This assumes memory is little-endian.
+ for (int64 i = start; i < limit; ++i) {
+ uint8* out = output.data() + i;
+ const int64 block = *reinterpret_cast<const int64*>(input.data() + 8 * i);
+ *out =
+ ((((block & (1LL << (7 * 8))) >> (7 * 8 - 0))) |
+ (((block & (1LL << (6 * 8))) >> (6 * 8 - 1))) |
+ (((block & (1LL << (5 * 8))) >> (5 * 8 - 2))) |
+ (((block & (1LL << (4 * 8))) >> (4 * 8 - 3))) |
+ (((block & (1LL << (3 * 8))) >> (3 * 8 - 4))) |
+ (((block & (1LL << (2 * 8))) >> (2 * 8 - 5))) |
+ (((block & (1LL << 8)) >> (1 * 8 - 6))) | (((block & (1LL)) << 7)));
+ }
+ }
+};
+
+template <typename T>
+struct CompareAndBitpack<CPUDevice, T> {
+ void operator()(OpKernelContext* c, typename TTypes<T>::ConstMatrix input,
+ typename TTypes<T>::ConstScalar threshold,
+ TTypes<uint8>::Matrix output) {
+ const T thresh = threshold();
+ auto shard = [&, thresh](int64 start, int64 limit) {
+ ComputeShard<T>::Compute(input, output, thresh, start, limit);
+ };
+ int64 total_shards = output.size(); // Approximate cmp as an add and
+ // bitwise-or + shift as an add.
+ const double total_cost = 8 * (Eigen::TensorOpCost::AddCost<T>() +
+ Eigen::TensorOpCost::AddCost<uint8>());
+ const int64 shard_cost = (total_cost >= static_cast<double>(kint64max))
+ ? kint64max
+ : static_cast<int64>(total_cost);
+
+ auto worker_threads = *(c->device()->tensorflow_cpu_worker_threads());
+ Shard(worker_threads.num_threads, worker_threads.workers, total_shards,
+ shard_cost, shard);
+ }
+};
+
+} // namespace functor
+
+#if GOOGLE_CUDA
+
+#define REGISTER_COMPARE_AND_BITPACK(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("CompareAndBitpack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ CompareAndBitpackOp<GPUDevice, type>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_COMPARE_AND_BITPACK);
+TF_CALL_bool(REGISTER_COMPARE_AND_BITPACK);
+
+#undef REGISTER_COMPARE_AND_BITPACK
+
+namespace functor {
+
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void CompareAndBitpack<GPUDevice, T>::operator()( \
+ OpKernelContext* c, typename TTypes<T>::ConstMatrix input, \
+ typename TTypes<T>::ConstScalar threshold, \
+ TTypes<uint8>::Matrix output); \
+ extern template struct CompareAndBitpack<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC)
+TF_CALL_bool(DECLARE_GPU_SPEC)
+
+#undef DECLARE_GPU_SPEC
+
+} // namespace functor
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/compare_and_bitpack_op.h b/tensorflow/core/kernels/compare_and_bitpack_op.h
new file mode 100644
index 0000000000..8e020249c1
--- /dev/null
+++ b/tensorflow/core/kernels/compare_and_bitpack_op.h
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+struct CompareAndBitpack {
+ void operator()(OpKernelContext* c, typename TTypes<T>::ConstMatrix input,
+ typename TTypes<T>::ConstScalar threshold,
+ TTypes<uint8>::Matrix output);
+};
+
+} // namespace functor
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_
diff --git a/tensorflow/core/kernels/compare_and_bitpack_op_gpu.cu.cc b/tensorflow/core/kernels/compare_and_bitpack_op_gpu.cu.cc
new file mode 100644
index 0000000000..345405e3fe
--- /dev/null
+++ b/tensorflow/core/kernels/compare_and_bitpack_op_gpu.cu.cc
@@ -0,0 +1,141 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/compare_and_bitpack_op.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+template <typename T>
+__global__ void CompareAndBitpackKernel(const int size, const T* threshold,
+ const T* input, uint8* output) {
+ // TODO(ebrevdo): Erich said: to get a better memory access pattern
+ // you could have 8 threads load this data and do a comparison, then
+ // use the ballot instruction to combine the values from each thread
+ // in the warp in one instruction (so each thread will have the
+ // result for 4 blocks) followed by an appropriate shift and mask to
+ // get the 8-bits of interest.
+ const T thresh = ldg(threshold);
+ CUDA_1D_KERNEL_LOOP(i, size) {
+ const T* block = input + 8 * i;
+ output[i] =
+ ((((ldg(block) > thresh) << 7)) | (((ldg(block + 1) > thresh) << 6)) |
+ (((ldg(block + 2) > thresh) << 5)) |
+ (((ldg(block + 3) > thresh) << 4)) |
+ (((ldg(block + 4) > thresh) << 3)) |
+ (((ldg(block + 5) > thresh) << 2)) |
+ (((ldg(block + 6) > thresh) << 1)) | (((ldg(block + 7) > thresh))));
+ }
+}
+
+template <>
+__global__ void CompareAndBitpackKernel<bool>(const int size,
+ const bool* threshold,
+ const bool* input,
+ uint8* output) {
+ // TODO(ebrevdo): Erich said: I think you could again have multiple
+ // threads work on one block and use the ballot instruction to the
+ // bit packing in one instruction.
+ CUDA_1D_KERNEL_LOOP(i, size) {
+ const int64 block = ldg(reinterpret_cast<const int64*>(input + 8 * i));
+ // NOTE(ebrevdo): This assumes memory is little-endian.
+ output[i] =
+ ((((block & (1LL << (7 * 8))) >> (7 * 8 - 0))) |
+ (((block & (1LL << (6 * 8))) >> (6 * 8 - 1))) |
+ (((block & (1LL << (5 * 8))) >> (5 * 8 - 2))) |
+ (((block & (1LL << (4 * 8))) >> (4 * 8 - 3))) |
+ (((block & (1LL << (3 * 8))) >> (3 * 8 - 4))) |
+ (((block & (1LL << (2 * 8))) >> (2 * 8 - 5))) |
+ (((block & (1LL << 8)) >> (1 * 8 - 6))) | (((block & (1LL)) << 7)));
+ }
+}
+
+template <>
+__global__ void CompareAndBitpackKernel<float>(const int size,
+ const float* threshold,
+ const float* input,
+ uint8* output) {
+ const float thresh = ldg(threshold);
+ CUDA_1D_KERNEL_LOOP(i, size) {
+ const float4 block0 = ldg(reinterpret_cast<const float4*>(input + 8 * i));
+ const float4 block1 =
+ ldg(reinterpret_cast<const float4*>(input + 8 * i + 4));
+ output[i] = ((((block0.x > thresh) << 7)) | (((block0.y > thresh) << 6)) |
+ (((block0.z > thresh) << 5)) | (((block0.w > thresh) << 4)) |
+ (((block1.x > thresh) << 3)) | (((block1.y > thresh) << 2)) |
+ (((block1.z > thresh) << 1)) | (((block1.w > thresh))));
+ }
+}
+
+template <>
+__global__ void CompareAndBitpackKernel<double>(const int size,
+ const double* threshold,
+ const double* input,
+ uint8* output) {
+ const double thresh = ldg(threshold);
+ CUDA_1D_KERNEL_LOOP(i, size) {
+ const double2 block0 = ldg(reinterpret_cast<const double2*>(input + 8 * i));
+ const double2 block1 =
+ ldg(reinterpret_cast<const double2*>(input + 8 * i + 2));
+ const double2 block2 =
+ ldg(reinterpret_cast<const double2*>(input + 8 * i + 4));
+ const double2 block3 =
+ ldg(reinterpret_cast<const double2*>(input + 8 * i + 6));
+ output[i] = ((((block0.x > thresh) << 7)) | (((block0.y > thresh) << 6)) |
+ (((block1.x > thresh) << 5)) | (((block1.y > thresh) << 4)) |
+ (((block2.x > thresh) << 3)) | (((block2.y > thresh) << 2)) |
+ (((block3.x > thresh) << 1)) | (((block3.y > thresh))));
+ }
+}
+
+#define DEFINE_GPU_SPECS(T) \
+ template <> \
+ void CompareAndBitpack<GPUDevice, T>::operator()( \
+ OpKernelContext* c, typename TTypes<T>::ConstMatrix input, \
+ typename TTypes<T>::ConstScalar threshold, \
+ TTypes<uint8>::Matrix output) { \
+ const GPUDevice& d = c->eigen_device<GPUDevice>(); \
+ int64 total_count = output.size(); \
+ CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); \
+ \
+ CompareAndBitpackKernel<T> \
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( \
+ total_count, threshold.data(), input.data(), output.data()); \
+ }
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS)
+TF_CALL_bool(DEFINE_GPU_SPECS)
+
+#undef DECLARE_GPU_SPECS
+
+} // namespace functor
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/population_count_op.cc b/tensorflow/core/kernels/population_count_op.cc
new file mode 100644
index 0000000000..12ff6b69f8
--- /dev/null
+++ b/tensorflow/core/kernels/population_count_op.cc
@@ -0,0 +1,163 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/math_ops.cc
+
+#define EIGEN_USE_THREADS
+
+#include <bitset>
+
+#include "tensorflow/core/kernels/population_count_op.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class PopulationCountOp : public OpKernel {
+ public:
+ explicit PopulationCountOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* c) override {
+ const Tensor& input_t = c->input(0);
+ Tensor* output_t;
+ OP_REQUIRES_OK(c, c->allocate_output(0, input_t.shape(), &output_t));
+
+ auto input = input_t.flat<T>();
+ auto output = output_t->flat<uint8>();
+
+ functor::PopulationCount<Device, T> popcnt;
+ popcnt(c, input, output);
+ }
+};
+
+#define REGISTER_POPULATION_COUNT(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("PopulationCount").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ PopulationCountOp<CPUDevice, type>);
+
+TF_CALL_uint8(REGISTER_POPULATION_COUNT);
+TF_CALL_int8(REGISTER_POPULATION_COUNT);
+TF_CALL_uint16(REGISTER_POPULATION_COUNT);
+TF_CALL_int16(REGISTER_POPULATION_COUNT);
+TF_CALL_int32(REGISTER_POPULATION_COUNT);
+TF_CALL_int64(REGISTER_POPULATION_COUNT);
+
+#undef REGISTER_POPULATION_COUNT
+
+namespace functor {
+
+namespace {
+
+template <typename T>
+inline uint8 PopCnt(const T v);
+
+#define POPCNT(T, N) \
+ template <> \
+ uint8 PopCnt<T>(const T v) { \
+ return std::bitset<N>(v).count(); \
+ }
+
+POPCNT(int8, 8);
+POPCNT(uint8, 8);
+POPCNT(int16, 16);
+POPCNT(uint16, 16);
+POPCNT(int32, 32);
+POPCNT(int64, 64);
+
+#undef POPCNT
+
+} // namespace
+
+template <typename T>
+struct PopulationCount<CPUDevice, T> {
+ void operator()(OpKernelContext* c, typename TTypes<T>::ConstFlat input,
+ TTypes<uint8>::Flat output) {
+ const T* input_ptr = input.data();
+ uint8* output_ptr = output.data();
+ auto shard = [input_ptr, output_ptr](int64 start, int64 limit) {
+ for (int64 i = start; i < limit; ++i) {
+ output_ptr[i] = PopCnt<T>(input_ptr[i]);
+ }
+ };
+ int64 total_shards = input.size();
+ // Approximating cost of popcnt: convert T to int64
+ // (std::bitset constructor) and convert int64 to uint8
+ // (bitset.count() -> output). The .count() itself is relatively cheap.
+ const double total_cost = (Eigen::TensorOpCost::CastCost<T, uint8>() +
+ Eigen::TensorOpCost::CastCost<int64, uint8>());
+ const int64 shard_cost = (total_cost >= static_cast<double>(kint64max))
+ ? kint64max
+ : static_cast<int64>(total_cost);
+
+ auto worker_threads = *(c->device()->tensorflow_cpu_worker_threads());
+ Shard(worker_threads.num_threads, worker_threads.workers, total_shards,
+ shard_cost, shard);
+ }
+};
+
+} // namespace functor
+
+#if GOOGLE_CUDA
+
+#define REGISTER_POPULATION_COUNT(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("PopulationCount").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ PopulationCountOp<GPUDevice, type>)
+
+TF_CALL_uint8(REGISTER_POPULATION_COUNT);
+TF_CALL_int8(REGISTER_POPULATION_COUNT);
+TF_CALL_uint16(REGISTER_POPULATION_COUNT);
+TF_CALL_int16(REGISTER_POPULATION_COUNT);
+TF_CALL_int32(REGISTER_POPULATION_COUNT);
+TF_CALL_int64(REGISTER_POPULATION_COUNT);
+
+#undef REGISTER_POPULATION_COUNT
+
+namespace functor {
+
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void PopulationCount<GPUDevice, T>::operator()( \
+ OpKernelContext* c, typename TTypes<T>::ConstFlat input, \
+ TTypes<uint8>::Flat output); \
+ extern template struct PopulationCount<GPUDevice, T>
+
+TF_CALL_uint8(DECLARE_GPU_SPEC);
+TF_CALL_int8(DECLARE_GPU_SPEC);
+TF_CALL_uint16(DECLARE_GPU_SPEC);
+TF_CALL_int16(DECLARE_GPU_SPEC);
+TF_CALL_int32(DECLARE_GPU_SPEC);
+TF_CALL_int64(DECLARE_GPU_SPEC);
+
+#undef DECLARE_GPU_SPEC
+
+} // namespace functor
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/population_count_op.h b/tensorflow/core/kernels/population_count_op.h
new file mode 100644
index 0000000000..de89582e13
--- /dev/null
+++ b/tensorflow/core/kernels/population_count_op.h
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+template <typename Device, typename T>
+struct PopulationCount {
+ void operator()(OpKernelContext* c, typename TTypes<T>::ConstFlat input,
+ TTypes<uint8>::Flat output);
+};
+
+} // namespace functor
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_
diff --git a/tensorflow/core/kernels/population_count_op_gpu.cu.cc b/tensorflow/core/kernels/population_count_op_gpu.cu.cc
new file mode 100644
index 0000000000..27a687ba40
--- /dev/null
+++ b/tensorflow/core/kernels/population_count_op_gpu.cu.cc
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/population_count_op.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+template <typename T>
+__global__ void PopulationCountKernel(const int size, const T* input,
+ uint8* output) {
+ CUDA_1D_KERNEL_LOOP(i, size) { output[i] = __popc(ldg(input + i)); }
+}
+
+template <>
+__global__ void PopulationCountKernel(const int size, const int8* input,
+ uint8* output) {
+ // For some reason, __popc on a negative int8 gets confused.
+ CUDA_1D_KERNEL_LOOP(i, size) {
+ output[i] = __popc(ldg(reinterpret_cast<const uint8*>(input + i)));
+ }
+}
+
+template <>
+__global__ void PopulationCountKernel(const int size, const int16* input,
+ uint8* output) {
+ // For some reason, __popc on a negative int16 gets confused.
+ CUDA_1D_KERNEL_LOOP(i, size) {
+ output[i] = __popc(ldg(reinterpret_cast<const uint16*>(input + i)));
+ }
+}
+
+template <>
+__global__ void PopulationCountKernel<int64>(const int size, const int64* input,
+ uint8* output) {
+ CUDA_1D_KERNEL_LOOP(i, size) { output[i] = __popcll(ldg(input + i)); }
+}
+
+#define DEFINE_GPU_SPECS(T) \
+ template <> \
+ void PopulationCount<GPUDevice, T>::operator()( \
+ OpKernelContext* c, typename TTypes<T>::ConstFlat input, \
+ TTypes<uint8>::Flat output) { \
+ const GPUDevice& d = c->eigen_device<GPUDevice>(); \
+ int64 total_count = input.size(); \
+ CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); \
+ PopulationCountKernel<T> \
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( \
+ total_count, input.data(), output.data()); \
+ }
+
+TF_CALL_uint8(DEFINE_GPU_SPECS);
+TF_CALL_int8(DEFINE_GPU_SPECS);
+TF_CALL_uint16(DEFINE_GPU_SPECS);
+TF_CALL_int16(DEFINE_GPU_SPECS);
+TF_CALL_int32(DEFINE_GPU_SPECS);
+TF_CALL_int64(DEFINE_GPU_SPECS);
+
+#undef DEFINE_GPU_SPECS
+
+} // namespace functor
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/ops/bitwise_ops.cc b/tensorflow/core/ops/bitwise_ops.cc
index 2005d5e102..3ffc4ab74a 100644
--- a/tensorflow/core/ops/bitwise_ops.cc
+++ b/tensorflow/core/ops/bitwise_ops.cc
@@ -40,6 +40,22 @@ computation is performed on the underlying representation of x.
.Attr("T: {int8, int16, int32, int64, uint8, uint16}") \
.SetShapeFn(shape_inference::UnchangedShape)
+REGISTER_OP("PopulationCount")
+ .Input("x: T")
+ .Output("y: uint8")
+ .Attr("T: {int8, int16, int32, int64, uint8, uint16}")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Computes element-wise population count (a.k.a. popcount, bitsum, bitcount).
+
+For each entry in `x`, calculates the number of `1` (on) bits in the binary
+representation of that entry.
+
+**NOTE**: It is more efficient to first `tf.bitcast` your tensors into
+`int32` or `int64` and perform the bitcount on the result, than to feed in
+8- or 16-bit inputs and then aggregate the resulting counts.
+)doc");
+
REGISTER_OP("BitwiseAnd").BINARY_BITWISE().Doc(R"doc(
Elementwise computes the bitwise AND of `x` and `y`.
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 94224f22b9..b82035bfc3 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -4598,6 +4598,37 @@ op {
}
}
op {
+ name: "CompareAndBitpack"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "threshold"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_UINT8
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BOOL
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "Complex"
input_arg {
name: "real"
@@ -16268,6 +16299,31 @@ op {
}
}
op {
+ name: "PopulationCount"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type: DT_UINT8
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_UINT8
+ type: DT_UINT16
+ }
+ }
+ }
+}
+op {
name: "Pow"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index f4e0625c66..36f999ff60 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -2458,6 +2458,64 @@ out_type: The type of the output. Should be a lower bit depth than Tinput.
)doc");
+REGISTER_OP("CompareAndBitpack")
+ .Input("input: T")
+ .Input("threshold: T")
+ .Output("output: uint8")
+ .Attr("T: {bool, float16, float32, float64, int8, int16, int32, int64}")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle input;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ ShapeHandle output = input;
+ if (c->RankKnown(input)) {
+ int rank = c->Rank(input);
+ auto inner_dim = c->Dim(input, rank - 1);
+ DimensionHandle inferred_dim;
+ TF_RETURN_IF_ERROR(c->Divide(inner_dim, 8,
+ /* evenly_divisible */ true,
+ &inferred_dim));
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(output, rank - 1, inferred_dim, &output));
+ }
+ c->set_output(0, output);
+
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Compare values of `input` to `threshold` and pack resulting bits into a `uint8`.
+
+Each comparison returns a boolean `true` (if `input_value > threshold`)
+or and `false` otherwise.
+
+This operation is useful for Locality-Sensitive-Hashing (LSH) and other
+algorithms that use hashing approximations of cosine and `L2` distances;
+codes can be generated from an input via:
+
+```python
+codebook_size = 50
+codebook_bits = codebook_size * 32
+codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits],
+ dtype=x.dtype,
+ initializer=tf.orthogonal_initializer())
+codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.)
+codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32
+# now codes has shape x.shape[:-1] + [codebook_size]
+```
+
+**NOTE**: Currently, the innermost dimension of the tensor must be divisible
+by 8.
+
+Given an `input` shaped `[s0, s1, ..., s_n]`, the output is
+a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`.
+
+input: Values to compare against `threshold` and bitpack.
+threshold: Threshold to compare against.
+T: The type of the input and threshold.
+output: The bitpacked comparisons.
+)doc");
+
REGISTER_OP("RequantizationRange")
.Input("input: Tinput")
.Input("input_min: float")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 1822fc1133..468434bd28 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -499,7 +499,7 @@ op {
}
input_arg {
name: "reduction_indices"
- description: "The dimensions to reduce."
+ description: "The dimensions to reduce. Must be in the range\n`[-rank(input), rank(input))`."
type_attr: "Tidx"
}
output_arg {
@@ -601,7 +601,7 @@ op {
}
input_arg {
name: "reduction_indices"
- description: "The dimensions to reduce."
+ description: "The dimensions to reduce. Must be in the range\n`[-rank(input), rank(input))`."
type_attr: "Tidx"
}
output_arg {
@@ -1691,7 +1691,7 @@ op {
}
input_arg {
name: "dimension"
- description: "int32 or int64, 0 <= dimension < rank(input). Describes\nwhich dimension of the input Tensor to reduce across. For vectors,\nuse dimension = 0."
+ description: "int32 or int64, must be in the range `[-rank(input), rank(input))`.\nDescribes which dimension of the input Tensor to reduce across. For vectors,\nuse dimension = 0."
type_attr: "Tidx"
}
output_arg {
@@ -1757,7 +1757,7 @@ op {
}
input_arg {
name: "dimension"
- description: "int32 or int64, 0 <= dimension < rank(input). Describes\nwhich dimension of the input Tensor to reduce across. For vectors,\nuse dimension = 0."
+ description: "int32 or int64, must be in the range `[-rank(input), rank(input))`.\nDescribes which dimension of the input Tensor to reduce across. For vectors,\nuse dimension = 0."
type_attr: "Tidx"
}
output_arg {
@@ -4407,6 +4407,43 @@ op {
description: "For an explanation see \"Differentiation of the Cholesky algorithm\" by\nIain Murray http://arxiv.org/abs/1602.07527."
}
op {
+ name: "CompareAndBitpack"
+ input_arg {
+ name: "input"
+ description: "Values to compare against `threshold` and bitpack."
+ type_attr: "T"
+ }
+ input_arg {
+ name: "threshold"
+ description: "Threshold to compare against."
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ description: "The bitpacked comparisons."
+ type: DT_UINT8
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "The type of the input and threshold."
+ allowed_values {
+ list {
+ type: DT_BOOL
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ summary: "Compare values of `input` to `threshold` and pack resulting bits into a `uint8`."
+ description: "Each comparison returns a boolean `true` (if `input_value > threshold`)\nor and `false` otherwise.\n\nThis operation is useful for Locality-Sensitive-Hashing (LSH) and other\nalgorithms that use hashing approximations of cosine and `L2` distances;\ncodes can be generated from an input via:\n\n```python\ncodebook_size = 50\ncodebook_bits = codebook_size * 32\ncodebook = tf.get_variable(\'codebook\', [x.shape[-1].value, codebook_bits],\n dtype=x.dtype,\n initializer=tf.orthogonal_initializer())\ncodes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.)\ncodes = tf.bitcast(codes, tf.int32) # go from uint8 to int32\n# now codes has shape x.shape[:-1] + [codebook_size]\n```\n\n**NOTE**: Currently, the innermost dimension of the tensor must be divisible\nby 8.\n\nGiven an `input` shaped `[s0, s1, ..., s_n]`, the output is\na `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`."
+}
+op {
name: "Complex"
input_arg {
name: "real"
@@ -5656,10 +5693,12 @@ op {
name: "Cumprod"
input_arg {
name: "x"
+ description: "A `Tensor`. Must be one of the following types: `float32`, `float64`,\n`int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,\n`complex128`, `qint8`, `quint8`, `qint32`, `half`."
type_attr: "T"
}
input_arg {
name: "axis"
+ description: "A `Tensor` of type `int32` (default: 0). Must be in the range\n`[-rank(x), rank(x))`."
type_attr: "Tidx"
}
output_arg {
@@ -5672,6 +5711,7 @@ op {
default_value {
b: false
}
+ description: "If `True`, perform exclusive cumprod."
}
attr {
name: "reverse"
@@ -5679,6 +5719,7 @@ op {
default_value {
b: false
}
+ description: "A `bool` (default: False)."
}
attr {
name: "T"
@@ -5722,10 +5763,12 @@ op {
name: "Cumsum"
input_arg {
name: "x"
+ description: "A `Tensor`. Must be one of the following types: `float32`, `float64`,\n`int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,\n`complex128`, `qint8`, `quint8`, `qint32`, `half`."
type_attr: "T"
}
input_arg {
name: "axis"
+ description: "A `Tensor` of type `int32` (default: 0). Must be in the range\n`[-rank(x), rank(x))`."
type_attr: "Tidx"
}
output_arg {
@@ -5738,6 +5781,7 @@ op {
default_value {
b: false
}
+ description: "If `True`, perform exclusive cumsum."
}
attr {
name: "reverse"
@@ -5745,6 +5789,7 @@ op {
default_value {
b: false
}
+ description: "A `bool` (default: False)."
}
attr {
name: "T"
@@ -7615,7 +7660,7 @@ op {
}
input_arg {
name: "dim"
- description: "0-D (scalar). Specifies the dimension index at which to\nexpand the shape of `input`."
+ description: "0-D (scalar). Specifies the dimension index at which to\nexpand the shape of `input`. Must be in the range\n`[-rank(input) - 1, rank(input)]`."
type_attr: "Tdim"
}
output_arg {
@@ -12325,7 +12370,7 @@ op {
}
input_arg {
name: "reduction_indices"
- description: "The dimensions to reduce."
+ description: "The dimensions to reduce. Must be in the range\n`[-rank(input), rank(input))`."
type_attr: "Tidx"
}
output_arg {
@@ -13102,7 +13147,7 @@ op {
}
input_arg {
name: "reduction_indices"
- description: "The dimensions to reduce."
+ description: "The dimensions to reduce. Must be in the range\n`[-rank(input), rank(input))`."
type_attr: "Tidx"
}
output_arg {
@@ -13293,7 +13338,7 @@ op {
}
input_arg {
name: "reduction_indices"
- description: "The dimensions to reduce."
+ description: "The dimensions to reduce. Must be in the range\n`[-rank(input), rank(input))`."
type_attr: "Tidx"
}
output_arg {
@@ -15437,6 +15482,33 @@ op {
description: "The polygamma function is defined as:\n\n\n\\\\(\\psi^{(n)}(x) = \\frac{d^n}{dx^n} \\psi(x)\\\\)\n\nwhere \\\\(\\psi(x)\\\\) is the digamma function."
}
op {
+ name: "PopulationCount"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type: DT_UINT8
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_UINT8
+ type: DT_UINT16
+ }
+ }
+ }
+ summary: "Computes element-wise population count (a.k.a. popcount, bitsum, bitcount)."
+ description: "For each entry in `x`, calculates the number of `1` (on) bits in the binary\nrepresentation of that entry.\n\n**NOTE**: It is more efficient to first `tf.bitcast` your tensors into\n`int32` or `int64` and perform the bitcount on the result, than to feed in\n8- or 16-bit inputs and then aggregate the resulting counts."
+}
+op {
name: "Pow"
input_arg {
name: "x"
@@ -15662,7 +15734,7 @@ op {
}
input_arg {
name: "reduction_indices"
- description: "The dimensions to reduce."
+ description: "The dimensions to reduce. Must be in the range\n`[-rank(input), rank(input))`."
type_attr: "Tidx"
}
output_arg {
@@ -21708,7 +21780,7 @@ op {
}
input_arg {
name: "axis"
- description: "1-D. The indices of the dimensions to reverse."
+ description: "1-D. The indices of the dimensions to reverse. Must be in the range\n`[-rank(tensor), rank(tensor))`."
type_attr: "Tidx"
}
output_arg {
@@ -27260,7 +27332,7 @@ op {
list {
}
}
- description: "If specified, only squeezes the dimensions listed. The dimension\nindex starts at 0. It is an error to squeeze a dimension that is not 1."
+ description: "If specified, only squeezes the dimensions listed. The dimension\nindex starts at 0. It is an error to squeeze a dimension that is not 1. Must\nbe in the range `[-rank(input), rank(input))`."
has_minimum: true
}
summary: "Removes dimensions of size 1 from the shape of a tensor."
@@ -28250,7 +28322,7 @@ op {
}
input_arg {
name: "reduction_indices"
- description: "The dimensions to reduce."
+ description: "The dimensions to reduce. Must be in the range\n`[-rank(input), rank(input))`."
type_attr: "Tidx"
}
output_arg {
diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl
index e1f123a0a0..04bf2aeca6 100644
--- a/tensorflow/core/platform/default/build_config_root.bzl
+++ b/tensorflow/core/platform/default/build_config_root.bzl
@@ -28,7 +28,7 @@ def tf_additional_verbs_deps():
"//tensorflow:with_verbs_support": [
"//tensorflow/contrib/verbs:verbs_server_lib",
"//tensorflow/contrib/verbs:grpc_verbs_client",
- ],
+ ],
"//conditions:default": [],
})
diff --git a/tensorflow/core/platform/macros.h b/tensorflow/core/platform/macros.h
index eaf0171e72..47523c7d2b 100644
--- a/tensorflow/core/platform/macros.h
+++ b/tensorflow/core/platform/macros.h
@@ -20,6 +20,7 @@ limitations under the License.
#if (defined(__GNUC__) || defined(__APPLE__)) && !defined(SWIG)
// Compiler supports GCC-style attributes
#define TF_ATTRIBUTE_NORETURN __attribute__((noreturn))
+#define TF_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline))
#define TF_ATTRIBUTE_NOINLINE __attribute__((noinline))
#define TF_ATTRIBUTE_UNUSED __attribute__((unused))
#define TF_ATTRIBUTE_COLD __attribute__((cold))
@@ -33,6 +34,7 @@ limitations under the License.
#elif defined(COMPILER_MSVC)
// Non-GCC equivalents
#define TF_ATTRIBUTE_NORETURN __declspec(noreturn)
+#define TF_ATTRIBUTE_ALWAYS_INLINE
#define TF_ATTRIBUTE_NOINLINE
#define TF_ATTRIBUTE_UNUSED
#define TF_ATTRIBUTE_COLD
@@ -43,6 +45,7 @@ limitations under the License.
#else
// Non-GCC equivalents
#define TF_ATTRIBUTE_NORETURN
+#define TF_ATTRIBUTE_ALWAYS_INLINE
#define TF_ATTRIBUTE_NOINLINE
#define TF_ATTRIBUTE_UNUSED
#define TF_ATTRIBUTE_COLD
diff --git a/tensorflow/core/platform/posix/error.cc b/tensorflow/core/platform/posix/error.cc
index df5c800879..e9baad5422 100644
--- a/tensorflow/core/platform/posix/error.cc
+++ b/tensorflow/core/platform/posix/error.cc
@@ -171,11 +171,7 @@ error::Code ErrnoToCode(int err_number) {
Status IOError(const string& context, int err_number) {
auto code = ErrnoToCode(err_number);
- if (code == error::UNKNOWN) {
- return Status(code, strings::StrCat(context, "; ", strerror(err_number)));
- } else {
- return Status(code, context);
- }
+ return Status(code, strings::StrCat(context, "; ", strerror(err_number)));
}
} // namespace tensorflow
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 10388509a9..095cbbe637 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -1130,7 +1130,8 @@ type SqueezeAttr func(optionalAttr)
// SqueezeSqueezeDims sets the optional squeeze_dims attribute to value.
//
// value: If specified, only squeezes the dimensions listed. The dimension
-// index starts at 0. It is an error to squeeze a dimension that is not 1.
+// index starts at 0. It is an error to squeeze a dimension that is not 1. Must
+// be in the range `[-rank(input), rank(input))`.
// If not specified, defaults to <>
//
// REQUIRES: len(value) >= 0
@@ -7069,6 +7070,61 @@ func TFRecordReaderV2(scope *Scope, optional ...TFRecordReaderV2Attr) (reader_ha
return op.Output(0)
}
+// TextLineReaderV2Attr is an optional argument to TextLineReaderV2.
+type TextLineReaderV2Attr func(optionalAttr)
+
+// TextLineReaderV2SkipHeaderLines sets the optional skip_header_lines attribute to value.
+//
+// value: Number of lines to skip from the beginning of every file.
+// If not specified, defaults to 0
+func TextLineReaderV2SkipHeaderLines(value int64) TextLineReaderV2Attr {
+ return func(m optionalAttr) {
+ m["skip_header_lines"] = value
+ }
+}
+
+// TextLineReaderV2Container sets the optional container attribute to value.
+//
+// value: If non-empty, this reader is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func TextLineReaderV2Container(value string) TextLineReaderV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// TextLineReaderV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this reader is named in the given bucket
+// with this shared_name. Otherwise, the node name is used instead.
+// If not specified, defaults to ""
+func TextLineReaderV2SharedName(value string) TextLineReaderV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// A Reader that outputs the lines of a file delimited by '\n'.
+//
+// Returns The handle to reference the Reader.
+func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "TextLineReaderV2",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes rectified linear 6: `min(max(features, 0), 6)`.
func Relu6(scope *Scope, features tf.Output) (activations tf.Output) {
if scope.Err() != nil {
@@ -12819,7 +12875,8 @@ func ReciprocalGrad(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
//
// Arguments:
// tensor: Up to 8-D.
-// axis: 1-D. The indices of the dimensions to reverse.
+// axis: 1-D. The indices of the dimensions to reverse. Must be in the range
+// `[-rank(tensor), rank(tensor))`.
//
// Returns The same shape as `tensor`.
func ReverseV2(scope *Scope, tensor tf.Output, axis tf.Output) (output tf.Output) {
@@ -14493,61 +14550,6 @@ func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
-// TextLineReaderV2Attr is an optional argument to TextLineReaderV2.
-type TextLineReaderV2Attr func(optionalAttr)
-
-// TextLineReaderV2SkipHeaderLines sets the optional skip_header_lines attribute to value.
-//
-// value: Number of lines to skip from the beginning of every file.
-// If not specified, defaults to 0
-func TextLineReaderV2SkipHeaderLines(value int64) TextLineReaderV2Attr {
- return func(m optionalAttr) {
- m["skip_header_lines"] = value
- }
-}
-
-// TextLineReaderV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this reader is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func TextLineReaderV2Container(value string) TextLineReaderV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// TextLineReaderV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this reader is named in the given bucket
-// with this shared_name. Otherwise, the node name is used instead.
-// If not specified, defaults to ""
-func TextLineReaderV2SharedName(value string) TextLineReaderV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// A Reader that outputs the lines of a file delimited by '\n'.
-//
-// Returns The handle to reference the Reader.
-func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "TextLineReaderV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Component-wise multiplies a SparseTensor by a dense Tensor.
//
// The output locations corresponding to the implicitly zero elements in the sparse
@@ -16147,6 +16149,8 @@ func SegmentMean(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf
type CumprodAttr func(optionalAttr)
// CumprodExclusive sets the optional exclusive attribute to value.
+//
+// value: If `True`, perform exclusive cumprod.
// If not specified, defaults to false
func CumprodExclusive(value bool) CumprodAttr {
return func(m optionalAttr) {
@@ -16155,6 +16159,8 @@ func CumprodExclusive(value bool) CumprodAttr {
}
// CumprodReverse sets the optional reverse attribute to value.
+//
+// value: A `bool` (default: False).
// If not specified, defaults to false
func CumprodReverse(value bool) CumprodAttr {
return func(m optionalAttr) {
@@ -16192,6 +16198,13 @@ func CumprodReverse(value bool) CumprodAttr {
// ```python
// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1]
// ```
+//
+// Arguments:
+// x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
+// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
+// `complex128`, `qint8`, `quint8`, `qint32`, `half`.
+// axis: A `Tensor` of type `int32` (default: 0). Must be in the range
+// `[-rank(x), rank(x))`.
func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) {
if scope.Err() != nil {
return
@@ -16420,6 +16433,8 @@ func QuantizedRelu6(scope *Scope, features tf.Output, min_features tf.Output, ma
type CumsumAttr func(optionalAttr)
// CumsumExclusive sets the optional exclusive attribute to value.
+//
+// value: If `True`, perform exclusive cumsum.
// If not specified, defaults to false
func CumsumExclusive(value bool) CumsumAttr {
return func(m optionalAttr) {
@@ -16428,6 +16443,8 @@ func CumsumExclusive(value bool) CumsumAttr {
}
// CumsumReverse sets the optional reverse attribute to value.
+//
+// value: A `bool` (default: False).
// If not specified, defaults to false
func CumsumReverse(value bool) CumsumAttr {
return func(m optionalAttr) {
@@ -16465,6 +16482,13 @@ func CumsumReverse(value bool) CumsumAttr {
// ```python
// tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0]
// ```
+//
+// Arguments:
+// x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
+// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
+// `complex128`, `qint8`, `quint8`, `qint32`, `half`.
+// axis: A `Tensor` of type `int32` (default: 0). Must be in the range
+// `[-rank(x), rank(x))`.
func Cumsum(scope *Scope, x tf.Output, axis tf.Output, optional ...CumsumAttr) (out tf.Output) {
if scope.Err() != nil {
return
@@ -17894,6 +17918,28 @@ func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf.
return op.Output(0), op.Output(1), op.Output(2)
}
+// Computes element-wise population count (a.k.a. popcount, bitsum, bitcount).
+//
+// For each entry in `x`, calculates the number of `1` (on) bits in the binary
+// representation of that entry.
+//
+// **NOTE**: It is more efficient to first `tf.bitcast` your tensors into
+// `int32` or `int64` and perform the bitcount on the result, than to feed in
+// 8- or 16-bit inputs and then aggregate the resulting counts.
+func PopulationCount(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "PopulationCount",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// AssertAttr is an optional argument to Assert.
type AssertAttr func(optionalAttr)
@@ -18063,7 +18109,8 @@ func AnyKeepDims(value bool) AnyAttr {
//
// Arguments:
// input: The tensor to reduce.
-// reduction_indices: The dimensions to reduce.
+// reduction_indices: The dimensions to reduce. Must be in the range
+// `[-rank(input), rank(input))`.
//
// Returns The reduced tensor.
func Any(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ...AnyAttr) (output tf.Output) {
@@ -19213,7 +19260,8 @@ func ProdKeepDims(value bool) ProdAttr {
//
// Arguments:
// input: The tensor to reduce.
-// reduction_indices: The dimensions to reduce.
+// reduction_indices: The dimensions to reduce. Must be in the range
+// `[-rank(input), rank(input))`.
//
// Returns The reduced tensor.
func Prod(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ...ProdAttr) (output tf.Output) {
@@ -20258,7 +20306,8 @@ func MaxKeepDims(value bool) MaxAttr {
//
// Arguments:
// input: The tensor to reduce.
-// reduction_indices: The dimensions to reduce.
+// reduction_indices: The dimensions to reduce. Must be in the range
+// `[-rank(input), rank(input))`.
//
// Returns The reduced tensor.
func Max(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ...MaxAttr) (output tf.Output) {
@@ -20583,7 +20632,8 @@ func Sqrt(scope *Scope, x tf.Output) (y tf.Output) {
// Arguments:
//
// dim: 0-D (scalar). Specifies the dimension index at which to
-// expand the shape of `input`.
+// expand the shape of `input`. Must be in the range
+// `[-rank(input) - 1, rank(input)]`.
//
// Returns Contains the same data as `input`, but its shape has an additional
// dimension of size 1 added.
@@ -20623,7 +20673,8 @@ func AllKeepDims(value bool) AllAttr {
//
// Arguments:
// input: The tensor to reduce.
-// reduction_indices: The dimensions to reduce.
+// reduction_indices: The dimensions to reduce. Must be in the range
+// `[-rank(input), rank(input))`.
//
// Returns The reduced tensor.
func All(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ...AllAttr) (output tf.Output) {
@@ -21665,8 +21716,8 @@ func ArgMinOutputType(value tf.DataType) ArgMinAttr {
//
// Arguments:
//
-// dimension: int32 or int64, 0 <= dimension < rank(input). Describes
-// which dimension of the input Tensor to reduce across. For vectors,
+// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`.
+// Describes which dimension of the input Tensor to reduce across. For vectors,
// use dimension = 0.
func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMinAttr) (output tf.Output) {
if scope.Err() != nil {
@@ -22716,7 +22767,8 @@ func MeanKeepDims(value bool) MeanAttr {
//
// Arguments:
// input: The tensor to reduce.
-// reduction_indices: The dimensions to reduce.
+// reduction_indices: The dimensions to reduce. Must be in the range
+// `[-rank(input), rank(input))`.
//
// Returns The reduced tensor.
func Mean(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ...MeanAttr) (output tf.Output) {
@@ -22856,7 +22908,8 @@ func MinKeepDims(value bool) MinAttr {
//
// Arguments:
// input: The tensor to reduce.
-// reduction_indices: The dimensions to reduce.
+// reduction_indices: The dimensions to reduce. Must be in the range
+// `[-rank(input), rank(input))`.
//
// Returns The reduced tensor.
func Min(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ...MinAttr) (output tf.Output) {
@@ -22914,8 +22967,8 @@ func ArgMaxOutputType(value tf.DataType) ArgMaxAttr {
//
// Arguments:
//
-// dimension: int32 or int64, 0 <= dimension < rank(input). Describes
-// which dimension of the input Tensor to reduce across. For vectors,
+// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`.
+// Describes which dimension of the input Tensor to reduce across. For vectors,
// use dimension = 0.
func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMaxAttr) (output tf.Output) {
if scope.Err() != nil {
@@ -23888,6 +23941,51 @@ func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Outp
return op.Output(0), op.Output(1), op.Output(2)
}
+// Compare values of `input` to `threshold` and pack resulting bits into a `uint8`.
+//
+// Each comparison returns a boolean `true` (if `input_value > threshold`)
+// or and `false` otherwise.
+//
+// This operation is useful for Locality-Sensitive-Hashing (LSH) and other
+// algorithms that use hashing approximations of cosine and `L2` distances;
+// codes can be generated from an input via:
+//
+// ```python
+// codebook_size = 50
+// codebook_bits = codebook_size * 32
+// codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits],
+// dtype=x.dtype,
+// initializer=tf.orthogonal_initializer())
+// codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.)
+// codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32
+// # now codes has shape x.shape[:-1] + [codebook_size]
+// ```
+//
+// **NOTE**: Currently, the innermost dimension of the tensor must be divisible
+// by 8.
+//
+// Given an `input` shaped `[s0, s1, ..., s_n]`, the output is
+// a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`.
+//
+// Arguments:
+// input: Values to compare against `threshold` and bitpack.
+// threshold: Threshold to compare against.
+//
+// Returns The bitpacked comparisons.
+func CompareAndBitpack(scope *Scope, input tf.Output, threshold tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "CompareAndBitpack",
+ Input: []tf.Input{
+ input, threshold,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Outputs a `Summary` protocol buffer with a tensor and per-plugin data.
//
// Arguments:
@@ -24724,7 +24822,8 @@ func SumKeepDims(value bool) SumAttr {
//
// Arguments:
// input: The tensor to reduce.
-// reduction_indices: The dimensions to reduce.
+// reduction_indices: The dimensions to reduce. Must be in the range
+// `[-rank(input), rank(input))`.
//
// Returns The reduced tensor.
func Sum(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ...SumAttr) (output tf.Output) {
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index cac05c372a..896d466c25 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1672,6 +1672,18 @@ cuda_py_test(
)
cuda_py_test(
+ name = "compare_and_bitpack_op_test",
+ size = "small",
+ srcs = ["compare_and_bitpack_op_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ ],
+)
+
+cuda_py_test(
name = "scalar_test",
size = "small",
srcs = ["scalar_test.py"],
diff --git a/tensorflow/python/kernel_tests/compare_and_bitpack_op_test.py b/tensorflow/python/kernel_tests/compare_and_bitpack_op_test.py
new file mode 100644
index 0000000000..56ddd6e428
--- /dev/null
+++ b/tensorflow/python/kernel_tests/compare_and_bitpack_op_test.py
@@ -0,0 +1,83 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.compare_and_bitpack_op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class CompareAndBitpackTest(test.TestCase):
+
+ def _testCompareAndBitpack(self,
+ x, threshold,
+ truth,
+ expected_err_re=None):
+ with self.test_session(use_gpu=True):
+ ans = math_ops.compare_and_bitpack(x, threshold)
+ if expected_err_re is None:
+ tf_ans = ans.eval()
+ self.assertShapeEqual(truth, ans)
+ self.assertAllEqual(tf_ans, truth)
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ ans.eval()
+
+ def _testBasic(self, dtype):
+ rows = 371
+ cols = 294
+ x = np.random.randn(rows, cols * 8)
+ if dtype == np.bool:
+ x = x > 0
+ else:
+ x = x.astype(dtype)
+ threshold = dtype(0)
+ # np.packbits flattens the tensor, so we reshape it back to the
+ # expected dimensions.
+ truth = np.packbits(x > threshold).reshape(rows, cols)
+ self._testCompareAndBitpack(x, threshold, truth)
+
+ def testBasicFloat32(self):
+ self._testBasic(np.float32)
+
+ def testBasicFloat64(self):
+ self._testBasic(np.float64)
+
+ def testBasicFloat16(self):
+ self._testBasic(np.float16)
+
+ def testBasicBool(self):
+ self._testBasic(np.bool)
+
+ def testBasicInt8(self):
+ self._testBasic(np.int8)
+
+ def testBasicInt16(self):
+ self._testBasic(np.int16)
+
+ def testBasicInt32(self):
+ self._testBasic(np.int32)
+
+ def testBasicInt64(self):
+ self._testBasic(np.int64)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
index 04d65b88a1..9a94692569 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -88,58 +88,58 @@ class GatherTest(test.TestCase):
def testHigherRank(self):
# We check that scalar and empty indices shapes work as well
- for shape in (4, 3, 2), (2, 1, 3, 2):
- for indices_shape in (), (0,), (3, 0), (3, 5), (5, 2, 3):
- for dtype in _TEST_TYPES:
- for axis in range(len(shape)):
- params = self._buildParams(np.random.randn(*shape), dtype)
- indices = np.random.randint(shape[axis], size=indices_shape)
- with self.test_session(use_gpu=True) as sess:
- tf_params = constant_op.constant(params)
- tf_indices = constant_op.constant(indices)
- # Check that both positive and negative indices for axis work.
- tf_axis = constant_op.constant(axis)
- tf_negative_axis = constant_op.constant(-len(shape) + axis)
- gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
- gather_negative_axis = array_ops.gather(
- tf_params, tf_indices, axis=tf_negative_axis)
- gather_value, gather_negative_axis_value = sess.run(
- [gather, gather_negative_axis])
- gather_np = np.take(params, indices, axis)
- self.assertAllEqual(gather_np, gather_value)
- self.assertAllEqual(gather_np, gather_negative_axis_value)
- expected_shape = (params.shape[:axis] + indices.shape +
- params.shape[axis + 1:])
- self.assertEqual(expected_shape, gather.shape)
- self.assertEqual(expected_shape, gather_negative_axis.shape)
-
- # Test gradients
- gather_grad = np.random.randn(
- *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype)
- if dtype.is_complex:
- gather_grad -= 1j * gather_grad
- params_grad, indices_grad, axis_grad = gradients_impl.gradients(
- gather, [tf_params, tf_indices, tf_axis], gather_grad)
- self.assertEqual(indices_grad, None)
- self.assertEqual(axis_grad, None)
- # For axis 0, we are able to create an efficient IndexedSlices for
- # the gradient.
- if axis == 0:
- self.assertEqual(type(params_grad), ops.IndexedSlices)
- params_grad = ops.convert_to_tensor(params_grad)
- correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype)
- outer_dims = axis
- inner_dims = len(shape) - axis - 1
- gather_grad = gather_grad.reshape(
- shape[:axis] + (indices.size,) + shape[axis + 1:])
- for source_index, dest_index in enumerate(indices.flat):
- dest_slice = ((slice(None),) * outer_dims + (dest_index,) +
+ shape = (2, 1, 3, 2)
+ for indices_shape in (), (0,), (2, 0), (2, 3):
+ for dtype in _TEST_TYPES:
+ for axis in range(len(shape)):
+ params = self._buildParams(np.random.randn(*shape), dtype)
+ indices = np.random.randint(shape[axis], size=indices_shape)
+ with self.test_session(use_gpu=True) as sess:
+ tf_params = constant_op.constant(params)
+ tf_indices = constant_op.constant(indices)
+ # Check that both positive and negative indices for axis work.
+ tf_axis = constant_op.constant(axis)
+ tf_negative_axis = constant_op.constant(-len(shape) + axis)
+ gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
+ gather_negative_axis = array_ops.gather(
+ tf_params, tf_indices, axis=tf_negative_axis)
+ gather_value, gather_negative_axis_value = sess.run(
+ [gather, gather_negative_axis])
+ gather_np = np.take(params, indices, axis)
+ self.assertAllEqual(gather_np, gather_value)
+ self.assertAllEqual(gather_np, gather_negative_axis_value)
+ expected_shape = (params.shape[:axis] + indices.shape +
+ params.shape[axis + 1:])
+ self.assertEqual(expected_shape, gather.shape)
+ self.assertEqual(expected_shape, gather_negative_axis.shape)
+
+ # Test gradients
+ gather_grad = np.random.randn(
+ *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype)
+ if dtype.is_complex:
+ gather_grad -= 1j * gather_grad
+ params_grad, indices_grad, axis_grad = gradients_impl.gradients(
+ gather, [tf_params, tf_indices, tf_axis], gather_grad)
+ self.assertEqual(indices_grad, None)
+ self.assertEqual(axis_grad, None)
+ # For axis 0, we are able to create an efficient IndexedSlices for
+ # the gradient.
+ if axis == 0:
+ self.assertEqual(type(params_grad), ops.IndexedSlices)
+ params_grad = ops.convert_to_tensor(params_grad)
+ correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype)
+ outer_dims = axis
+ inner_dims = len(shape) - axis - 1
+ gather_grad = gather_grad.reshape(
+ shape[:axis] + (indices.size,) + shape[axis + 1:])
+ for source_index, dest_index in enumerate(indices.flat):
+ dest_slice = ((slice(None),) * outer_dims + (dest_index,) +
+ (slice(None),) * inner_dims)
+ source_slice = ((slice(None),) * outer_dims + (source_index,) +
(slice(None),) * inner_dims)
- source_slice = ((slice(None),) * outer_dims + (source_index,) +
- (slice(None),) * inner_dims)
- correct_params_grad[dest_slice] += gather_grad[source_slice]
- self.assertAllClose(correct_params_grad, params_grad.eval(),
- atol=2e-6, rtol=2e-6)
+ correct_params_grad[dest_slice] += gather_grad[source_slice]
+ self.assertAllClose(correct_params_grad, params_grad.eval(),
+ atol=2e-6, rtol=2e-6)
def testString(self):
params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index bb86640eab..f64c89ac5d 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -330,7 +330,7 @@ def rank(input, name=None):
# pylint: disable=redefined-builtin
"""Returns the rank of a tensor.
- This operation returns an integer representing the rank of `input`.
+ Returns a 0-D `int32` `Tensor` representing the rank of `input`.
For example:
diff --git a/tensorflow/python/ops/bitwise_ops.py b/tensorflow/python/ops/bitwise_ops.py
index cbabc3ed9b..44daf13537 100644
--- a/tensorflow/python/ops/bitwise_ops.py
+++ b/tensorflow/python/ops/bitwise_ops.py
@@ -36,5 +36,6 @@ ops.NotDifferentiable("BitwiseAnd")
ops.NotDifferentiable("BitwiseOr")
ops.NotDifferentiable("BitwiseXor")
ops.NotDifferentiable("Invert")
+ops.NotDifferentiable("PopulationCount")
remove_undocumented(__name__)
diff --git a/tensorflow/python/ops/bitwise_ops_test.py b/tensorflow/python/ops/bitwise_ops_test.py
index 904cf99a5a..1d08c8f82d 100644
--- a/tensorflow/python/ops/bitwise_ops_test.py
+++ b/tensorflow/python/ops/bitwise_ops_test.py
@@ -18,10 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+import six
+
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import bitwise_ops
+from tensorflow.python.ops import gen_bitwise_ops
from tensorflow.python.platform import googletest
@@ -46,6 +50,25 @@ class BitwiseOpTest(test_util.TensorFlowTestCase):
self.assertAllEqual(or_result, [5, 5, 7, 15])
self.assertAllEqual(xor_result, [5, 5, 4, 5])
+ def testPopulationCountOp(self):
+ dtype_list = [dtypes.int8, dtypes.int16,
+ dtypes.int32, dtypes.int64,
+ dtypes.uint8, dtypes.uint16]
+ raw_inputs = [0, 1, -1, 3, -3, 5, -5, 14, -14,
+ 127, 128, 255, 256, 65535, 65536,
+ 2**31 - 1, 2**31, 2**32 - 1, 2**32, -2**32 + 1, -2**32,
+ -2**63 + 1, 2**63 - 1]
+ def count_bits(x):
+ return sum([bin(z).count("1") for z in six.iterbytes(x.tobytes())])
+ for dtype in dtype_list:
+ with self.test_session(use_gpu=True) as sess:
+ print("PopulationCount test: ", dtype)
+ inputs = np.array(raw_inputs, dtype=dtype.as_numpy_dtype)
+ truth = [count_bits(x) for x in inputs]
+ input_tensor = constant_op.constant(inputs, dtype=dtype)
+ popcnt_result = sess.run(gen_bitwise_ops.population_count(input_tensor))
+ self.assertAllEqual(truth, popcnt_result)
+
def testInvertOp(self):
dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
dtypes.uint8, dtypes.uint16]
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 44d6c7e275..4ba812eaf5 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -61,6 +61,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_control_flow_ops
@@ -983,9 +984,16 @@ class GradLoopState(object):
# the right control flow context.
real_value = self._grad_context.AddValue(cur_value)
break
+ elif constant_op.is_constant(cur_value):
+ # If the value to be forwarded is a constant, clone the constant in
+ # the gradient loop rather than using a stack.
+ # TODO(phawkins): consider hoisting the constant out of the loop
+ # instead.
+ real_value = constant_op.constant(
+ tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
+ break
else:
# Record the history of this value in forward_ctxt.
- # TODO(yuanbyu): Avoid recording constants.
self._grad_context.Exit()
history_value = cur_grad_state.AddForwardAccumulator(cur_value)
self._grad_context.Enter()
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 1e2f999995..42b4f952bb 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -41,7 +41,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import math_ops
class Initializer(object):
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 87f6f92a8a..cc8c623947 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -830,7 +830,8 @@ class ReluTest(test_lib.TestCase):
class MomentsTest(test_lib.TestCase):
- def doOutputTest(self, input_shape, moments_axes, tol=1e-4):
+ def doOutputTest(self, input_shape, moments_axes, tol=1e-4,
+ check_gradients=False):
for mu in [0.0, 1.0, 1e3]:
for sigma in [1.0, 0.1]:
for keep_dims in [True, False]:
@@ -846,6 +847,15 @@ class MomentsTest(test_lib.TestCase):
mean, variance = nn_impl.moments(
inputs, moments_axes, keep_dims=keep_dims)
+ if check_gradients:
+ err = gradient_checker.compute_gradient_error(
+ inputs, input_shape, mean, mean.shape.as_list())
+ self.assertLess(err, 1e-3)
+ err = gradient_checker.compute_gradient_error(
+ inputs, input_shape, variance, variance.shape.as_list())
+ self.assertLess(err, 1e-3)
+
+ # Evaluate.
[mean, variance] = sess.run([mean, variance])
# Make sure that there are no NaNs
self.assertFalse(np.isnan(mean).any())
@@ -853,6 +863,12 @@ class MomentsTest(test_lib.TestCase):
self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
self.assertAllClose(variance, expected_var, rtol=tol, atol=tol)
+ def testOutputAndGradient2DInput0(self):
+ self.doOutputTest((10, 10), (0,), check_gradients=True)
+
+ def testOutputAndGradient2DInput01(self):
+ self.doOutputTest((10, 10), (0, 1), check_gradients=True)
+
def testOutput2DInput0(self):
self.doOutputTest((10, 300), (0,))
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index f7854e86c0..304b6ae665 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -786,13 +786,18 @@ class DropoutWrapper(RNNCell):
class ResidualWrapper(RNNCell):
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
- def __init__(self, cell):
+ def __init__(self, cell, residual_fn=None):
"""Constructs a `ResidualWrapper` for `cell`.
Args:
cell: An instance of `RNNCell`.
+ residual_fn: (Optional) The function to map raw cell inputs and raw cell
+ outputs to the actual cell outputs of the residual network.
+ Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs
+ and outputs.
"""
self._cell = cell
+ self._residual_fn = residual_fn
@property
def state_size(self):
@@ -807,7 +812,7 @@ class ResidualWrapper(RNNCell):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
- """Run the cell and add its inputs to its outputs.
+ """Run the cell and then apply the residual_fn on its inputs to its outputs.
Args:
inputs: cell inputs.
@@ -822,13 +827,14 @@ class ResidualWrapper(RNNCell):
ValueError: If cell inputs and outputs have different structure (value).
"""
outputs, new_state = self._cell(inputs, state, scope=scope)
- nest.assert_same_structure(inputs, outputs)
# Ensure shapes match
def assert_shape_match(inp, out):
inp.get_shape().assert_is_compatible_with(out.get_shape())
- nest.map_structure(assert_shape_match, inputs, outputs)
- res_outputs = nest.map_structure(
- lambda inp, out: inp + out, inputs, outputs)
+ def default_residual_fn(inputs, outputs):
+ nest.assert_same_structure(inputs, outputs)
+ nest.map_structure(assert_shape_match, inputs, outputs)
+ return nest.map_structure(lambda inp, out: inp + out, inputs, outputs)
+ res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs)
return (res_outputs, new_state)
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
index b21d9a8ee3..a75e9e8080 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
@@ -54,7 +54,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cell\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'cell\', \'residual_fn\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/ci_build/builds/builds_common.sh b/tensorflow/tools/ci_build/builds/builds_common.sh
index 9323a96e74..e3b58d038a 100644
--- a/tensorflow/tools/ci_build/builds/builds_common.sh
+++ b/tensorflow/tools/ci_build/builds/builds_common.sh
@@ -17,6 +17,7 @@
# Common Bash functions used by build scripts
COLOR_NC='\033[0m'
+COLOR_LIGHT_GRAY='\033[0;37m'
COLOR_GREEN='\033[0;32m'
COLOR_RED='\033[0;31m'
diff --git a/tensorflow/tools/ci_build/builds/configured b/tensorflow/tools/ci_build/builds/configured
index 25cb51ea7c..563e07e3af 100755
--- a/tensorflow/tools/ci_build/builds/configured
+++ b/tensorflow/tools/ci_build/builds/configured
@@ -56,7 +56,7 @@ else
fi
pushd "${CI_TENSORFLOW_SUBMODULE_PATH:-.}"
-yes "" | ./configure
+$PYTHON_BIN_PATH configure.py
popd
# Gather and print build information
diff --git a/tensorflow/tools/ci_build/builds/pip.sh b/tensorflow/tools/ci_build/builds/pip.sh
index db011a6bad..112dab3a73 100755
--- a/tensorflow/tools/ci_build/builds/pip.sh
+++ b/tensorflow/tools/ci_build/builds/pip.sh
@@ -73,6 +73,9 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/builds_common.sh"
+SKIP_RETURN_CODE=112
+
+
# Get the command line arguments
CONTAINER_TYPE=$( echo "$1" | tr '[:upper:]' '[:lower:]' )
shift
@@ -310,6 +313,13 @@ create_activate_virtualenv_and_install_tensorflow() {
# Smoke test of tensorflow install in clean virtualenv
################################################################################
do_clean_virtualenv_smoke_test() {
+ if [[ -n "${NO_TEST_ON_INSTALL}" ]] &&
+ [[ "${NO_TEST_ON_INSTALL}" != "0" ]]; then
+ echo "NO_TEST_ON_INSTALL=${NO_TEST_ON_INSTALL}:"
+ echo " Skipping smoke test of tensorflow install in clean virtualenv"
+ return ${SKIP_RETURN_CODE}
+ fi
+
CLEAN_VENV_DIR="${PIP_TEST_ROOT}/venv_clean"
create_activate_virtualenv_and_install_tensorflow --clean \
"${CLEAN_VENV_DIR}" "${WHL_PATH}"
@@ -361,6 +371,7 @@ do_virtualenv_pip_test() {
[[ "${NO_TEST_ON_INSTALL}" != "0" ]]; then
echo "NO_TEST_ON_INSTALL=${NO_TEST_ON_INSTALL}:"
echo " Skipping ALL Python unit tests on install"
+ return ${SKIP_RETURN_CODE}
else
# Call run_pip_tests.sh to perform test-on-install
"${SCRIPT_DIR}/run_pip_tests.sh" --virtualenv ${GPU_FLAG} ${MAC_FLAG}
@@ -379,6 +390,7 @@ do_virtualenv_oss_serial_pip_test() {
[[ "${NO_TEST_ON_INSTALL}" != "0" ]]; then
echo "NO_TEST_ON_INSTALL=${NO_TEST_ON_INSTALL}:"
echo " Skipping Python unit tests on install tagged with oss_serial"
+ return ${SKIP_RETURN_CODE}
else
# Call run_pip_tests.sh to perform test-on-install
"${SCRIPT_DIR}/run_pip_tests.sh" \
@@ -402,6 +414,7 @@ do_test_user_ops() {
fi
else
echo "Skipping user-op test-on-install due to DO_TEST_USER_OPS = ${DO_TEST_USER_OPS}"
+ return ${SKIP_RETURN_CODE}
fi
}
@@ -424,6 +437,7 @@ do_test_tfdbg_binaries() {
popd
else
echo "Skipping test of tfdbg binaries due to DO_TEST_TFDBG_BINARIES = ${DO_TEST_TFDBG_BINARIES}"
+ return ${SKIP_RETURN_CODE}
fi
}
@@ -439,6 +453,7 @@ do_test_tutorials() {
fi
else
echo "Skipping tutorial tests-on-install due to DO_TEST_TUTORIALS = ${DO_TEST_TUTORIALS}"
+ return ${SKIP_RETURN_CODE}
fi
}
@@ -455,6 +470,7 @@ do_ffmpeg_integration_test() {
fi
else
echo "Skipping ffmpeg integration due to DO_INTEGRATION_TESTS = ${DO_INTEGRATION_TESTS}"
+ return ${SKIP_RETURN_CODE}
fi
}
@@ -468,6 +484,7 @@ PIP_TASKS_DESC=("Smoke test of pip install in clean virtualenv" "PIP tests in vi
COUNTER=0
FAIL_COUNTER=0
PASS_COUNTER=0
+SKIP_COUNTER=0
while [[ ${COUNTER} -lt "${#PIP_TASKS[@]}" ]]; do
INDEX=COUNTER
((INDEX++))
@@ -480,7 +497,9 @@ while [[ ${COUNTER} -lt "${#PIP_TASKS[@]}" ]]; do
${PIP_TASKS[COUNTER]}
RESULT=$?
- if [[ ${RESULT} != "0" ]]; then
+ if [[ ${RESULT} == ${SKIP_RETURN_CODE} ]]; then
+ ((SKIP_COUNTER++))
+ elif [[ ${RESULT} != "0" ]]; then
((FAIL_COUNTER++))
else
((PASS_COUNTER++))
@@ -503,7 +522,9 @@ while [[ ${COUNTER} -lt "${#PIP_TASKS[@]}" ]]; do
((INDEX++))
echo "${INDEX}. ${PIP_TASKS[COUNTER]}: ${PIP_TASKS_DESC[COUNTER]}"
- if [[ ${STEP_EXIT_CODES[COUNTER]} == "0" ]]; then
+ if [[ ${STEP_EXIT_CODES[COUNTER]} == ${SKIP_RETURN_CODE} ]]; then
+ printf " ${COLOR_LIGHT_GRAY}SKIP${COLOR_NC}\n"
+ elif [[ ${STEP_EXIT_CODES[COUNTER]} == "0" ]]; then
printf " ${COLOR_GREEN}PASS${COLOR_NC}\n"
else
printf " ${COLOR_RED}FAIL${COLOR_NC}\n"
@@ -513,7 +534,7 @@ while [[ ${COUNTER} -lt "${#PIP_TASKS[@]}" ]]; do
done
echo
-echo "${FAIL_COUNTER} failed; ${PASS_COUNTER} passed."
+echo "${SKIP_COUNTER} skipped; ${FAIL_COUNTER} failed; ${PASS_COUNTER} passed."
echo
if [[ ${FAIL_COUNTER} == "0" ]]; then
diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
index f66846654d..9a6890401b 100755
--- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh
+++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
@@ -120,7 +120,7 @@ else
fi
export TF_NEED_CUDA=$IS_GPU
-yes "" | ./configure
+${PYTHON_BIN_PATH} configure.py
# Figure out how many concurrent tests we can run and do run the tests.
BAZEL_PARALLEL_TEST_FLAGS=""
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
index 118e85fee0..ca84079654 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
@@ -30,7 +30,7 @@ export TF_NEED_HDFS=0
export TF_NEED_CUDA=0
# Only running cc tests, python version does not matter.
export PYTHON_BIN_PATH=`which python`
-yes "" | ./configure
+$PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test --test_lang_filters=cc -k \
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
index fa3d27fa41..5c82c9efaf 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
@@ -29,7 +29,7 @@ export TF_NEED_GCP=0
export TF_NEED_HDFS=0
export TF_NEED_CUDA=0
export PYTHON_BIN_PATH=`which python2`
-yes "" | ./configure
+$PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
index 258dec4fec..7155636a53 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
@@ -29,7 +29,7 @@ export TF_NEED_GCP=0
export TF_NEED_HDFS=0
export TF_NEED_CUDA=0
export PYTHON_BIN_PATH=`which python3`
-yes "" | ./configure
+$PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
index 9c450ab4da..218d2a8991 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
@@ -29,7 +29,7 @@ export TF_NEED_GCP=0
export TF_NEED_HDFS=0
export TF_NEED_CUDA=0
export PYTHON_BIN_PATH=`which python3`
-yes "" | ./configure
+$PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \
diff --git a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh
index f2ea8d3c77..dff72c25bf 100755
--- a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh
+++ b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh
@@ -32,7 +32,7 @@ export PYTHON_BIN_PATH=`which python3`
export TF_NEED_CUDA=1
export TF_CUDA_COMPUTE_CAPABILITIES=3.7
-yes "" | ./configure
+$PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \
diff --git a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh
index 4e0c3d1d33..a36a8445af 100755
--- a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh
+++ b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh
@@ -32,7 +32,7 @@ export PYTHON_BIN_PATH=`which python3`
export TF_NEED_CUDA=1
export TF_CUDA_COMPUTE_CAPABILITIES=3.7
-yes "" | ./configure
+$PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \
diff --git a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
index 0b8c73993f..0ee894e2c4 100755
--- a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
+++ b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
@@ -30,7 +30,7 @@ export TF_NEED_GCP=0
export TF_NEED_HDFS=0
export TF_NEED_CUDA=0
export PYTHON_BIN_PATH=$(which python2)
-yes "" | ./configure
+$PYTHON_BIN_PATH configure.py
which bazel
bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac \
--test_timeout 300,450,1200,3600 \
diff --git a/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh b/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh
index 1106413071..f548adc5ca 100755
--- a/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh
+++ b/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh
@@ -33,7 +33,7 @@ export TF_NEED_CUDA=1
export TF_ENABLE_XLA=1
export TF_CUDA_COMPUTE_CAPABILITIES=3.7
-yes "" | ./configure
+$PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --config=cuda --test_tag_filters=-no_gpu,-benchmark-test -k \
diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD
index 12bacf3c27..2d96406d27 100644
--- a/third_party/llvm/llvm.BUILD
+++ b/third_party/llvm/llvm.BUILD
@@ -899,6 +899,7 @@ cc_library(
"include/llvm/Target/ARM/InstPrinter/*.h",
"include/llvm/Target/ARM/InstPrinter/*.def",
"include/llvm/Target/ARM/InstPrinter/*.inc",
+ "lib/Target/ARM/*.h",
"lib/Target/ARM/InstPrinter/*.h",
]),
copts = ["-Iexternal/llvm/lib/Target/ARM"],
@@ -1206,6 +1207,7 @@ cc_library(
"lib/IR/*.h",
]),
hdrs = glob([
+ "include/llvm/Analysis/*.def",
"include/llvm/IR/*.h",
"include/llvm/IR/*.def",
"include/llvm/IR/*.inc",
@@ -2022,6 +2024,8 @@ cc_library(
"lib/Target/*.h",
]),
hdrs = glob([
+ "include/llvm/CodeGen/*.h",
+ "include/llvm/CodeGen/*.def",
"include/llvm/Target/*.h",
"include/llvm/Target/*.def",
"include/llvm/Target/*.inc",