aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--README.md10
-rw-r--r--RELEASE.md7
-rw-r--r--configure.py46
-rw-r--r--tensorflow/BUILD1
-rw-r--r--tensorflow/cc/gradients/math_grad.cc16
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc20
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/__init__.py1
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake1
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake8
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake1
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake7
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util.py6
-rw-r--r--tensorflow/contrib/gdr/gdr_server_lib.cc2
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/generator_io.py10
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py8
-rw-r--r--tensorflow/contrib/reduce_slice_ops/BUILD110
-rw-r--r--tensorflow/contrib/reduce_slice_ops/__init__.py26
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc239
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h84
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc100
-rw-r--r--tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc282
-rw-r--r--tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops_test.cc41
-rw-r--r--tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py158
-rw-r--r--tensorflow/contrib/reduce_slice_ops/python/ops/reduce_slice_ops.py30
-rw-r--r--tensorflow/contrib/session_bundle/example/export_half_plus_two.py5
-rw-r--r--tensorflow/contrib/slim/python/slim/data/parallel_reader.py32
-rw-r--r--tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py66
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/framework/types.cc8
-rw-r--r--tensorflow/core/framework/types.h22
-rw-r--r--tensorflow/core/kernels/cwise_op_arg.cc37
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_arg.cu.cc28
-rw-r--r--tensorflow/core/kernels/cwise_ops.h4
-rw-r--r--tensorflow/core/kernels/string_split_op.cc23
-rw-r--r--tensorflow/core/ops/math_grad.cc14
-rw-r--r--tensorflow/core/ops/math_grad_test.cc1
-rw-r--r--tensorflow/core/ops/math_ops.cc28
-rw-r--r--tensorflow/core/ops/ops.pbtxt39
-rw-r--r--tensorflow/core/ops/string_ops.cc2
-rw-r--r--tensorflow/core/public/version.h2
-rw-r--r--tensorflow/docs_src/api_guides/python/math_ops.md1
-rw-r--r--tensorflow/docs_src/install/install_c.md2
-rw-r--r--tensorflow/docs_src/install/install_go.md2
-rw-r--r--tensorflow/docs_src/install/install_java.md18
-rw-r--r--tensorflow/docs_src/install/install_linux.md22
-rw-r--r--tensorflow/docs_src/install/install_mac.md10
-rw-r--r--tensorflow/docs_src/install/install_sources.md4
-rw-r--r--tensorflow/docs_src/install/install_windows.md4
-rw-r--r--tensorflow/go/op/wrappers.go46
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/client/session_test.py37
-rw-r--r--tensorflow/python/estimator/inputs/queues/feeding_functions.py123
-rw-r--r--tensorflow/python/estimator/inputs/queues/feeding_functions_test.py96
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py83
-rw-r--r--tensorflow/python/kernel_tests/string_split_op_test.py18
-rw-r--r--tensorflow/python/ops/math_grad.py13
-rw-r--r--tensorflow/python/ops/math_ops.py44
-rw-r--r--tensorflow/python/ops/math_ops_test.py15
-rw-r--r--tensorflow/python/ops/string_ops.py5
-rw-r--r--tensorflow/python/training/monitored_session.py2
-rw-r--r--tensorflow/python/training/monitored_session_test.py7
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt6
-rw-r--r--tensorflow/tools/dist_test/Dockerfile1
-rw-r--r--tensorflow/tools/dist_test/Dockerfile.local1
-rw-r--r--tensorflow/tools/gcs_test/Dockerfile1
-rwxr-xr-xtensorflow/tools/pip_package/build_pip_package.sh5
-rw-r--r--tensorflow/tools/pip_package/setup.py15
68 files changed, 1969 insertions, 140 deletions
diff --git a/README.md b/README.md
index d265949194..63bde4235e 100644
--- a/README.md
+++ b/README.md
@@ -35,11 +35,11 @@ and discussion, and please direct specific questions to [Stack Overflow](https:/
People who are a little more adventurous can also try our nightly binaries:
-* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
-* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
-* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
-* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0rc2-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/))
-* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0rc2-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/))
+* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
+* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
+* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
+* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.3.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/))
+* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.3.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/))
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
diff --git a/RELEASE.md b/RELEASE.md
index ffe38004a2..3203f0aec1 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -1,5 +1,7 @@
# Release 1.3.0
+See also [TensorBoard 0.1.4](https://github.com/tensorflow/tensorboard/releases/tag/0.1.4) release notes.
+
## Major Features and Improvements
* Added canned estimators to Tensorflow library. List of added estimators:
* `DNNClassifier`
@@ -8,7 +10,7 @@
* `LinearRegressor`
* `DNNLinearCombinedClassifier`
* `DNNLinearCombinedRegressor`.
-* All our prebuilt binaries have been built with cuDNN 6.
+* All our prebuilt binaries have been built with cuDNN 6. We anticipate releasing TensorFlow 1.4 with cuDNN 7.
* `import tensorflow` now goes much faster.
* Adds a file cache to the GCS filesystem with configurable max staleness for file contents. This permits caching of file contents across close/open boundaries.
* Added an axis parameter to `tf.gather`.
@@ -44,6 +46,9 @@
* Adds time series models to contrib. See contrib/timeseries/README.md for details.
* Adds FULLY_CONNECTED Op to tensorflow/contrib/lite/schema.fbs
+## Known Issues
+* Tensorflow_gpu compilation fails with Bazel 0.5.3.
+
## Bug Fixes and Other Changes
* Fixes `strides` and `begin` dtype mismatch when slicing using int64 Tensor index in python.
* Improved convolution padding documentation.
diff --git a/configure.py b/configure.py
index 3646670263..5a024fb0e4 100644
--- a/configure.py
+++ b/configure.py
@@ -22,7 +22,6 @@ import errno
import os
import platform
import re
-import site
import subprocess
import sys
@@ -131,16 +130,27 @@ def cygpath(path):
return run_shell('cygpath -m "%s"' % path)
-def get_python_path(environ_cp):
+def get_python_path(environ_cp, python_bin_path):
"""Get the python site package paths."""
python_paths = []
if environ_cp.get('PYTHONPATH'):
python_paths = environ_cp.get('PYTHONPATH').split(':')
try:
- library_paths = site.getsitepackages()
- except AttributeError:
- from distutils.sysconfig import get_python_lib # pylint: disable=g-import-not-at-top
- library_paths = [get_python_lib()]
+ check_input = [
+ python_bin_path, '-c',
+ 'import site; print("\\n".join(site.getsitepackages()))'
+ ]
+ library_paths = subprocess.check_output(check_input).decode(
+ 'UTF-8').strip().split('\n')
+ except subprocess.CalledProcessError:
+ check_input = [
+ python_bin_path, '-c', 'from distutils.sysconfig import get_python_lib;'
+ + 'print(get_python_lib())'
+ ]
+ library_paths = [
+ subprocess.check_output(check_input).decode('UTF-8').strip()
+ ]
+
all_paths = set(python_paths + library_paths)
paths = []
@@ -150,6 +160,12 @@ def get_python_path(environ_cp):
return paths
+def get_python_major_version(python_bin_path):
+ """Get the python major version."""
+ check_input = [python_bin_path, '-c', 'import sys; print(sys.version[0])']
+ return subprocess.check_output(check_input).decode('UTF-8').strip()
+
+
def setup_python(environ_cp, bazel_version):
"""Setup python related env variables."""
# Get PYTHON_BIN_PATH, default is the current running python.
@@ -170,10 +186,14 @@ def setup_python(environ_cp, bazel_version):
print('%s is not executable. Is it the python binary?' % python_bin_path)
environ_cp['PYTHON_BIN_PATH'] = ''
+ # Convert python path to Windows style before checking lib and version
+ if is_windows():
+ python_bin_path = cygpath(python_bin_path)
+
# Get PYTHON_LIB_PATH
python_lib_path = environ_cp.get('PYTHON_LIB_PATH')
if not python_lib_path:
- python_lib_paths = get_python_path(environ_cp)
+ python_lib_paths = get_python_path(environ_cp, python_bin_path)
if environ_cp.get('USE_DEFAULT_PYTHON_LIB_PATH') == '1':
python_lib_path = python_lib_paths[0]
else:
@@ -187,10 +207,10 @@ def setup_python(environ_cp, bazel_version):
python_lib_path = default_python_lib_path
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
- python_major_version = sys.version_info[0]
+ python_major_version = get_python_major_version(python_bin_path)
+
# Convert python path to Windows style before writing into bazel.rc
if is_windows():
- python_bin_path = cygpath(python_bin_path)
python_lib_path = cygpath(python_lib_path)
# Set-up env variables used by python_configure.bzl
@@ -726,9 +746,15 @@ def set_tf_cuda_compute_capabilities(environ_cp):
# Check whether all capabilities from the input is valid
all_valid = True
for compute_capability in tf_cuda_compute_capabilities.split(','):
- if not re.match('[0-9]+.[0-9]+', compute_capability):
+ m = re.match('[0-9]+.[0-9]+', compute_capability)
+ if not m:
print('Invalid compute capability: ' % compute_capability)
all_valid = False
+ else:
+ ver = int(m.group(0).split('.')[0])
+ if ver < 3:
+ print('Only compute capabilities 3.0 or higher are supported.')
+ all_valid = False
if all_valid:
break
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 585240ec89..80646fb602 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -312,6 +312,7 @@ filegroup(
"//tensorflow/contrib/nn:all_files",
"//tensorflow/contrib/opt:all_files",
"//tensorflow/contrib/predictor:all_files",
+ "//tensorflow/contrib/reduce_slice_ops:all_files",
"//tensorflow/contrib/remote_fused_graph/pylib:all_files",
"//tensorflow/contrib/resampler:all_files",
"//tensorflow/contrib/rnn:all_files",
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index 4ef2dc5bfa..f0426c998b 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -537,6 +537,22 @@ Status ImagGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Imag", ImagGrad);
+Status AngleGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = Angle(x)
+ // dx = -dy / (Im(x) + iRe(x)) = -dy * z
+ auto re = Real(scope, op.input(0));
+ auto im = Imag(scope, op.input(0));
+ auto z_inv = Reciprocal(scope, Complex(scope, im, re));
+ auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type());
+ auto grad = Complex(scope, grad_inputs[0], zero);
+ auto dx = Neg(scope, Mul(scope, grad, z_inv));
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Angle", AngleGrad);
+
Status ConjGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 011e119353..bd78331309 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -684,7 +684,7 @@ class CWiseUnaryComplexGradTest : public ::testing::Test {
CWiseUnaryComplexGradTest()
: scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
- enum UnaryOpType { REAL, IMAG, CONJ };
+ enum UnaryOpType { REAL, IMAG, ANGLE, CONJ };
void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x,
const Tensor& dy, const Tensor& dx_expected) {
@@ -696,6 +696,9 @@ class CWiseUnaryComplexGradTest : public ::testing::Test {
case IMAG:
y = Imag(scope_, x);
break;
+ case ANGLE:
+ y = Angle(scope_, x);
+ break;
case CONJ:
y = Conj(scope_, x);
break;
@@ -730,6 +733,21 @@ TEST_F(CWiseUnaryComplexGradTest, Imag) {
TestCWiseGradComplex(IMAG, x, dy, dx_expected);
}
+TEST_F(CWiseUnaryComplexGradTest, Angle) {
+ Tensor x = test::AsTensor<complex64>(
+ {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
+ Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
+ Tensor dx_expected =
+ test::AsTensor<complex64>({{5.5, 5.5},
+ {3, 3},
+ {2.1666666666666665, 2.1666666666666665},
+ {1.75, 1.75},
+ {0.9375, 0.9375},
+ {0.8888888888888888, 0.8888888888888888}},
+ {2, 3});
+ TestCWiseGradComplex(ANGLE, x, dy, dx_expected);
+}
+
TEST_F(CWiseUnaryComplexGradTest, Conj) {
Tensor x = test::AsTensor<complex64>(
{{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 167821a51a..ee703e040e 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -57,6 +57,7 @@ py_library(
"//tensorflow/contrib/opt:opt_py",
"//tensorflow/contrib/predictor",
"//tensorflow/contrib/quantization:quantization_py",
+ "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py",
"//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
"//tensorflow/contrib/resampler:resampler_py",
"//tensorflow/contrib/rnn:rnn_py",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index a0f8896f98..315ea943cf 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -53,6 +53,7 @@ from tensorflow.contrib import nn
from tensorflow.contrib import opt
from tensorflow.contrib import predictor
from tensorflow.contrib import quantization
+from tensorflow.contrib import reduce_slice_ops
from tensorflow.contrib import resampler
from tensorflow.contrib import rnn
from tensorflow.contrib import saved_model
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index d2a817390a..f7470d3bce 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -87,6 +87,7 @@ endfunction()
file(GLOB_RECURSE tf_protos_cc_srcs RELATIVE ${tensorflow_source_dir}
"${tensorflow_source_dir}/tensorflow/core/*.proto"
+ "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto"
)
RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS
${tensorflow_source_dir} ${tf_protos_cc_srcs}
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index a86eb44d3f..1c7cb9476d 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -137,6 +137,7 @@ file(GLOB_RECURSE tf_core_kernels_exclude_srcs
"${tensorflow_source_dir}/tensorflow/core/kernels/*test_utils.cc"
"${tensorflow_source_dir}/tensorflow/core/kernels/*main.cc"
"${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc"
+ "${tensorflow_source_dir}/tensorflow/core/kernels/fuzzing/*"
"${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/*"
"${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_rewriter_transform*.cc"
)
@@ -169,9 +170,10 @@ if(WIN32)
endif(WIN32)
file(GLOB_RECURSE tf_core_gpu_kernels_srcs
- "${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc"
- "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/*.cu.cc"
- "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/*.cu.cc"
+ "${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc"
+ "${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/zero_initializer_op_gpu.cu.cc"
+ "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/*.cu.cc"
+ "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/*.cu.cc"
)
if(WIN32 AND tensorflow_ENABLE_GPU)
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index a9770bc731..f9b5d31088 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -100,6 +100,7 @@ GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_stats "${tensorflow_source_dir}/tensor
GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}")
GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc")
+GENERATE_CONTRIB_OP_LIBRARY(reduce_slice_ops "${tensorflow_source_dir}/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc")
########################################################
# tf_user_ops library
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 170feb582b..704f0d09f3 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -142,7 +142,6 @@ RELATIVE_PROTOBUF_GENERATE_PYTHON(
file(GLOB_RECURSE tf_python_protos_cc_srcs RELATIVE ${tensorflow_source_dir}
"${tensorflow_source_dir}/tensorflow/core/profiler/*.proto"
"${tensorflow_source_dir}/tensorflow/python/*.proto"
- "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/training/*.proto"
@@ -567,6 +566,12 @@ add_python_module("tensorflow/contrib/training")
add_python_module("tensorflow/contrib/training/python")
add_python_module("tensorflow/contrib/training/python/training")
add_python_module("tensorflow/contrib/util")
+add_python_module("tensorflow/contrib/reduce_slice_ops")
+add_python_module("tensorflow/contrib/reduce_slice_ops/kernels")
+add_python_module("tensorflow/contrib/reduce_slice_ops/ops")
+add_python_module("tensorflow/contrib/reduce_slice_ops/python")
+add_python_module("tensorflow/contrib/reduce_slice_ops/python/kernel_tests")
+add_python_module("tensorflow/contrib/reduce_slice_ops/python/ops")
########################################################
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py
index 6852dee979..470fb15fd6 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util.deprecation import deprecated
__all__ = [
@@ -77,6 +78,11 @@ def reduce_sum_n(tensors, name=None):
return math_ops.add_n(tensors, name=name_scope)
+@deprecated(
+ None,
+ 'Please switch to tf.confusion_matrix.remove_squeezable_dimensions. Note '
+ 'that order of the inputs and ouputs of labels and predictions have also '
+ 'been switched.')
def remove_squeezable_dimensions(predictions, labels, name=None):
"""Squeeze last dim if ranks of `predictions` and `labels` differ by 1.
diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc
index ea4d3a10ee..1f9dd0decb 100644
--- a/tensorflow/contrib/gdr/gdr_server_lib.cc
+++ b/tensorflow/contrib/gdr/gdr_server_lib.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/gdr/gdr_server_lib.h"
+
+#include "grpc/support/alloc.h"
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
#include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h"
#include "tensorflow/contrib/gdr/gdr_worker.h"
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py
index 8623768d33..884faf8335 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py
@@ -31,8 +31,10 @@ def generator_input_fn(x,
num_epochs=1,
shuffle=True,
queue_capacity=1000,
- num_threads=1):
- """Returns input function that returns dicts yielded by a generator.
+ num_threads=1,
+ pad_value=None):
+ """Returns input function that returns dicts of numpy arrays
+ yielded from a generator.
It is assumed that every dict of numpy arrays yielded from the dictionary
represents a single sample. The generator should consume a single epoch of the
@@ -68,6 +70,7 @@ def generator_input_fn(x,
time.
queue_capacity: Integer, size of queue to accumulate.
num_threads: Integer, number of threads used for reading and enqueueing.
+ pad_value: default value for dynamic padding of data samples, if provided.
Returns:
Function, that returns a feature `dict` with `Tensors` and an optional
@@ -117,7 +120,8 @@ def generator_input_fn(x,
shuffle=shuffle,
num_threads=num_threads,
enqueue_size=batch_size,
- num_epochs=num_epochs)
+ num_epochs=num_epochs,
+ pad_value=pad_value)
features = (queue.dequeue_many(batch_size)
if num_epochs is None else queue.dequeue_up_to(batch_size))
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index b5d8c95678..463bd60300 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -22,12 +22,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.framework import deprecated
-from tensorflow.contrib.framework.python.framework import tensor_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import confusion_matrix
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
@@ -35,6 +34,7 @@ from tensorflow.python.ops import metrics_impl
from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.util.deprecation import deprecated
def _safe_div(numerator, denominator, name):
@@ -2445,8 +2445,8 @@ def _remove_squeezable_dimensions(predictions, labels, weights):
Tuple of `predictions`, `labels` and `weights`, possibly with the last
dimension squeezed.
"""
- predictions, labels = tensor_util.remove_squeezable_dimensions(
- predictions, labels)
+ labels, predictions = confusion_matrix.remove_squeezable_dimensions(
+ labels, predictions)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
if weights is not None:
diff --git a/tensorflow/contrib/reduce_slice_ops/BUILD b/tensorflow/contrib/reduce_slice_ops/BUILD
new file mode 100644
index 0000000000..5340a51e00
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/BUILD
@@ -0,0 +1,110 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_kernel_tests_linkstatic")
+
+tf_custom_op_library(
+ name = "python/ops/_reduce_slice_ops.so",
+ srcs = [
+ "kernels/reduce_slice_ops.cc",
+ "kernels/reduce_slice_ops.h",
+ "ops/reduce_slice_ops.cc",
+ ],
+ gpu_srcs = [
+ "kernels/reduce_slice_ops.h",
+ "kernels/reduce_slice_ops_gpu.cu.cc",
+ ],
+)
+
+tf_kernel_library(
+ name = "reduce_slice_ops_kernels",
+ srcs = [
+ "kernels/reduce_slice_ops.cc",
+ ],
+ hdrs = [
+ "kernels/reduce_slice_ops.h",
+ ],
+ gpu_srcs = [
+ "kernels/reduce_slice_ops.h",
+ "kernels/reduce_slice_ops_gpu.cu.cc",
+ ],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["reduce_slice_ops"],
+)
+
+tf_gen_op_wrapper_py(
+ name = "reduce_slice_ops",
+ deps = [":reduce_slice_ops_op_lib"],
+)
+
+tf_custom_op_py_library(
+ name = "reduce_slice_ops_py",
+ srcs = [
+ "__init__.py",
+ "python/ops/reduce_slice_ops.py",
+ ],
+ dso = [
+ ":python/ops/_reduce_slice_ops.so",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":reduce_slice_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:framework",
+ ],
+)
+
+cuda_py_test(
+ name = "reduce_slice_ops_test",
+ size = "small",
+ srcs = ["python/kernel_tests/reduce_slice_ops_test.py"],
+ additional_deps = [
+ ":reduce_slice_ops_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ "//third_party/py/numpy",
+ ],
+)
+
+tf_cc_test(
+ name = "reduce_slice_ops_test_cc",
+ size = "small",
+ srcs = [
+ "ops/reduce_slice_ops_test.cc",
+ ],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":reduce_slice_ops_op_lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/reduce_slice_ops/__init__.py b/tensorflow/contrib/reduce_slice_ops/__init__.py
new file mode 100644
index 0000000000..d0364587b5
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""reduce by slice
+
+@@reduce_slice_sum
+@@reduce_slice_prod
+@@reduce_slice_min
+@@reduce_slice_max
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.reduce_slice_ops.python.ops import *
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc
new file mode 100644
index 0000000000..2def4f3f17
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc
@@ -0,0 +1,239 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h"
+#include <algorithm>
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+
+namespace tensorflow {
+
+using GPUDevice = Eigen::GpuDevice;
+using CPUDevice = Eigen::ThreadPoolDevice;
+using thread::ThreadPool;
+
+namespace functor {
+
+#define CPUReduceSliceFunctorReduceop(reduceop, beginning) \
+ template <typename T, typename Index> \
+ struct ReduceSliceFunctor##reduceop<CPUDevice, T, Index> { \
+ private: \
+ struct XYZ { \
+ Index x, y, z; \
+ XYZ() = default; \
+ XYZ(Index x, Index y, Index z) : x(x), y(y), z(z) {} \
+ }; \
+ inline static XYZ global_index_to_xyz(Index global, XYZ size) { \
+ XYZ ret; \
+ ret.x = global / (size.y * size.z); \
+ ret.y = global % (size.y * size.z) / size.z; \
+ ret.z = global % size.z; \
+ return ret; \
+ } \
+ \
+ public: \
+ virtual ~ReduceSliceFunctor##reduceop() {} \
+ virtual void operator()(OpKernelContext* ctx, const CPUDevice& d, \
+ Index indices_width, \
+ typename TTypes<Index, 1>::ConstTensor indices, \
+ typename TTypes<T, 3>::ConstTensor data, \
+ typename TTypes<T, 3>::Tensor output) { \
+ Index bound = data.dimension(1); \
+ Index dim1 = output.dimension(0); \
+ Index dim2 = output.dimension(1); \
+ Index dim3 = output.dimension(2); \
+ Index size = dim1 * dim2 * dim3; \
+ if (size == 0) { \
+ return; \
+ } \
+ T zero = beginning<T>(); \
+ ThreadPool* thread_pool = \
+ ctx->device()->tensorflow_cpu_worker_threads()->workers; \
+ /* shard the work */ \
+ auto work = [&](Index start, Index end) { \
+ for (Index global = start; global < end; ++global) { \
+ XYZ xyz = global_index_to_xyz(global, XYZ(dim1, dim2, dim3)); \
+ Index x = xyz.x; \
+ Index y = xyz.y; \
+ Index z = xyz.z; \
+ output(x, y, z) = zero; \
+ Index slice_head = indices(y * indices_width); \
+ Index slice_end = std::min(indices(y * indices_width + 1), bound); \
+ for (Index i = slice_head; i < slice_end; ++i) { \
+ output(x, y, z) = reduceop(output(x, y, z), data(x, i, z)); \
+ } \
+ } \
+ }; \
+ /* Here assumes the number of average CPU cycles for each slice equals \
+ * the average length of each slice */ \
+ thread_pool->ParallelFor(size, std::max(bound / dim2, (Index)1), work); \
+ } \
+ };
+
+CALL_ALL_REDUCEOPS(CPUReduceSliceFunctorReduceop)
+#undef CPUReduceSliceFunctorReduceop
+
+#define DEFINE_CPU_SUMPROD_SPECS_INDEX(T, Index) \
+ template struct ReduceSliceFunctorSum<CPUDevice, T, Index>; \
+ template struct ReduceSliceFunctorProd<CPUDevice, T, Index>;
+
+#define DEFINE_CPU_MINMAX_SPECS_INDEX(T, Index) \
+ template struct ReduceSliceFunctorMax<CPUDevice, T, Index>; \
+ template struct ReduceSliceFunctorMin<CPUDevice, T, Index>;
+
+#define DEFINE_CPU_SUMPROD_SPECS(T) \
+ DEFINE_CPU_SUMPROD_SPECS_INDEX(T, int32); \
+ DEFINE_CPU_SUMPROD_SPECS_INDEX(T, int64);
+
+#define DEFINE_CPU_MINMAX_SPECS(T) \
+ DEFINE_CPU_MINMAX_SPECS_INDEX(T, int32); \
+ DEFINE_CPU_MINMAX_SPECS_INDEX(T, int64);
+
+TF_CALL_NUMBER_TYPES(DEFINE_CPU_SUMPROD_SPECS)
+TF_CALL_REAL_NUMBER_TYPES(DEFINE_CPU_MINMAX_SPECS)
+
+#undef DEFINE_CPU_SUMPROD_SPECS_INDEX
+#undef DEFINE_CPU_MINMAX_SPECS_INDEX
+#undef DEFINE_CPU_SUMPROD_SPECS
+#undef DEFINE_CPU_MINMAX_SPECS
+
+} // namespace functor
+
+template <typename Device, typename T, typename Index,
+ template <typename Device2, typename T2, typename Index2>
+ class Functor>
+class ReduceSliceKernel : public OpKernel {
+ public:
+ explicit ReduceSliceKernel(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& data = context->input(0);
+ const Tensor& indices = context->input(1);
+ const Tensor& _axis = context->input(2);
+ int64 axis = _axis.scalar<int64>()();
+
+ int indices_width = 2;
+ int out_axis_dim_size = indices.shape().dim_size(0);
+ if (indices.dims() == 1 || indices.shape().dim_size(1) == 1) {
+ indices_width = 1;
+ if (out_axis_dim_size > 0) {
+ out_axis_dim_size--;
+ }
+ }
+
+ TensorShape output_shape = data.shape();
+ output_shape.set_dim(axis, out_axis_dim_size);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ auto functor = Functor<Device, T, Index>();
+ functor(context, context->eigen_device<Device>(), indices_width,
+ indices.flat<Index>(), data.flat_inner_outer_dims<T, 3>(axis - 1),
+ output->flat_inner_outer_dims<T, 3>(axis - 1));
+ }
+};
+
+#define REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceSum") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<CPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorSum>); \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceProd") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<CPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorProd>);
+
+#define REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceMax") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<CPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorMax>); \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceMin") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<CPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorMin>);
+
+#define REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS_ALL(type) \
+ REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS(type, int32); \
+ REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS(type, int64);
+
+#define REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS_ALL(type) \
+ REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS(type, int32); \
+ REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS(type, int64);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS_ALL)
+TF_CALL_NUMBER_TYPES(REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS_ALL)
+
+#undef REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS
+#undef REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS
+#undef REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS_ALL
+#undef REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS_ALL
+
+#if GOOGLE_CUDA
+
+#define REGISTER_GPU_REDUCE_SLICE_KERNELS(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceSum") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("axis") \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<GPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorSum>); \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceProd") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("axis") \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<GPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorProd>); \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceMax") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("axis") \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<GPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorMax>); \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceMin") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("axis") \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<GPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorMin>);
+
+#define REGISTER_GPU_REDUCE_SLICE_KERNELS_ALL(type) \
+ REGISTER_GPU_REDUCE_SLICE_KERNELS(type, int32); \
+ REGISTER_GPU_REDUCE_SLICE_KERNELS(type, int64);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_REDUCE_SLICE_KERNELS_ALL);
+
+#undef REGISTER_GPU_REDUCE_SLICE_KERNELS
+#undef REGISTER_GPU_REDUCE_SLICE_KERNELS_ALL
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h
new file mode 100644
index 0000000000..c62a7b20d6
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+#define Sum(a, b) ((a) + (b))
+#define Prod(a, b) ((a) * (b))
+#define Max(a, b) ((a) > (b) ? (a) : (b))
+#define Min(a, b) ((a) < (b) ? (a) : (b))
+
+namespace tensorflow {
+
+class OpKernelContext;
+
+namespace functor {
+
+namespace reduce_functions {
+
+template <typename T>
+inline T zero() {
+ return T(0);
+}
+
+template <typename T>
+inline T one() {
+ return T(1);
+}
+
+template <typename T>
+inline T infinity() {
+ return std::max<T>(std::numeric_limits<T>::max(),
+ std::numeric_limits<T>::infinity());
+}
+
+template <typename T>
+inline T negative_infinity() {
+ return std::min<T>(-std::numeric_limits<T>::infinity(),
+ std::numeric_limits<T>::min());
+}
+
+} // namespace reduce_functions
+
+#define CALL_ALL_REDUCEOPS(func, ...) \
+ func(Sum, functor::reduce_functions::zero, ##__VA_ARGS__) \
+ func(Prod, functor::reduce_functions::one, ##__VA_ARGS__) func( \
+ Max, functor::reduce_functions::negative_infinity, ##__VA_ARGS__) \
+ func(Min, functor::reduce_functions::infinity, ##__VA_ARGS__)
+
+#define ReduceSliceFunctorReduceop(reduceop, dummy) \
+ template <typename Device, typename T, typename Index> \
+ struct ReduceSliceFunctor##reduceop { \
+ virtual ~ReduceSliceFunctor##reduceop() {} \
+ virtual void operator()(OpKernelContext* ctx, const Device& d, \
+ Index indices_width, \
+ typename TTypes<Index, 1>::ConstTensor indices, \
+ typename TTypes<T, 3>::ConstTensor data, \
+ typename TTypes<T, 3>::Tensor output); \
+ };
+
+CALL_ALL_REDUCEOPS(ReduceSliceFunctorReduceop)
+#undef ReduceSliceFunctorReduceop
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
new file mode 100644
index 0000000000..8b205f7dd5
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
@@ -0,0 +1,100 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+
+using GPUDevice = Eigen::GpuDevice;
+
+namespace functor {
+
+#define GPUReduceSliceFunctorReduceop(reduceop, beginning) \
+ template <typename T, typename Index> \
+ __global__ void ReduceSliceDeviceKernel##reduceop( \
+ Cuda3DLaunchConfig config, Index indices_width, Index bound, \
+ const T begin, const Index *indices, const T *input, T *out) { \
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { \
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { \
+ CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { \
+ Index outidx = x * config.virtual_thread_count.y * \
+ config.virtual_thread_count.z + \
+ y * config.virtual_thread_count.z + z; \
+ out[outidx] = begin; \
+ Index start = indices[y * indices_width]; \
+ Index end = Min(bound, indices[y * indices_width + 1]); \
+ for (Index yin = start; yin < end; yin++) { \
+ Index inidx = x * bound * config.virtual_thread_count.z + \
+ yin * config.virtual_thread_count.z + z; \
+ out[outidx] = reduceop(out[outidx], input[inidx]); \
+ } \
+ } \
+ } \
+ } \
+ } \
+ \
+ template <typename T, typename Index> \
+ struct ReduceSliceFunctor##reduceop<GPUDevice, T, Index> { \
+ virtual ~ReduceSliceFunctor##reduceop() {} \
+ virtual void operator()(OpKernelContext *ctx, const GPUDevice &d, \
+ Index indices_width, \
+ typename TTypes<Index, 1>::ConstTensor indices, \
+ typename TTypes<T, 3>::ConstTensor data, \
+ typename TTypes<T, 3>::Tensor output) { \
+ Index bound = data.dimension(1); \
+ int sizex = output.dimension(0); \
+ int sizey = output.dimension(1); \
+ int sizez = output.dimension(2); \
+ if (sizex * sizey * sizez == 0) { \
+ return; \
+ } \
+ Cuda3DLaunchConfig config = GetCuda3DLaunchConfig( \
+ sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop<T, Index>, \
+ 0, 0); \
+ \
+ ReduceSliceDeviceKernel##reduceop<T, Index> \
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( \
+ config, indices_width, bound, beginning<T>(), indices.data(), \
+ data.data(), output.data()); \
+ } \
+ };
+
+CALL_ALL_REDUCEOPS(GPUReduceSliceFunctorReduceop)
+#undef GPUReduceSliceFunctorReduceop
+
+#define DEFINE_GPU_REDUCEOP_SPECS_INDEX(reduceop, dummy, T) \
+ template struct ReduceSliceFunctor##reduceop<GPUDevice, T, int32>; \
+ template struct ReduceSliceFunctor##reduceop<GPUDevice, T, int64>;
+
+#define DEFINE_GPU_SPECS(T) \
+ CALL_ALL_REDUCEOPS(DEFINE_GPU_REDUCEOP_SPECS_INDEX, T)
+
+TF_CALL_REAL_NUMBER_TYPES(DEFINE_GPU_SPECS)
+
+#undef DEFINE_GPU_REDUCEOP_SPECS_INDEX
+#undef DEFINE_GPU_SPECS
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc b/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc
new file mode 100644
index 0000000000..b8b56c0e22
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc
@@ -0,0 +1,282 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+namespace {
+
+Status ReduceSliceShapeFn(InferenceContext* c) {
+ ShapeHandle handle;
+ DimensionHandle dimhandle;
+ DimensionHandle dim_axis = c->UnknownDim();
+ // "axis" must be a scala
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &handle));
+ // "data" must have rank at least 1
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &handle));
+ // "indices" must have have rank 1 or rank 2 with the number of columns must
+ // be 2
+ if (c->RankKnown(c->input(1))) {
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &handle));
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 2, &handle));
+ if (c->Rank(c->input(1)) == 1) {
+ // if "indices" is a vector of 0 elements, then the axis dimension of
+ // output tensor should be of dimension 0.
+ DimensionHandle raw_dim_axis;
+ TF_RETURN_IF_ERROR(c->Max(c->Dim(c->input(1), 0), 1, &raw_dim_axis));
+ TF_RETURN_IF_ERROR(c->Subtract(raw_dim_axis, 1, &dim_axis));
+ } else { // c->Rank(c->input(1)) == 2
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(c->input(1), 1), c->MakeDim(2), &dimhandle));
+ dim_axis = c->Dim(c->input(1), 0);
+ }
+ }
+ // shape of output tensor
+ const Tensor* _axis = c->input_tensor(2);
+ if (nullptr == _axis) {
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
+ } else {
+ int64 axis = _axis->scalar<int64>()();
+ TF_RETURN_IF_ERROR(c->ReplaceDim(handle, axis, dim_axis, &handle));
+ c->set_output(0, handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+REGISTER_OP("ReduceSliceSum")
+ .Input("data: T")
+ .Input("indices: Tindices")
+ .Input("axis: int64")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32,int64}")
+ .SetShapeFn(ReduceSliceShapeFn)
+ .Doc(R"doc(
+Dynamically sum over the first dimension of a tensor according to start and end
+indices specified at 'index'.
+
+For example:
+
+```prettyprint
+# if 'data' is [[ 1, 2, 3]
+ [ 40, 50, 60]
+ [ 700, 800, 900]
+ [1000,2000,3000]],
+
+and 'indices' is [[0,1]
+ [1,1]
+ [0,2]],
+
+the the output will be [[ 1, 2, 3]
+ [ 0, 0, 0]
+ [41,52,63]].
+```
+
+The data must be at least rank 1. The indices must be of shape (?,2) where the
+first column is start indices and the second column is end indices. The end indices
+are not included in the reduce operation, which means, if you want to do a reduce
+over indices 0,1,2, then you should have start index 0 and end index 3. If end
+index is smaller than or equal to start, the result will be zero. If end index is
+out of bounds, then the reduce operation will automatically stop at the bound, so
+feel free to put a large number as your end of your index if you want to do the
+reduction until the bound.
+
+data: The source of data where the computation will be taken from.
+indices: start, end indices that controls which part to be included.
+T: the type of data.
+Tindices: the type of indices, must be int32 or int64.
+output: the computed sum values.
+)doc");
+
+REGISTER_OP("ReduceSliceProd")
+ .Input("data: T")
+ .Input("indices: Tindices")
+ .Input("axis: int64")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32,int64}")
+ .SetShapeFn(ReduceSliceShapeFn)
+ .Doc(R"doc(
+Dynamically compute the product over the first dimension of a tensor according
+to start and end indices specified at 'indices'.
+
+For example:
+
+```prettyprint
+# if 'data' is [[ 1, 2, 3]
+ [ 40, 50, 60]
+ [ 700, 800, 900]
+ [1000,2000,3000]],
+
+and 'indices' is [[0,1]
+ [1,1]
+ [0,2]],
+
+the the output will be [[ 1, 2, 3]
+ [ 1, 1, 1]
+ [40,100,180]].
+```
+
+The data must be at least rank 1. The indices can be of shape (?,2) where the
+first column is start indices and the second column is end indices. The end indices
+are not included in the reduce operation, which means, if you want to do a reduce
+over indices 0,1,2, then you should have start index 0 and end index 3. If end
+index is smaller than or equal to start, the result will be 1. If end index is
+out of bounds, then the reduce operation will automatically stop at the bound, so
+feel free to put a large number as your end of your index if you want to do the
+reduction until the bound. The indices can also be of shape (?), in this case, the
+start index of i will be the element at i, then end index of i will be the element
+at i+1. That is:
+
+```prettyprint
+indices = [0,5,11,115]
+
+is equivalent to
+
+indices = [ [0,5],
+ [5,11],
+ [11,115]]
+```
+
+data: The source of data where the computation will be taken from.
+indices: start, end indices that controls which part to be included.
+T: the type of data.
+Tindices: the type of indices, must be int32 or int64.
+output: the computed product values.
+)doc");
+
+REGISTER_OP("ReduceSliceMax")
+ .Input("data: T")
+ .Input("indices: Tindices")
+ .Input("axis: int64")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32,int64}")
+ .SetShapeFn(ReduceSliceShapeFn)
+ .Doc(R"doc(
+Dynamically compute the maximum over the first dimension of a tensor according
+to start and end indices specified at "indices".
+
+For example:
+
+```prettyprint
+# if 'data' is [[ 1, 20, 3]
+ [ 400, 5, 60]
+ [ 70, 8, 900]
+ [1000,2000,3000]],
+
+and 'indices' is [[0,1]
+ [1,1]
+ [0,2]],
+
+the the output will be [[ 1, 20, 3]
+ [ -BIG_VALUE, -BIG_VALUE, -BIG_VALUE]
+ [ 400, 20, 60]].
+```
+
+The data must be at least rank 1. The indices can be of shape (?,2) where the
+first column is start indices and the second column is end indices. The end indices
+are not included in the reduce operation, which means, if you want to do a reduce
+over indices 0,1,2, then you should have start index 0 and end index 3. If end
+index is smaller than or equal to start, the result will be 1. If end index is
+out of bounds, then the reduce operation will automatically stop at the bound, so
+feel free to put a large number as your end of your index if you want to do the
+reduction until the bound. The indices can also be of shape (?), in this case, the
+start index of i will be the element at i, then end index of i will be the element
+at i+1. That is:
+
+```prettyprint
+indices = [0,5,11,115]
+
+is equivalent to
+
+indices = [ [0,5],
+ [5,11],
+ [11,115]]
+```
+
+data: The source of data where the computation will be taken from.
+indices: start, end indices that controls which part to be included.
+T: the type of data.
+Tindices: the type of indices, must be int32 or int64.
+output: the computed product values.
+)doc");
+
+REGISTER_OP("ReduceSliceMin")
+ .Input("data: T")
+ .Input("indices: Tindices")
+ .Input("axis: int64")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32,int64}")
+ .SetShapeFn(ReduceSliceShapeFn)
+ .Doc(R"doc(
+Dynamically compute the minimum over the first dimension of a tensor according
+to start and end indices specified at 'indices'.
+
+For example:
+
+```prettyprint
+# if 'data' is [[ 1, 20, 3]
+ [ 400, 5, 60]
+ [ 70, 8, 900]
+ [1000,2000,3000]],
+
+and 'indices' is [[0,1]
+ [1,1]
+ [0,2]],
+
+the the output will be [[ 1, 20, 3]
+ [ +BIG_VALUE, +BIG_VALUE, +BIG_VALUE]
+ [ 1, 5, 3]].
+```
+
+The data must be at least rank 1. The indices can be of shape (?,2) where the
+first column is start indices and the second column is end indices. The end indices
+are not included in the reduce operation, which means, if you want to do a reduce
+over indices 0,1,2, then you should have start index 0 and end index 3. If end
+index is smaller than or equal to start, the result will be 1. If end index is
+out of bounds, then the reduce operation will automatically stop at the bound, so
+feel free to put a large number as your end of your index if you want to do the
+reduction until the bound. The indices can also be of shape (?), in this case, the
+start index of i will be the element at i, then end index of i will be the element
+at i+1. That is:
+
+```prettyprint
+indices = [0,5,11,115]
+
+is equivalent to
+
+indices = [ [0,5],
+ [5,11],
+ [11,115]]
+```
+
+data: The source of data where the computation will be taken from.
+indices: start, end indices that controls which part to be included.
+T: the type of data.
+Tindices: the type of indices, must be int32 or int64.
+output: the computed product values.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops_test.cc b/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops_test.cc
new file mode 100644
index 0000000000..777ad9bf15
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops_test.cc
@@ -0,0 +1,41 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(ReduceSliceOpsTest, ReduceSliceSum_ShapeFn) {
+ ShapeInferenceTestOp op("ReduceSliceSum");
+ INFER_OK(op, "?;?;?", "?");
+ INFER_OK(op, "[10,20];[100,2];[]", "[?,?]");
+ INFER_OK(op, "[10,20];[?,2];[]", "[?,?]");
+ INFER_OK(op, "[10,20];[0];[]", "[?,?]");
+ INFER_OK(op, "[10,20];[1];[]", "[?,?]");
+ INFER_OK(op, "[10,20];[?];[]", "[?,?]");
+ INFER_OK(op, "[?,?];[?,2];[]", "[?,?]");
+ INFER_OK(op, "[?,?];[25,2];[]", "[?,?]");
+ INFER_OK(op, "[?];[123,2];[]", "[?]");
+ INFER_OK(op, "[1,2,3,4];[100,2];[]", "[?,?,?,?]");
+
+ INFER_ERROR("must be rank 0", op, "?;[?,2];[?]");
+ INFER_ERROR("must be at least rank 1", op, "?;[];[]");
+ INFER_ERROR("must be at most rank 2", op, "?;[1,2,3];[]");
+ INFER_ERROR("must be equal, but are 1 and 2", op, "?;[?,1];[]");
+ INFER_ERROR("must be at least rank 1", op, "[];?;[]");
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py b/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py
new file mode 100644
index 0000000000..8c8db295ff
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py
@@ -0,0 +1,158 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.contrib.reduce_slice_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import unittest
+
+from tensorflow.contrib.reduce_slice_ops.python.ops import reduce_slice_ops
+from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.platform import googletest
+
+
+class ReduceSliceTest(TensorFlowTestCase):
+
+ def testReduceSliceSum1D(self):
+ x = np.array([1, 40, 700], dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.array([1, 741, 40, 740, 41], dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceSum2D(self):
+ x = np.array([[1, 2, 3], [40, 50, 60], [700, 800, 900]], dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.array(
+ [[1, 2, 3], [741, 852, 963], [40, 50, 60], [740, 850, 960],
+ [41, 52, 63]],
+ dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceSum3D(self):
+ x = np.array(
+ [[[1, 2], [3, 4]], [[50, 60], [70, 80]], [[600, 700], [800, 900]]],
+ dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.array(
+ [[[1, 2], [3, 4]], [[651, 762], [873, 984]], [[50, 60], [70, 80]],
+ [[650, 760], [870, 980]], [[51, 62], [73, 84]]],
+ dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceSumAxis1(self):
+ x = np.transpose(
+ np.array([[1, 2, 3], [40, 50, 60], [700, 800, 900]], dtype=np.int32))
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.transpose(
+ np.array(
+ [[1, 2, 3], [741, 852, 963], [40, 50, 60], [740, 850, 960],
+ [41, 52, 63]],
+ dtype=np.int32))
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 1).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceSum1DIndices(self):
+ x = np.array(
+ [[1, 2, 3], [40, 50, 60], [700, 800, 900], [1000, 2000, 3000],
+ [40000, 50000, 60000]],
+ dtype=np.int32)
+ indices = np.array([0, 0, 2, 5], dtype=np.int32)
+ result = np.array(
+ [[0, 0, 0], [41, 52, 63], [41700, 52800, 63900]], dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceProd(self):
+ x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.array(
+ [[1, 2, 3], [28, 80, 162], [4, 5, 6], [28, 40, 54], [4, 10, 18]],
+ dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_prod(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceMax(self):
+ x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.array(
+ [[1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9], [4, 5, 6]], dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_max(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceMin(self):
+ x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.array(
+ [[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6], [1, 2, 3]], dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_min(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceEmptyDataRows(self):
+ x = np.empty((0, 1, 2, 3, 4, 5, 6), dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.zeros((5, 1, 2, 3, 4, 5, 6), dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceEmptyDataCols(self):
+ x = np.empty((100, 0, 2, 3, 4, 5, 6), dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.empty((5, 0, 2, 3, 4, 5, 6), dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceEmptyIndicesRows(self):
+ x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+ indices = np.empty((0, 2), dtype=np.int32)
+ result = np.empty((0, 3), dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceEmpty0Indices1D(self):
+ x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+ indices = np.empty((0,), dtype=np.int32)
+ result = np.empty((0, 3), dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceEmpty1Indices1D(self):
+ x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+ indices = np.array([0], dtype=np.int32)
+ result = np.empty((0, 3), dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/contrib/reduce_slice_ops/python/ops/reduce_slice_ops.py b/tensorflow/contrib/reduce_slice_ops/python/ops/reduce_slice_ops.py
new file mode 100644
index 0000000000..d0f02489bd
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/python/ops/reduce_slice_ops.py
@@ -0,0 +1,30 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrapper for the reduce slice operators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+
+_reduce_slice_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_reduce_slice_ops.so"))
+
+reduce_slice_sum = _reduce_slice_ops.reduce_slice_sum
+reduce_slice_prod = _reduce_slice_ops.reduce_slice_prod
+reduce_slice_max = _reduce_slice_ops.reduce_slice_max
+reduce_slice_min = _reduce_slice_ops.reduce_slice_min
diff --git a/tensorflow/contrib/session_bundle/example/export_half_plus_two.py b/tensorflow/contrib/session_bundle/example/export_half_plus_two.py
index 4a56509e59..83e91d390f 100644
--- a/tensorflow/contrib/session_bundle/example/export_half_plus_two.py
+++ b/tensorflow/contrib/session_bundle/example/export_half_plus_two.py
@@ -152,11 +152,10 @@ if __name__ == "__main__":
)
parser.add_argument(
"--use_checkpoint_v2",
- "bool",
+ type="bool",
nargs="?",
const=True,
default=False,
- help="If true, write v2 checkpoint files."
- )
+ help="If true, write v2 checkpoint files.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py
index e97f500572..ad5e985487 100644
--- a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py
+++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py
@@ -127,6 +127,36 @@ class ParallelReader(io_ops.ReaderBase):
The next record (i.e. (key, value pair)) from the common_queue.
"""
+ self._configure_readers_by(queue)
+ return self._common_queue.dequeue(name=name)
+
+ def read_up_to(self, queue, num_records, name=None):
+ """Returns up to num_records (key, value pairs) produced by a reader.
+
+ Will dequeue a work unit from queue if necessary (e.g., when the
+ Reader needs to start reading from a new file since it has
+ finished with the previous file).
+ It may return less than num_records even before the last batch.
+
+ **Note** This operation is not supported by all types of `common_queue`s.
+ If a `common_queue` does not support `dequeue_up_to()`, then a
+ `tf.errors.UnimplementedError` is raised.
+
+ Args:
+ queue: A Queue or a mutable string Tensor representing a handle
+ to a Queue, with string work items.
+ num_records: Number of records to read.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tuple of Tensors (keys, values) from common_queue.
+ keys: A 1-D string Tensor.
+ values: A 1-D string Tensor.
+ """
+ self._configure_readers_by(queue)
+ return self._common_queue.dequeue_up_to(num_records, name)
+
+ def _configure_readers_by(self, queue):
enqueue_ops = []
for reader in self._readers:
enqueue_ops.append(self._common_queue.enqueue(reader.read(queue)))
@@ -134,8 +164,6 @@ class ParallelReader(io_ops.ReaderBase):
queue_runner.add_queue_runner(
queue_runner.QueueRunner(self._common_queue, enqueue_ops))
- return self._common_queue.dequeue(name=name)
-
def num_records_produced(self, name=None):
"""Returns the number of records this reader has produced.
diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
index a46e4b00f9..10ea883e1f 100644
--- a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.slim.python.slim.data import test_utils
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import variables
@@ -74,6 +75,54 @@ class ParallelReaderTest(test.TestCase):
self.assertGreater(count2, 0)
self.assertEquals(count0 + count1 + count2, num_reads)
+ def _verify_read_up_to_out(self, shared_queue):
+ with self.test_session():
+ num_files = 3
+ num_records_per_file = 7
+ tfrecord_paths = test_utils.create_tfrecord_files(
+ self.get_temp_dir(),
+ num_files=num_files,
+ num_records_per_file=num_records_per_file)
+
+ p_reader = parallel_reader.ParallelReader(
+ io_ops.TFRecordReader, shared_queue, num_readers=5)
+
+ data_files = parallel_reader.get_data_files(tfrecord_paths)
+ filename_queue = input_lib.string_input_producer(data_files, num_epochs=1)
+ key, value = p_reader.read_up_to(filename_queue, 4)
+
+ count0 = 0
+ count1 = 0
+ count2 = 0
+ all_keys_count = 0
+ all_values_count = 0
+
+ sv = supervisor.Supervisor(logdir=self.get_temp_dir())
+ with sv.prepare_or_wait_for_session() as sess:
+ sv.start_queue_runners(sess)
+ while True:
+ try:
+ current_keys, current_values = sess.run([key, value])
+ self.assertEquals(len(current_keys), len(current_values))
+ all_keys_count += len(current_keys)
+ all_values_count += len(current_values)
+ for current_key in current_keys:
+ if '0-of-3' in str(current_key):
+ count0 += 1
+ if '1-of-3' in str(current_key):
+ count1 += 1
+ if '2-of-3' in str(current_key):
+ count2 += 1
+ except errors_impl.OutOfRangeError:
+ break
+
+ self.assertEquals(count0, num_records_per_file)
+ self.assertEquals(count1, num_records_per_file)
+ self.assertEquals(count2, num_records_per_file)
+ self.assertEquals(all_keys_count, num_files * num_records_per_file)
+ self.assertEquals(all_values_count, all_keys_count)
+ self.assertEquals(count0 + count1 + count2, all_keys_count)
+
def testRandomShuffleQueue(self):
shared_queue = data_flow_ops.RandomShuffleQueue(
capacity=256,
@@ -86,6 +135,23 @@ class ParallelReaderTest(test.TestCase):
capacity=256, dtypes=[dtypes_lib.string, dtypes_lib.string])
self._verify_all_data_sources_read(shared_queue)
+ def testReadUpToFromRandomShuffleQueue(self):
+ shared_queue = data_flow_ops.RandomShuffleQueue(
+ capacity=55,
+ min_after_dequeue=28,
+ dtypes=[dtypes_lib.string, dtypes_lib.string],
+ shapes=[tensor_shape.scalar(),
+ tensor_shape.scalar()])
+ self._verify_read_up_to_out(shared_queue)
+
+ def testReadUpToFromFIFOQueue(self):
+ shared_queue = data_flow_ops.FIFOQueue(
+ capacity=99,
+ dtypes=[dtypes_lib.string, dtypes_lib.string],
+ shapes=[tensor_shape.scalar(),
+ tensor_shape.scalar()])
+ self._verify_read_up_to_out(shared_queue)
+
class ParallelReadTest(test.TestCase):
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 1f7eb87f18..6b51d3bc80 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2192,6 +2192,7 @@ tf_cc_test(
srcs = ["platform/setround_test.cc"],
tags = [
"noasan",
+ "noclang",
"nomsan",
"notsan",
],
diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc
index 39dd5b435e..1a5fd10f52 100644
--- a/tensorflow/core/framework/types.cc
+++ b/tensorflow/core/framework/types.cc
@@ -39,6 +39,14 @@ const char* const DEVICE_CPU = "CPU";
const char* const DEVICE_GPU = "GPU";
const char* const DEVICE_SYCL = "SYCL";
+const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
+#if GOOGLE_CUDA
+const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
+#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+const std::string DeviceName<Eigen::SyclDevice>::value = DEVICE_SYCL;
+#endif // TENSORFLOW_USE_SYCL
+
string DataTypeString(DataType dtype) {
if (IsRefType(dtype)) {
DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index 9127750d68..3b4362bcc9 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -74,6 +74,28 @@ TF_EXPORT extern const char* const DEVICE_CPU; // "CPU"
TF_EXPORT extern const char* const DEVICE_GPU; // "GPU"
TF_EXPORT extern const char* const DEVICE_SYCL; // "SYCL"
+template <typename Device>
+struct DeviceName {};
+
+template <>
+struct DeviceName<Eigen::ThreadPoolDevice> {
+ static const std::string value;
+};
+
+#if GOOGLE_CUDA
+template <>
+struct DeviceName<Eigen::GpuDevice> {
+ static const std::string value;
+};
+#endif // GOOGLE_CUDA
+
+#ifdef TENSORFLOW_USE_SYCL
+template <>
+struct DeviceName<Eigen::SyclDevice> {
+ static const std::string value;
+};
+#endif // TENSORFLOW_USE_SYCL
+
typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector;
typedef gtl::ArraySlice<MemoryType> MemoryTypeSlice;
diff --git a/tensorflow/core/kernels/cwise_op_arg.cc b/tensorflow/core/kernels/cwise_op_arg.cc
new file mode 100644
index 0000000000..62ffa0718f
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_arg.cc
@@ -0,0 +1,37 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+#define REGISTER_COMPLEX(D, R, C) \
+ REGISTER_KERNEL_BUILDER(Name("Angle") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<C>("T") \
+ .TypeConstraint<R>("Tout"), \
+ UnaryOp<D##Device, functor::get_angle<C>>);
+
+REGISTER_COMPLEX(CPU, float, complex64);
+REGISTER_COMPLEX(CPU, double, complex128);
+
+// TODO: Enable GPU support for angle op after resolving
+// build failures on GPU (See #10643 for context).
+#if 0 && GOOGLE_CUDA
+REGISTER_COMPLEX(GPU, float, complex64);
+REGISTER_COMPLEX(GPU, double, complex128);
+#endif
+
+#undef REGISTER_COMPLEX
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_arg.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_arg.cu.cc
new file mode 100644
index 0000000000..9b3f8200bd
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_arg.cu.cc
@@ -0,0 +1,28 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// TODO: Enable GPU support for angle op after resolving
+// build failures on GPU (See #10643 for context).
+#if 0 && GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(get_angle, complex64, complex128);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 65a60720dd..d935331904 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -831,6 +831,10 @@ struct get_imag
: base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {};
template <typename T>
+struct get_angle
+ : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {};
+
+template <typename T>
struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {};
////////////////////////////////////////////////////////////////////////////////
diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc
index d7b804daeb..9efbd66ef7 100644
--- a/tensorflow/core/kernels/string_split_op.cc
+++ b/tensorflow/core/kernels/string_split_op.cc
@@ -28,9 +28,13 @@ namespace tensorflow {
namespace {
-std::vector<string> Split(const string& str, const string& delimiter) {
+std::vector<string> Split(const string& str, const string& delimiter,
+ const bool skipEmpty) {
if (!delimiter.empty()) {
- return str_util::Split(str, delimiter, str_util::SkipEmpty());
+ if (skipEmpty) {
+ return str_util::Split(str, delimiter, str_util::SkipEmpty());
+ }
+ return str_util::Split(str, delimiter);
}
std::vector<string> char_vector(str.size());
for (size_t i = 0; i < str.size(); ++i) {
@@ -43,7 +47,15 @@ std::vector<string> Split(const string& str, const string& delimiter) {
class StringSplitOp : public OpKernel {
public:
- using OpKernel::OpKernel;
+ explicit StringSplitOp(OpKernelConstruction* context)
+ : OpKernel(context), skip_empty_(true) {
+ bool skip_empty;
+ // By default skip_empty_ is true. We only get the value from attr if it is
+ // available, so that it is backward compatible.
+ if (context->GetAttr("skip_empty", &skip_empty).ok()) {
+ skip_empty_ = skip_empty;
+ }
+ }
void Compute(OpKernelContext* ctx) override {
const Tensor* input_tensor;
@@ -73,7 +85,7 @@ class StringSplitOp : public OpKernel {
int64 max_num_entries = 0;
std::vector<int64> num_indices(batch_size);
for (int64 i = 0; i < batch_size; ++i) {
- std::vector<string> parts = Split(input_vec(i), delimiter);
+ std::vector<string> parts = Split(input_vec(i), delimiter, skip_empty_);
int64 n_entries = parts.size();
num_indices[i] = n_entries;
output_size += n_entries;
@@ -105,6 +117,9 @@ class StringSplitOp : public OpKernel {
}
}
}
+
+ private:
+ bool skip_empty_;
};
REGISTER_KERNEL_BUILDER(Name("StringSplit").Device(DEVICE_CPU), StringSplitOp);
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 5e082ce8f5..1290d3103e 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -349,6 +349,20 @@ Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Imag", ImagGrad);
+Status AngleGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"re"}, "Real", {"x"}},
+ {{"im"}, "Imag", {"x"}},
+ {{"z"}, "Complex", {"im", "re"}},
+ {{"z_inv"}, "Reciprocal", {"z"}},
+ {{"neg"}, "Neg", {"z_inv"}},
+ {{"dx"}, "Mul", {"neg", "dy"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Angle", AngleGrad);
+
Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForUnaryCwise(g, {
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index 1393bffb91..2b4b35547b 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -612,6 +612,7 @@ TEST_F(MathGradTest, Cos) {
// TODO(zhifengc)
// TEST_F(MathGradSComplexTest, Real) {}
// TEST_F(MathGradSComplexTest, Imag) {}
+// TEST_F(MathGradSComplexTest, Angle) {}
// TEST_F(MathGradSComplexTest, Conj) {}
// TEST_F(MathGradTernary, Select) {}
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 36f999ff60..6ff05bd2a6 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -2065,6 +2065,34 @@ tf.imag(input) ==> [4.75, 5.75]
```
)doc");
+REGISTER_OP("Angle")
+ .Input("input: T")
+ .Output("output: Tout")
+ .Attr("T: {complex64, complex128} = DT_COMPLEX64")
+ .Attr("Tout: {float, double} = DT_FLOAT")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Returns the argument of a complex number.
+
+Given a tensor `input` of complex numbers, this operation returns a tensor of
+type `float` that is the argument of each element in `input`. All elements in
+`input` must be complex numbers of the form \\(a + bj\\), where *a*
+is the real part and *b* is the imaginary part.
+
+The argument returned by this operation is of the form \\(atan2(b, a)\\).
+
+For example:
+
+```
+# tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
+tf.angle(input) ==> [2.0132, 1.056]
+```
+
+@compatibility(numpy)
+Equivalent to np.angle.
+@end_compatibility
+)doc");
+
REGISTER_OP("Conj")
.Input("input: T")
.Output("output: T")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index b68d05a8e8..5419301992 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -593,6 +593,45 @@ op {
is_stateful: true
}
op {
+ name: "Angle"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "Tout"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_COMPLEX64
+ }
+ allowed_values {
+ list {
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ attr {
+ name: "Tout"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ summary: "Returns the argument of a complex number."
+ description: "Given a tensor `input` of complex numbers, this operation returns a tensor of\ntype `float` that is the argument of each element in `input`. All elements in\n`input` must be complex numbers of the form \\(a + bj\\), where *a*\nis the real part and *b* is the imaginary part.\n\nThe argument returned by this operation is of the form \\(atan2(b, a)\\).\n\nFor example:\n```\n # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]\ntf.angle(input) ==> [2.0132, 1.056]\n```"
+}
+op {
name: "Any"
input_arg {
name: "input"
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index d6fe8ec342..5e99187d50 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -202,6 +202,7 @@ REGISTER_OP("StringSplit")
.Output("indices: int64")
.Output("values: string")
.Output("shape: int64")
+ .Attr("skip_empty: bool = true")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
@@ -238,6 +239,7 @@ For example:
input: 1-D. Strings to split.
delimiter: 0-D. Delimiter characters (bytes), or empty string.
+skip_empty: A `bool`. If `True`, skip the empty strings from the result.
indices: A dense matrix of int64 representing the indices of the sparse tensor.
values: A vector of strings corresponding to the splited values.
shape: a length-2 vector of int64 representing the shape of the sparse
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 2fefa67d7d..ccb861c93a 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -24,7 +24,7 @@ limitations under the License.
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX "-rc2"
+#define TF_VERSION_SUFFIX ""
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/docs_src/api_guides/python/math_ops.md b/tensorflow/docs_src/api_guides/python/math_ops.md
index b3c7a0c010..dee7f1618a 100644
--- a/tensorflow/docs_src/api_guides/python/math_ops.md
+++ b/tensorflow/docs_src/api_guides/python/math_ops.md
@@ -122,6 +122,7 @@ functions to your graph.
* @{tf.complex}
* @{tf.conj}
* @{tf.imag}
+* @{tf.angle}
* @{tf.real}
diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md
index ec37311624..7ebf5c4a2c 100644
--- a/tensorflow/docs_src/install/install_c.md
+++ b/tensorflow/docs_src/install/install_c.md
@@ -35,7 +35,7 @@ enable TensorFlow for C:
OS="linux" # Change to "darwin" for Mac OS
TARGET_DIRECTORY="/usr/local"
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.3.0-rc2.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.3.0.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md
index b7dc033efc..b991fd0f93 100644
--- a/tensorflow/docs_src/install/install_go.md
+++ b/tensorflow/docs_src/install/install_go.md
@@ -35,7 +35,7 @@ steps to install this library and enable TensorFlow for Go:
TF_TYPE="cpu" # Change to "gpu" for GPU support
TARGET_DIRECTORY='/usr/local'
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.3.0-rc2.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.3.0.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
index f9b7b322ca..2adcd4da73 100644
--- a/tensorflow/docs_src/install/install_java.md
+++ b/tensorflow/docs_src/install/install_java.md
@@ -34,7 +34,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.3.0-rc2</version>
+ <version>1.3.0</version>
</dependency>
```
@@ -63,7 +63,7 @@ As an example, these steps will create a Maven project that uses TensorFlow:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.3.0-rc2</version>
+ <version>1.3.0</version>
</dependency>
</dependencies>
</project>
@@ -122,7 +122,7 @@ refer to the simpler instructions above instead.
Take the following steps to install TensorFlow for Java on Linux or Mac OS:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.3.0-rc2.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.3.0.jar),
which is the TensorFlow Java Archive (JAR).
2. Decide whether you will run TensorFlow for Java on CPU(s) only or with
@@ -141,7 +141,7 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS:
OS=$(uname -s | tr '[:upper:]' '[:lower:]')
mkdir -p ./jni
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.3.0-rc2.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.3.0.tar.gz" |
tar -xz -C ./jni
### Install on Windows
@@ -149,10 +149,10 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS:
Take the following steps to install TensorFlow for Java on Windows:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.3.0-rc2.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.3.0.jar),
which is the TensorFlow Java Archive (JAR).
2. Download the following Java Native Interface (JNI) file appropriate for
- [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.3.0-rc2.zip).
+ [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.3.0.zip).
3. Extract this .zip file.
@@ -200,7 +200,7 @@ must be part of your `classpath`. For example, you can include the
downloaded `.jar` in your `classpath` by using the `-cp` compilation flag
as follows:
-<pre><b>javac -cp libtensorflow-1.3.0-rc2.jar HelloTF.java</b></pre>
+<pre><b>javac -cp libtensorflow-1.3.0.jar HelloTF.java</b></pre>
### Running
@@ -214,11 +214,11 @@ two files are available to the JVM:
For example, the following command line executes the `HelloTF` program on Linux
and Mac OS X:
-<pre><b>java -cp libtensorflow-1.3.0-rc2.jar:. -Djava.library.path=./jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.3.0.jar:. -Djava.library.path=./jni HelloTF</b></pre>
And the following command line executes the `HelloTF` program on Windows:
-<pre><b>java -cp libtensorflow-1.3.0-rc2.jar;. -Djava.library.path=jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.3.0.jar;. -Djava.library.path=jni HelloTF</b></pre>
If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
installed TensorFlow for Java and are ready to use the API. If the program
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index 85182cc74f..a961c6b1ce 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -172,7 +172,7 @@ Take the following steps to install TensorFlow with Virtualenv:
virtualenv environment:
<pre>(tensorflow)$ <b>pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0rc2-cp34-cp34m-linux_x86_64.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp34-cp34m-linux_x86_64.whl</b></pre>
If you encounter installation problems, see
[Common Installation Problems](#common_installation_problems).
@@ -277,7 +277,7 @@ take the following steps:
<pre>
$ <b>sudo pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0rc2-cp34-cp34m-linux_x86_64.whl</b>
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp34-cp34m-linux_x86_64.whl</b>
</pre>
If this step fails, see
@@ -464,7 +464,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
<pre>
(tensorflow)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0rc2-cp34-cp34m-linux_x86_64.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp34-cp34m-linux_x86_64.whl</b></pre>
<a name="ValidateYourInstallation"></a>
@@ -632,14 +632,14 @@ This section documents the relevant values for Linux installations.
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0rc2-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp27-none-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.3.0rc2-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.3.0-cp27-none-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -651,14 +651,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0rc2-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp34-cp34m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.3.0rc2-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.3.0-cp34-cp34m-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -670,14 +670,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0rc2-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.3.0rc2-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.3.0-cp35-cp35m-linux_x86_64.whl
</pre>
@@ -689,14 +689,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0rc2-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.3.0-cp36-cp36m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.3.0rc2-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.3.0-cp36-cp36m-linux_x86_64.whl
</pre>
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index 733ecc37fb..6bae3c03d1 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -109,7 +109,7 @@ Take the following steps to install TensorFlow with Virtualenv:
TensorFlow in the active Virtualenv is as follows:
<pre> $ <b>pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0rc2-py2-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0-py2-none-any.whl</b></pre>
If you encounter installation problems, see
[Common Installation Problems](#common-installation-problems).
@@ -230,7 +230,7 @@ take the following steps:
issue the following command:
<pre> $ <b>sudo pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0rc2-py2-none-any.whl</b> </pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0-py2-none-any.whl</b> </pre>
If the preceding command fails, see
[installation problems](#common-installation-problems).
@@ -339,7 +339,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
TensorFlow for Python 2.7:
<pre> (tensorflow)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0rc2-py2-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0-py2-none-any.whl</b></pre>
<a name="ValidateYourInstallation"></a>
@@ -512,7 +512,7 @@ This section documents the relevant values for Mac OS installations.
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0rc2-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0-py2-none-any.whl
</pre>
@@ -520,7 +520,7 @@ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0rc2-py2-none-a
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0rc2-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.3.0-py3-none-any.whl
</pre>
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index a69f982d76..810a27d1ab 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -343,10 +343,10 @@ Invoke `pip install` to install that pip package.
The filename of the `.whl` file depends on your platform.
For example, the following command will install the pip package
-for TensorFlow 1.3.0rc2 on Linux:
+for TensorFlow 1.3.0 on Linux:
<pre>
-$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.3.0rc2-py2-none-any.whl</b>
+$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.3.0-py2-none-any.whl</b>
</pre>
## Validate your installation
diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md
index a9d7dd955a..a06ab88046 100644
--- a/tensorflow/docs_src/install/install_windows.md
+++ b/tensorflow/docs_src/install/install_windows.md
@@ -115,12 +115,12 @@ Take the following steps to install TensorFlow in an Anaconda environment:
environment. To install the CPU-only version of TensorFlow, enter the
following command:
- <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-1.3.0rc2-cp35-cp35m-win_amd64.whl</b> </pre>
+ <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-1.3.0-cp35-cp35m-win_amd64.whl</b> </pre>
To install the GPU version of TensorFlow, enter the following command
(on a single line):
- <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-1.3.0rc2-cp35-cp35m-win_amd64.whl</b> </pre>
+ <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-1.3.0-cp35-cp35m-win_amd64.whl</b> </pre>
## Validate your installation
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index e464d98a15..4dcd88befa 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -9981,6 +9981,52 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf
return op.Output(0)
}
+// ArgAttr is an optional argument to Arg.
+type ArgAttr func(optionalAttr)
+
+// ArgTout sets the optional Tout attribute to value.
+// If not specified, defaults to DT_FLOAT
+func ArgTout(value tf.DataType) ArgAttr {
+ return func(m optionalAttr) {
+ m["Tout"] = value
+ }
+}
+
+// Returns the argument of a complex number.
+//
+// Given a tensor `input` of complex numbers, this operation returns a tensor of
+// type `float` that is the argument of each element in `input`. All elements in
+// `input` must be complex numbers of the form \\(a + bj\\), where *a*
+// is the real part and *b* is the imaginary part.
+//
+//
+// The argument returned by this operation is of the form \\(atan2(b, a)\\).
+//
+// For example:
+//
+// ```
+// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
+// tf.angle(input) ==> [2.0132, 1.056]
+// ```
+func Angle(scope *Scope, input tf.Output, optional ...ArgAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Angle",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes fingerprints of the input strings.
//
// Arguments:
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 42fdd72053..40589528f2 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -4038,6 +4038,7 @@ cuda_py_test(
"//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
],
+ tags = ["manual"],
)
py_library(
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index b4f0fd6f40..32c738f0f1 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -1701,43 +1701,6 @@ class SessionTest(test_util.TensorFlowTestCase):
server = server_lib.Server.create_local_server()
self.runTestBuildGraphError(session.Session(server.target))
- def testGraphOptimizer(self):
- rewrite_options = rewriter_config_pb2.RewriterConfig(
- disable_model_pruning=False, constant_folding=True)
- graph_options = config_pb2.GraphOptions(
- rewrite_options=rewrite_options, build_cost_model=1)
- config = config_pb2.ConfigProto(graph_options=graph_options)
-
- with ops.Graph().as_default() as g:
- r1 = random_ops.random_normal(shape=[2, 3], name='R1')
- r2 = random_ops.random_normal(shape=[2, 3], name='R2')
- copy1 = array_ops.stop_gradient(r1)
- copy2 = array_ops.identity(r2)
- result = copy1 + copy2
-
- with session.Session(graph=g, config=config) as sess:
- metadata = config_pb2.RunMetadata()
- sess.run(result, run_metadata=metadata)
-
- # Check that we optimized the graph by looking at the cost model: the add
- # node should have been reconnected directly to the R1 and R2 nodes.
- found_valid_nodes = 0
- for node in metadata.cost_graph.node:
- if node.name == 'R1':
- r1_cost_id = node.id
- found_valid_nodes += 1
- if node.name == 'R2':
- r2_cost_id = node.id
- found_valid_nodes += 1
- if node.name == 'add':
- if node.input_info[0].preceding_node == r1_cost_id:
- self.assertEqual(node.input_info[1].preceding_node, r2_cost_id)
- found_valid_nodes += 1
- elif node.input_info[0].preceding_node == r2_cost_id:
- self.assertEqual(node.input_info[1].preceding_node, r1_cost_id)
- found_valid_nodes += 1
- self.assertEqual(3, found_valid_nodes)
-
def testDeviceAttributes(self):
attrs = session._DeviceAttributes(
'/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337)
diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
index 847b27b904..149425436a 100644
--- a/tensorflow/python/estimator/inputs/queues/feeding_functions.py
+++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
@@ -46,6 +46,64 @@ except ImportError:
HAS_PANDAS = False
+def _fill_array(arr, seq, fillvalue=0):
+ """Recursively fills padded arr with elements from seq.
+
+ If lenght of seq is less then arr padded length, fillvalue used.
+ Args:
+ arr: Padded tensor of shape [batch_size, ..., max_padded_dim_len].
+ seq: Non-padded list of data sampels of shape
+ [batch_size, ..., padded_dim(None)]
+ fillvalue: Default fillvalue to use.
+ """
+ if arr.ndim == 1:
+ try:
+ len_ = len(seq)
+ except TypeError:
+ len_ = 0
+ arr[:len_] = seq
+ arr[len_:] = fillvalue
+ else:
+ for subarr, subseq in six.moves.zip_longest(arr, seq, fillvalue=()):
+ _fill_array(subarr, subseq, fillvalue)
+
+
+def _pad_if_needed(batch_key_item, fillvalue=0):
+ """ Returns padded batch.
+
+ Args:
+ batch_key_item: List of data samples of any type with shape
+ [batch_size, ..., padded_dim(None)].
+ fillvalue: Default fillvalue to use.
+
+ Returns:
+ Padded with zeros tensor of same type and shape
+ [batch_size, ..., max_padded_dim_len].
+
+ Raises:
+ ValueError if data samples have different shapes (except last padded dim).
+ """
+ shapes = [
+ seq.shape[:-1] if len(seq.shape) > 0 else -1 for seq in batch_key_item
+ ]
+ if not all(shapes[0] == x for x in shapes):
+ raise ValueError("Array shapes must match.")
+
+ last_length = [
+ seq.shape[-1] if len(seq.shape) > 0 else 0 for seq in batch_key_item
+ ]
+ if all([x == last_length[0] for x in last_length]):
+ return batch_key_item
+
+ batch_size = len(batch_key_item)
+ max_sequence_length = max(last_length)
+ result_batch = np.zeros(
+ shape=[batch_size] + list(shapes[0]) + [max_sequence_length],
+ dtype=batch_key_item[0].dtype)
+ _fill_array(result_batch, batch_key_item, fillvalue)
+ return result_batch
+
+
def _get_integer_indices_for_next_batch(
batch_indices_start, batch_size, epoch_end, array_length,
current_epoch, total_epochs):
@@ -229,7 +287,8 @@ class _GeneratorFeedFn(object):
batch_size,
random_start=False,
seed=None,
- num_epochs=None):
+ num_epochs=None,
+ pad_value=None):
first_sample = next(generator())
if len(placeholders) != len(first_sample):
raise ValueError("Expected {} placeholders; got {}.".format(
@@ -241,6 +300,7 @@ class _GeneratorFeedFn(object):
self._batch_size = batch_size
self._num_epochs = num_epochs
self._epoch = 0
+ self._pad_value = pad_value
random.seed(seed)
def __call__(self):
@@ -264,7 +324,17 @@ class _GeneratorFeedFn(object):
list_dict.setdefault(self._col_placeholders[index],
list()).append(data_row[key])
list_dict_size += 1
- feed_dict = {key: np.asarray(item) for key, item in list(list_dict.items())}
+
+ if self._pad_value is not None:
+ feed_dict = {
+ key: np.asarray(_pad_if_needed(item, self._pad_value))
+ for key, item in list(list_dict.items())
+ }
+ else:
+ feed_dict = {
+ key: np.asarray(item)
+ for key, item in list(list_dict.items())
+ }
return feed_dict
@@ -276,7 +346,8 @@ def _enqueue_data(data,
seed=None,
name="enqueue_input",
enqueue_size=1,
- num_epochs=None):
+ num_epochs=None,
+ pad_value=None):
"""Creates a queue filled from a numpy array or pandas `DataFrame`.
Returns a queue filled with the rows of the given (`OrderedDict` of) array
@@ -298,6 +369,7 @@ def _enqueue_data(data,
name: a scope name identifying the data.
enqueue_size: the number of rows to enqueue per step.
num_epochs: limit enqueuing to a specified number of epochs, if provided.
+ pad_value: default value for dynamic padding of data samples, if provided.
Returns:
A queue filled with the rows of the given (`OrderedDict` of) array or
@@ -306,6 +378,8 @@ def _enqueue_data(data,
Raises:
TypeError: `data` is not a Pandas `DataFrame`, an `OrderedDict` of numpy
arrays, a numpy `ndarray`, or a generator producing these.
+ NotImplementedError: padding and shuffling data at the same time.
+ NotImplementedError: padding usage with non generator data type.
"""
with ops.name_scope(name):
if isinstance(data, np.ndarray):
@@ -336,6 +410,14 @@ def _enqueue_data(data,
"data must be either a numpy array or pandas DataFrame if pandas is "
"installed; got {}".format(type(data).__name__))
+ pad_data = pad_value is not None
+ if pad_data and get_feed_fn is not _GeneratorFeedFn:
+ raise NotImplementedError(
+ "padding is only available with generator usage")
+ if shuffle and pad_data:
+ raise NotImplementedError(
+ "padding and shuffling data at the same time is not implemented")
+
# TODO(jamieas): TensorBoard warnings for all warnings below once available.
if num_threads > 1 and num_epochs is not None:
@@ -368,6 +450,13 @@ def _enqueue_data(data,
dtypes=types,
shapes=queue_shapes,
seed=seed)
+ elif pad_data:
+ min_after_dequeue = 0 # just for the summary text
+ queue_shapes = list(
+ map(lambda x: tuple(list(x[:-1]) + [None]) if len(x) > 0 else x,
+ queue_shapes))
+ queue = data_flow_ops.PaddingFIFOQueue(
+ capacity, dtypes=types, shapes=queue_shapes)
else:
min_after_dequeue = 0 # just for the summary text
queue = data_flow_ops.FIFOQueue(
@@ -383,14 +472,26 @@ def _enqueue_data(data,
enqueue_ops.append(queue.enqueue_many(placeholders))
seed_i = None if seed is None else (i + 1) * seed
- feed_fns.append(
- get_feed_fn(
- placeholders,
- data,
- enqueue_size,
- random_start=shuffle,
- seed=seed_i,
- num_epochs=num_epochs))
+
+ if not pad_data:
+ feed_fns.append(
+ get_feed_fn(
+ placeholders,
+ data,
+ enqueue_size,
+ random_start=shuffle,
+ seed=seed_i,
+ num_epochs=num_epochs))
+ else:
+ feed_fns.append(
+ get_feed_fn(
+ placeholders,
+ data,
+ enqueue_size,
+ random_start=shuffle,
+ seed=seed_i,
+ num_epochs=num_epochs,
+ pad_value=pad_value))
runner = fqr._FeedingQueueRunner( # pylint: disable=protected-access
queue=queue, enqueue_ops=enqueue_ops, feed_fns=feed_fns)
diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions_test.py b/tensorflow/python/estimator/inputs/queues/feeding_functions_test.py
index 0e602d3f33..3508f5aa3c 100644
--- a/tensorflow/python/estimator/inputs/queues/feeding_functions_test.py
+++ b/tensorflow/python/estimator/inputs/queues/feeding_functions_test.py
@@ -290,6 +290,102 @@ class _FeedingFunctionsTestCase(test.TestCase):
actual = aff()
self.assertEqual(expected, vals_to_list(actual))
+ def testFillArraySmall(self):
+ a = (np.ones(shape=[32, 32], dtype=np.int32).tolist() + np.ones(
+ shape=[32, 36], dtype=np.int32).tolist())
+ actual = np.ones(shape=[64, 36], dtype=np.int32)
+ ff._fill_array(actual, a)
+ expected = np.ones(shape=[64, 36], dtype=np.int32)
+ expected[:32, 32:] = 0
+ self.assertEqual(expected.tolist(), actual.tolist())
+
+ def testFillArrayLarge(self):
+ a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() + np.ones(
+ shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())
+ actual = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)
+ ff._fill_array(actual, a)
+ expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)
+ expected[:8, ..., 32:] = 0
+ self.assertEqual(expected.tolist(), actual.tolist())
+
+ def testFillArraySmallWithSpecifiedValue(self):
+ fill_value = 8
+ a = (np.ones(shape=[32, 32], dtype=np.int32).tolist() + np.ones(
+ shape=[32, 36], dtype=np.int32).tolist())
+ actual = np.ones(shape=[64, 36], dtype=np.int32)
+ ff._fill_array(actual, a, fill_value)
+ expected = np.ones(shape=[64, 36], dtype=np.int32)
+ expected[:32, 32:] = fill_value
+ self.assertEqual(expected.tolist(), actual.tolist())
+
+ def testFillArrayLargeWithSpecifiedValue(self):
+ fill_value = 8
+ a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() + np.ones(
+ shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())
+ actual = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)
+ ff._fill_array(actual, a, fill_value)
+ expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)
+ expected[:8, ..., 32:] = fill_value
+ self.assertEqual(expected.tolist(), actual.tolist())
+
+ def testPadIfNeededSmall(self):
+ a = (np.ones(shape=[32, 32], dtype=np.int32).tolist() + np.ones(
+ shape=[32, 36], dtype=np.int32).tolist())
+ a = list(map(np.array, a))
+ actual = ff._pad_if_needed(a)
+ expected = np.ones(shape=[64, 36], dtype=np.int32)
+ expected[:32, 32:] = 0
+ self.assertEqual(expected.tolist(), actual.tolist())
+
+ def testPadIfNeededLarge(self):
+ a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() + np.ones(
+ shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())
+ a = list(map(np.array, a))
+ actual = ff._pad_if_needed(a)
+ expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)
+ expected[:8, ..., 32:] = 0
+ self.assertEqual(expected.tolist(), actual.tolist())
+
+ def testPadIfNeededSmallWithSpecifiedValue(self):
+ fill_value = 8
+ a = (np.ones(shape=[32, 32], dtype=np.int32).tolist() + np.ones(
+ shape=[32, 36], dtype=np.int32).tolist())
+ a = list(map(np.array, a))
+ actual = ff._pad_if_needed(a, fill_value)
+ expected = np.ones(shape=[64, 36], dtype=np.int32)
+ expected[:32, 32:] = fill_value
+ self.assertEqual(expected.tolist(), actual.tolist())
+
+ def testPadIfNeededLargeWithSpecifiedValue(self):
+ fill_value = 8
+ a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() + np.ones(
+ shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())
+ a = list(map(np.array, a))
+ actual = ff._pad_if_needed(a, fill_value)
+ expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)
+ expected[:8, ..., 32:] = fill_value
+ self.assertEqual(expected.tolist(), actual.tolist())
+
+ def testPadIfNeededSmallWithSpecifiedNonNumericValue(self):
+ fill_value = False
+ a = (np.ones(shape=[32, 32], dtype=np.bool).tolist() + np.ones(
+ shape=[32, 36], dtype=np.bool).tolist())
+ a = list(map(np.array, a))
+ actual = ff._pad_if_needed(a, fill_value)
+ expected = np.ones(shape=[64, 36], dtype=np.bool)
+ expected[:32, 32:] = fill_value
+ self.assertEqual(expected.tolist(), actual.tolist())
+
+ def testPadIfNeededLargeWithSpecifiedNonNumericValue(self):
+ fill_value = False
+ a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.bool).tolist() + np.ones(
+ shape=[8, 8, 8, 8, 36], dtype=np.bool).tolist())
+ a = list(map(np.array, a))
+ actual = ff._pad_if_needed(a, fill_value)
+ expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.bool)
+ expected[:8, ..., 32:] = fill_value
+ self.assertEqual(expected.tolist(), actual.tolist())
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index d03a04b24e..95cd27ebad 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -71,9 +71,26 @@ def _sparsify(x, thresh=0.5, index_dtype=np.int64):
indices=x_indices, values=x_values, dense_shape=x_shape), x_values
+def _default_tolerance(dtype):
+ """Returns a sensible default tolerance for comparing results of a given
+ type"""
+ if dtype == np.float16:
+ return 5e-3
+ elif dtype in (np.float32, np.complex64):
+ return 1e-3
+ elif dtype in (np.float64, np.complex128):
+ return 1e-5
+ else:
+ return None # Fail fast for unexpected types
+
+
class UnaryOpTest(test.TestCase):
- def _compareCpu(self, x, np_func, tf_func):
+ def _compareCpu(self, x, np_func, tf_func, grad_rtol=None, grad_atol=None):
+ if grad_rtol is None:
+ grad_rtol = _default_tolerance(x.dtype)
+ if grad_atol is None:
+ grad_atol = _default_tolerance(x.dtype)
np_ans = np_func(x)
with self.test_session(use_gpu=False):
inx = ops.convert_to_tensor(x)
@@ -102,17 +119,17 @@ class UnaryOpTest(test.TestCase):
_, jacob_n = gradient_checker.compute_gradient(
inxf, s, yf, s, x_init_value=xf, delta=1e-2)
jacob_n = jacob_n.astype(np.float16)
- self.assertAllClose(jacob_t, jacob_n, rtol=5e-3, atol=5e-3)
+ self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
elif x.dtype in (np.float32, np.complex64):
s = list(np.shape(x))
jacob_t, jacob_n = gradient_checker.compute_gradient(
inx, s, y, s, x_init_value=x, delta=1e-3)
- self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+ self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
elif x.dtype in (np.float64, np.complex128):
s = list(np.shape(x))
jacob_t, jacob_n = gradient_checker.compute_gradient(
inx, s, y, s, x_init_value=x, delta=1e-5)
- self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
+ self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
def _check(self, result_tensor, result_np, input_sp_t, tol):
self.assertTrue(isinstance(result_tensor, sparse_tensor.SparseTensor))
@@ -407,8 +424,13 @@ class UnaryOpTest(test.TestCase):
self._compareCpu(x, np.sinh, math_ops.sinh)
self._compareCpu(x, np.cosh, math_ops.cosh)
self._compareCpu(x, np.tanh, math_ops.tanh)
- self._compareCpu(y, np.arcsinh, math_ops.asinh)
- self._compareCpu(y, np.arccosh, math_ops.acosh)
+
+ # Complex64 versions of asinh() and acosh() in libstdc++ only have 6 digits
+ # of precision.
+ # Small gradient values + low precision --> High relative error
+ self._compareCpu(y, np.arcsinh, math_ops.asinh, grad_rtol=1e-2)
+ self._compareCpu(y, np.arccosh, math_ops.acosh, grad_rtol=1e-2)
+
self._compareCpu(y, np.arctanh, math_ops.atanh)
self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
self._compareCpu(x, np.sin, math_ops.sin)
@@ -1930,6 +1952,33 @@ class ComplexMakeRealImagTest(test.TestCase):
self._compareRealImag(cplx, use_gpu=False)
self._compareRealImag(cplx, use_gpu=True)
+ def _compareAngle(self, cplx, use_gpu):
+ np_angle = np.angle(cplx)
+ with self.test_session(use_gpu=use_gpu) as sess:
+ inx = ops.convert_to_tensor(cplx)
+ tf_angle = math_ops.angle(inx)
+ tf_angle_val = tf_angle.eval()
+ self.assertAllEqual(np_angle, tf_angle_val)
+ self.assertShapeEqual(np_angle, tf_angle)
+
+ def testAngle64(self):
+ real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float32)
+ imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float32)
+ cplx = real + 1j * imag
+ self._compareAngle(cplx, use_gpu=False)
+ # TODO: Enable GPU tests for angle op after resolving
+ # build failures on GPU (See #10643 for context).
+ # self._compareAngle(cplx, use_gpu=True)
+
+ def testAngle(self):
+ real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float64)
+ imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float64)
+ cplx = real + 1j * imag
+ self._compareAngle(cplx, use_gpu=False)
+ # TODO: Enable GPU tests for angle op after resolving
+ # build failures on GPU (See #10643 for context).
+ # self._compareAngle(cplx, use_gpu=True)
+
def testRealReal(self):
for dtype in dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.float32, dtypes_lib.float64:
x = array_ops.placeholder(dtype)
@@ -2062,6 +2111,28 @@ class AccumulateTest(test.TestCase):
tf_val = math_ops.accumulate_n([])
tf_val.eval()
+ def testWrongShape(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ a = variables.Variable(0.2)
+ b = variables.Variable(0.1)
+ tf_val = math_ops.accumulate_n(
+ [a, b], shape=[2, 2]) # Should be shape=[]
+
+ def testWrongType(self):
+ with self.test_session():
+ with self.assertRaises(TypeError):
+ a = variables.Variable(0.2, dtype=np.float32)
+ b = variables.Variable(0.1, dtype=np.float32)
+ tf_val = math_ops.accumulate_n([a, b], tensor_dtype=np.int32)
+
+ def testWrongTypeOneInput(self):
+ # Scenario that used to trigger a bug, even when testWrongType() worked
+ with self.test_session():
+ with self.assertRaises(TypeError):
+ a = variables.Variable(0.2, dtype=np.float32)
+ tf_val = math_ops.accumulate_n([a], tensor_dtype=np.int32)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py
index 60ba16c1ac..f3731bad38 100644
--- a/tensorflow/python/kernel_tests/string_split_op_test.py
+++ b/tensorflow/python/kernel_tests/string_split_op_test.py
@@ -126,6 +126,24 @@ class StringSplitOpTest(test.TestCase):
[b"hello", b"cruel", b"world", b"hello cruel world"])
self.assertAllEqual(shape, [2, 3])
+ def testStringSplitWithNoSkipEmpty(self):
+ strings = ["#a", "b#", "#c#"]
+
+ with self.test_session() as sess:
+ tokens = string_ops.string_split(strings, "#", skip_empty=False)
+ indices, values, shape = sess.run(tokens)
+ self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0],
+ [2, 1], [2, 2]])
+ self.assertAllEqual(values, [b"", b"a", b"b", b"", b"", b"c", b""])
+ self.assertAllEqual(shape, [3, 3])
+
+ with self.test_session() as sess:
+ tokens = string_ops.string_split(strings, "#")
+ indices, values, shape = sess.run(tokens)
+ self.assertAllEqual(values, [b"a", b"b", b"c"])
+ self.assertAllEqual(indices, [[0, 0], [1, 0], [2, 0]])
+ self.assertAllEqual(shape, [3, 1])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 8f88a9b30d..3d70465e68 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -1023,6 +1023,19 @@ def _ImagGrad(_, grad):
return math_ops.complex(zero, grad)
+@ops.RegisterGradient("Angle")
+def _AngleGrad(op, grad):
+ """Returns -grad / (Im(x) + iRe(x))"""
+ x = op.inputs[0]
+ with ops.control_dependencies([grad.op]):
+ re = math_ops.real(x)
+ im = math_ops.imag(x)
+ z = math_ops.reciprocal(math_ops.complex(im, re))
+ zero = constant_op.constant(0, dtype=grad.dtype)
+ complex_grad = math_ops.complex(grad, zero)
+ return -complex_grad * z
+
+
@ops.RegisterGradient("Conj")
def _ConjGrad(_, grad):
"""Returns the complex conjugate of grad."""
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 7ee095745a..dc1a17829e 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -100,6 +100,7 @@ See the @{$python/math_ops} guide.
@@complex
@@conj
@@imag
+@@angle
@@real
@@fft
@@ifft
@@ -622,10 +623,9 @@ def imag(input, name=None):
r"""Returns the imaginary part of a complex number.
Given a tensor `input` of complex numbers, this operation returns a tensor of
- type `float32` or `float64` that is the imaginary part of each element in
- `input`. All elements in `input` must be complex numbers of the form \\(a +
- bj\\), where *a* is the real part and *b* is the imaginary part returned by
- this operation.
+ type `float` that is the argument of each element in `input`. All elements in
+ `input` must be complex numbers of the form \\(a + bj\\), where *a*
+ is the real part and *b* is the imaginary part returned by the operation.
For example:
@@ -646,6 +646,35 @@ def imag(input, name=None):
return gen_math_ops.imag(input, Tout=input.dtype.real_dtype, name=name)
+def angle(input, name=None):
+ r"""Returns the argument of a complex number.
+
+ Given a tensor `input` of complex numbers, this operation returns a tensor of
+ type `float32` or `float64` that is the argument of each element in `input`.
+ All elements in `input` must be complex numbers of the form \\(a + bj\\),
+ where *a* is the real part and *b* is the imaginary part.
+
+ The argument returned by this function is of the form \\(atan2(b, a)\\).
+
+ For example:
+
+ ```
+ # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
+ tf.angle(input) ==> [2.0132, 1.056]
+ ```
+
+ Args:
+ input: A `Tensor`. Must be one of the following types: `complex64`,
+ `complex128`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `float32` or `float64`.
+ """
+ with ops.name_scope(name, "Angle", [input]) as name:
+ return gen_math_ops.angle(input, Tout=input.dtype.real_dtype, name=name)
+
+
# pylint: enable=redefined-outer-name,redefined-builtin
@@ -2027,10 +2056,13 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
for input_tensor in inputs:
if isinstance(input_tensor, ops.Tensor):
shape = shape.merge_with(input_tensor.get_shape())
- if len(inputs) == 1:
- return inputs[0]
if tensor_dtype is None:
tensor_dtype = inputs[0].dtype
+ if tensor_dtype != inputs[0].dtype:
+ raise TypeError("tensor_dtype is {}, but input is of type {}"
+ .format(tensor_dtype, inputs[0].dtype))
+ if len(inputs) == 1:
+ return inputs[0]
with ops.name_scope(name, "AccumulateN", inputs) as name:
var = gen_state_ops._temporary_variable(
shape=tensor_shape.vector(0), dtype=tensor_dtype)
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 46a5792474..b36cfeb2eb 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -317,6 +317,21 @@ class AddNTest(test_util.TensorFlowTestCase):
self.assertAllEqual(x[0] * num_inputs,
math_ops.add_n([tf_x[0]] * num_inputs).eval())
+ def testGrad(self):
+ np.random.seed(42)
+ for num_inputs in range(1, 10):
+ with self.test_session(use_gpu=True) as sess:
+ input_vars = [
+ variables.Variable(10.0 * np.random.random())
+ for i in range(0, num_inputs)
+ ]
+ addn = math_ops.add_n(input_vars)
+ sess.run(variables.global_variables_initializer())
+ add_n_grad = gradients.gradients(addn, input_vars)
+ self.assertAllEqual(
+ np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1
+ [g.eval() for g in add_n_grad])
+
class DivAndModTest(test_util.TensorFlowTestCase):
# TODO(aselle): Test more types before exposing new division operators.
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 97f2f761a6..f30e79a108 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -50,7 +50,7 @@ from tensorflow.python.util import deprecation
# pylint: enable=wildcard-import
-def string_split(source, delimiter=" "): # pylint: disable=invalid-name
+def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name
"""Split elements of `source` based on `delimiter` into a `SparseTensor`.
Let N be the size of source (typically N will be the batch size). Split each
@@ -78,6 +78,7 @@ def string_split(source, delimiter=" "): # pylint: disable=invalid-name
source: `1-D` string `Tensor`, the strings to split.
delimiter: `0-D` string `Tensor`, the delimiter character, the string should
be length 0 or 1.
+ skip_empty: A `bool`. If `True`, skip the empty strings from the result.
Raises:
ValueError: If delimiter is not a string.
@@ -92,7 +93,7 @@ def string_split(source, delimiter=" "): # pylint: disable=invalid-name
# pylint: disable=protected-access
indices, values, shape = gen_string_ops._string_split(
- source, delimiter=delimiter)
+ source, delimiter=delimiter, skip_empty=skip_empty)
# pylint: enable=protected-access
indices.set_shape([None, 2])
values.set_shape([None])
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index e5484e02b5..1562f65675 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -566,6 +566,8 @@ class _MonitoredSession(object):
h.end(self._coordinated_creator.tf_sess)
finally:
try:
+ if self._sess is None:
+ raise RuntimeError('Session is already closed.')
self._sess.close()
finally:
self._sess = None
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index ffd9ed311f..a7c34cdd1b 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -1414,6 +1414,13 @@ class MonitoredSessionTest(test.TestCase):
isinstance(hook.run_metadata_list[0], config_pb2.RunMetadata))
self.assertGreater(len(hook.run_metadata_list[0].partition_graphs), 0)
+ def test_with_statement_and_close(self):
+ # Test case for https://github.com/tensorflow/tensorflow/issues/12224
+ # where close() inside the with should have a better error message.
+ with self.assertRaisesRegexp(RuntimeError, 'Session is already closed'):
+ with monitored_session.MonitoredSession() as session:
+ session.close()
+
class SingularMonitoredSessionTest(test.TestCase):
"""Tests SingularMonitoredSession."""
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index f9733b4963..a7a044f64e 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -561,6 +561,10 @@ tf_module {
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "angle"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "arg_max"
argspec: "args=[\'input\', \'dimension\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
}
@@ -1878,7 +1882,7 @@ tf_module {
}
member_method {
name: "string_split"
- argspec: "args=[\'source\', \'delimiter\'], varargs=None, keywords=None, defaults=[\' \'], "
+ argspec: "args=[\'source\', \'delimiter\', \'skip_empty\'], varargs=None, keywords=None, defaults=[\' \', \'True\'], "
}
member_method {
name: "string_to_hash_bucket"
diff --git a/tensorflow/tools/dist_test/Dockerfile b/tensorflow/tools/dist_test/Dockerfile
index 83bbeeca8a..cd64e2c518 100644
--- a/tensorflow/tools/dist_test/Dockerfile
+++ b/tensorflow/tools/dist_test/Dockerfile
@@ -26,7 +26,6 @@ RUN apt-get update
RUN apt-get install -y \
curl \
python \
- python-numpy \
python-pip \
&& \
apt-get clean && \
diff --git a/tensorflow/tools/dist_test/Dockerfile.local b/tensorflow/tools/dist_test/Dockerfile.local
index 0cfb8d529e..7a896ab611 100644
--- a/tensorflow/tools/dist_test/Dockerfile.local
+++ b/tensorflow/tools/dist_test/Dockerfile.local
@@ -23,7 +23,6 @@ MAINTAINER Shanqing Cai <cais@google.com>
# Pick up some TF dependencies.
RUN apt-get update && apt-get install -y \
- python-numpy \
python-pip \
&& \
apt-get clean && \
diff --git a/tensorflow/tools/gcs_test/Dockerfile b/tensorflow/tools/gcs_test/Dockerfile
index 581bded65d..5af753226f 100644
--- a/tensorflow/tools/gcs_test/Dockerfile
+++ b/tensorflow/tools/gcs_test/Dockerfile
@@ -7,7 +7,6 @@ RUN apt-get install -y \
curl \
libcurl4-openssl-dev \
python \
- python-numpy \
python-pip
# Install Google Cloud SDK
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index cdce308a50..f48fdcc9ec 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -54,11 +54,10 @@ function main() {
while true; do
if [[ "$1" == "--nightly_flag" ]]; then
NIGHTLY_BUILD=1
- fi
- if [[ "$1" == "--gpu" ]]; then
+ elif [[ "$1" == "--gpu" ]]; then
GPU_BUILD=1
elif [[ "$1" == "--gpudirect" ]]; then
- GPU_FLAG="--project_name tensorflow_gpudirect"
+ PKG_NAME_FLAG="--project_name tensorflow_gpudirect"
fi
shift
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 2b840c1a3a..1c45e9e5f6 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -29,11 +29,11 @@ from setuptools.dist import Distribution
# This version string is semver compatible, but incompatible with pip.
# For pip, we will remove all '-' characters from this string, and use the
# result for pip.
-_VERSION = '1.3.0-rc2'
+_VERSION = '1.3.0'
REQUIRED_PACKAGES = [
'enum34 >= 1.1.6',
- 'numpy >= 1.11.0',
+ 'numpy >= 1.12.1',
'six >= 1.10.0',
'protobuf >= 3.3.0',
'tensorflow-tensorboard >= 0.1.0, < 0.2.0',
@@ -55,6 +55,13 @@ else:
# mock comes with unittest.mock for python3, need to install for python2
REQUIRED_PACKAGES.append('mock >= 2.0.0')
+# remove tensorboard from tf-nightly packages
+if 'tf_nightly' in project_name:
+ for package in REQUIRED_PACKAGES:
+ if 'tensorflow-tensorboard' in package:
+ REQUIRED_PACKAGES.remove(package)
+ break
+
# weakref.finalize was introduced in Python 3.4
if sys.version_info < (3, 4):
REQUIRED_PACKAGES.append('backports.weakref >= 1.0rc1')
@@ -70,6 +77,10 @@ CONSOLE_SCRIPTS = [
]
# pylint: enable=line-too-long
+# remove the tensorboard console script if building tf_nightly
+if 'tf_nightly' in project_name:
+ CONSOLE_SCRIPTS.remove('tensorboard = tensorboard.main:main')
+
TEST_PACKAGES = [
'scipy >= 0.15.1',
]