aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--README.md10
-rw-r--r--configure.py112
-rw-r--r--tensorflow/BUILD2
-rw-r--r--tensorflow/c/eager/BUILD76
-rw-r--r--tensorflow/c/exported_symbols.lds1
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc12
-rw-r--r--tensorflow/cc/gradients/nn_grad.cc15
-rw-r--r--tensorflow/cc/gradients/nn_grad_test.cc18
-rw-r--r--tensorflow/compiler/jit/BUILD6
-rw-r--r--tensorflow/compiler/tf2xla/BUILD7
-rw-r--r--tensorflow/compiler/tf2xla/xla_cpu_backend.cc38
-rw-r--r--tensorflow/compiler/tf2xla/xla_gpu_backend.cc35
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc31
-rw-r--r--tensorflow/compiler/xla/literal_util.cc2
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc22
-rw-r--r--tensorflow/compiler/xla/tests/convolution_variants_test.cc100
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc17
-rw-r--r--tensorflow/contrib/android/cmake/CMakeLists.txt2
-rw-r--r--tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java63
-rw-r--r--tensorflow/contrib/cmake/external/boringssl.cmake2
-rw-r--r--tensorflow/contrib/crf/__init__.py6
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util.py10
-rwxr-xr-xtensorflow/contrib/image/__init__.py4
-rw-r--r--tensorflow/contrib/imperative/README.md2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py51
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimators_test.py39
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py46
-rw-r--r--tensorflow/contrib/lookup/__init__.py2
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc2
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h12
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc7
-rw-r--r--tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py57
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn.py20
-rw-r--r--tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py14
-rw-r--r--tensorflow/contrib/slim/python/slim/learning.py11
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/setup.py20
-rw-r--r--tensorflow/contrib/verbs/rdma.cc197
-rw-r--r--tensorflow/contrib/verbs/rdma.h7
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc51
-rw-r--r--tensorflow/contrib/verbs/verbs_util.cc49
-rw-r--r--tensorflow/contrib/verbs/verbs_util.h14
-rw-r--r--tensorflow/core/distributed_runtime/master.cc46
-rw-r--r--tensorflow/core/distributed_runtime/master.h4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc53
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc4
-rw-r--r--tensorflow/core/distributed_runtime/worker_session.cc2
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.cc1
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc4
-rw-r--r--tensorflow/core/kernels/sparse_fill_empty_rows_op.cc6
-rw-r--r--tensorflow/core/kernels/variable_ops.cc1
-rw-r--r--tensorflow/core/lib/core/refcount.h2
-rw-r--r--tensorflow/core/ops/array_ops.cc23
-rw-r--r--tensorflow/core/platform/hadoop/hadoop_file_system.cc2
-rw-r--r--tensorflow/core/profiler/README.md16
-rw-r--r--tensorflow/docs_src/api_guides/python/reading_data.md4
-rw-r--r--tensorflow/docs_src/deploy/hadoop.md4
-rw-r--r--tensorflow/docs_src/get_started/get_started.md22
-rw-r--r--tensorflow/docs_src/install/index.md2
-rw-r--r--tensorflow/docs_src/install/install_linux.md4
-rw-r--r--tensorflow/docs_src/install/install_mac.md14
-rw-r--r--tensorflow/docs_src/install/install_windows.md9
-rw-r--r--tensorflow/docs_src/programmers_guide/datasets.md2
-rw-r--r--tensorflow/docs_src/programmers_guide/threading_and_queues.md2
-rw-r--r--tensorflow/docs_src/programmers_guide/variables.md4
-rw-r--r--tensorflow/docs_src/tutorials/deep_cnn.md6
-rw-r--r--tensorflow/docs_src/tutorials/image_recognition.md18
-rw-r--r--tensorflow/docs_src/tutorials/image_retraining.md4
-rw-r--r--tensorflow/docs_src/tutorials/kernel_methods.md2
-rw-r--r--tensorflow/docs_src/tutorials/layers.md2
-rw-r--r--tensorflow/docs_src/tutorials/recurrent.md8
-rw-r--r--tensorflow/docs_src/tutorials/seq2seq.md28
-rw-r--r--tensorflow/docs_src/tutorials/using_gpu.md2
-rw-r--r--tensorflow/docs_src/tutorials/wide.md4
-rw-r--r--tensorflow/docs_src/tutorials/wide_and_deep.md6
-rw-r--r--tensorflow/docs_src/tutorials/word2vec.md20
-rw-r--r--tensorflow/examples/android/build.gradle2
-rw-r--r--tensorflow/examples/tutorials/word2vec/word2vec_basic.py17
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java63
-rw-r--r--tensorflow/java/src/main/native/graph_jni.cc20
-rw-r--r--tensorflow/java/src/main/native/graph_jni.h9
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java33
-rw-r--r--tensorflow/python/estimator/inputs/queues/feeding_functions.py73
-rw-r--r--tensorflow/python/estimator/inputs/queues/feeding_functions_test.py40
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py12
-rw-r--r--tensorflow/python/kernel_tests/denormal_test.py4
-rw-r--r--tensorflow/python/kernel_tests/string_split_op_test.py5
-rw-r--r--tensorflow/python/ops/array_ops.py8
-rw-r--r--tensorflow/python/ops/control_flow_ops.py14
-rw-r--r--tensorflow/python/ops/io_ops.py4
-rw-r--r--tensorflow/python/ops/math_ops.py2
-rw-r--r--tensorflow/python/ops/math_ops_test.py5
-rwxr-xr-xtensorflow/tools/ci_build/install/install_pip_packages.sh4
-rwxr-xr-xtensorflow/tools/ci_build/update_version.py8
-rw-r--r--tensorflow/tools/pip_package/check_load_py_test.py13
-rw-r--r--tensorflow/tools/pip_package/pip_smoke_test.py22
-rw-r--r--tensorflow/tools/pip_package/setup.py12
-rw-r--r--tensorflow/tools/proto_text/BUILD2
99 files changed, 1128 insertions, 791 deletions
diff --git a/README.md b/README.md
index 63bde4235e..87c7b1bfa9 100644
--- a/README.md
+++ b/README.md
@@ -22,7 +22,9 @@ networks research. The system is general enough to be applicable in a wide
variety of other domains, as well.
**If you want to contribute to TensorFlow, be sure to review the [contribution
-guidelines](CONTRIBUTING.md).**
+guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's
+[code of conduct](CODE_OF_CONDUCT.md). By participating, you are expected to
+uphold this code.**
**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for
tracking requests and bugs. So please see
@@ -57,7 +59,7 @@ $ python
>>> b = tf.constant(32)
>>> sess.run(a + b)
42
->>>
+>>> sess.close()
```
## For more information
@@ -69,3 +71,7 @@ $ python
* [TensorFlow course at Stanford](https://web.stanford.edu/class/cs20si)
Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate.
+
+## License
+
+[Apache License 2.0](LICENSE)
diff --git a/configure.py b/configure.py
index a0fbcdaccb..1a0f71ed94 100644
--- a/configure.py
+++ b/configure.py
@@ -25,6 +25,11 @@ import re
import subprocess
import sys
+try:
+ from shutil import which
+except ImportError:
+ from distutils.spawn import find_executable as which
+
_TF_BAZELRC = '.tf_configure.bazelrc'
_DEFAULT_CUDA_VERSION = '8.0'
_DEFAULT_CUDNN_VERSION = '6'
@@ -53,6 +58,10 @@ def is_ppc64le():
return platform.machine() == 'ppc64le'
+def is_cygwin():
+ return platform.system().startswith('CYGWIN_NT')
+
+
def get_input(question):
try:
try:
@@ -121,13 +130,20 @@ def write_action_env_to_bazelrc(var_name, var):
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
-def run_shell(cmd):
- return subprocess.check_output(cmd, shell=True).decode('UTF-8').strip()
+def run_shell(cmd, allow_non_zero=False):
+ if allow_non_zero:
+ try:
+ output = subprocess.check_output(cmd)
+ except subprocess.CalledProcessError as e:
+ output = e.output
+ else:
+ output = subprocess.check_output(cmd)
+ return output.decode('UTF-8').strip()
def cygpath(path):
"""Convert path from posix to windows."""
- return run_shell('cygpath -m "%s"' % path)
+ return run_shell(['cygpath', '-m', path])
def get_python_path(environ_cp, python_bin_path):
@@ -136,20 +152,14 @@ def get_python_path(environ_cp, python_bin_path):
if environ_cp.get('PYTHONPATH'):
python_paths = environ_cp.get('PYTHONPATH').split(':')
try:
- check_input = [
- python_bin_path, '-c',
- 'import site; print("\\n".join(site.getsitepackages()))'
- ]
- library_paths = subprocess.check_output(check_input).decode(
- 'UTF-8').strip().split('\n')
+ library_paths = run_shell(
+ [python_bin_path, '-c',
+ 'import site; print("\\n".join(site.getsitepackages()))']).split("\n")
except subprocess.CalledProcessError:
- check_input = [
- python_bin_path, '-c', 'from distutils.sysconfig import get_python_lib;'
- + 'print(get_python_lib())'
- ]
- library_paths = [
- subprocess.check_output(check_input).decode('UTF-8').strip()
- ]
+ library_paths = [run_shell(
+ [python_bin_path, '-c',
+ 'from distutils.sysconfig import get_python_lib;'
+ 'print(get_python_lib())'])]
all_paths = set(python_paths + library_paths)
@@ -162,8 +172,7 @@ def get_python_path(environ_cp, python_bin_path):
def get_python_major_version(python_bin_path):
"""Get the python major version."""
- check_input = [python_bin_path, '-c', 'import sys; print(sys.version[0])']
- return subprocess.check_output(check_input).decode('UTF-8').strip()
+ return run_shell([python_bin_path, '-c', 'import sys; print(sys.version[0])'])
def setup_python(environ_cp, bazel_version):
@@ -177,8 +186,8 @@ def setup_python(environ_cp, bazel_version):
environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path,
default_python_bin_path)
# Check if the path is valid
- if (os.path.isfile(python_bin_path) and os.access(
- python_bin_path, os.X_OK)) or (os.path.isdir(python_bin_path)):
+ if os.path.isfile(python_bin_path) and os.access(
+ python_bin_path, os.X_OK):
break
elif not os.path.exists(python_bin_path):
print('Invalid python path: %s cannot be found.' % python_bin_path)
@@ -187,7 +196,7 @@ def setup_python(environ_cp, bazel_version):
environ_cp['PYTHON_BIN_PATH'] = ''
# Convert python path to Windows style before checking lib and version
- if is_windows():
+ if is_cygwin():
python_bin_path = cygpath(python_bin_path)
# Get PYTHON_LIB_PATH
@@ -197,12 +206,12 @@ def setup_python(environ_cp, bazel_version):
if environ_cp.get('USE_DEFAULT_PYTHON_LIB_PATH') == '1':
python_lib_path = python_lib_paths[0]
else:
- print('Found possible Python library paths:\n%s' %
- '\n'.join(python_lib_paths))
+ print('Found possible Python library paths:\n %s' %
+ '\n '.join(python_lib_paths))
default_python_lib_path = python_lib_paths[0]
python_lib_path = get_input(
- 'Please input the desired Python library path to use. Default is %s'
- % python_lib_paths[0])
+ 'Please input the desired Python library path to use. '
+ 'Default is [%s]\n' % python_lib_paths[0])
if not python_lib_path:
python_lib_path = default_python_lib_path
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
@@ -210,7 +219,7 @@ def setup_python(environ_cp, bazel_version):
python_major_version = get_python_major_version(python_bin_path)
# Convert python path to Windows style before writing into bazel.rc
- if is_windows():
+ if is_cygwin():
python_lib_path = cygpath(python_lib_path)
# Set-up env variables used by python_configure.bzl
@@ -432,11 +441,10 @@ def check_bazel_version(min_version):
Returns:
The bazel version detected.
"""
- try:
- curr_version = run_shell('bazel --batch version')
- except subprocess.CalledProcessError:
+ if which('bazel') is None:
print('Cannot find bazel. Please install bazel.')
sys.exit(0)
+ curr_version = run_shell(['bazel', '--batch', 'version'])
for line in curr_version.split('\n'):
if 'Build label: ' in line:
@@ -529,7 +537,7 @@ def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var,
def set_clang_cuda_compiler_path(environ_cp):
"""Set CLANG_CUDA_COMPILER_PATH."""
- default_clang_path = run_shell('which clang || true')
+ default_clang_path = which('clang') or ''
ask_clang_path = ('Please specify which clang should be used as device and '
'host compiler. [Default is %s]: ') % default_clang_path
@@ -552,12 +560,12 @@ def set_clang_cuda_compiler_path(environ_cp):
def set_gcc_host_compiler_path(environ_cp):
"""Set GCC_HOST_COMPILER_PATH."""
- default_gcc_host_compiler_path = run_shell('which gcc || true')
+ default_gcc_host_compiler_path = which('gcc') or ''
cuda_bin_symlink = '%s/bin/gcc' % environ_cp.get('CUDA_TOOLKIT_PATH')
if os.path.islink(cuda_bin_symlink):
# os.readlink is only available in linux
- default_gcc_host_compiler_path = run_shell('readlink %s' % cuda_bin_symlink)
+ default_gcc_host_compiler_path = os.path.realpath(cuda_bin_symlink)
ask_gcc_path = (
'Please specify which gcc should be used by nvcc as the '
@@ -592,7 +600,7 @@ def set_tf_cuda_version(environ_cp):
# Find out where the CUDA toolkit is installed
default_cuda_path = _DEFAULT_CUDA_PATH
- if is_windows():
+ if is_cygwin():
default_cuda_path = cygpath(
environ_cp.get('CUDA_PATH', _DEFAULT_CUDA_PATH_WIN))
elif is_linux():
@@ -633,7 +641,7 @@ def set_tf_cuda_version(environ_cp):
def set_tf_cunn_version(environ_cp):
"""Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION."""
ask_cudnn_version = (
- '"Please specify the cuDNN version you want to use. '
+ 'Please specify the cuDNN version you want to use. '
'[Leave empty to default to cuDNN %s.0]: ') % _DEFAULT_CUDNN_VERSION
while True:
@@ -652,7 +660,7 @@ def set_tf_cunn_version(environ_cp):
# unusable. Going through one more level of expansion to handle that.
cudnn_install_path = os.path.realpath(
os.path.expanduser(cudnn_install_path))
- if is_windows():
+ if is_cygwin():
cudnn_install_path = cygpath(cudnn_install_path)
if is_windows():
@@ -674,12 +682,10 @@ def set_tf_cunn_version(environ_cp):
# Try another alternative for Linux
if is_linux():
- if subprocess.call(['which', 'ldconfig']):
- ldconfig_bin = '/sbin/ldconfig'
- else:
- ldconfig_bin = 'ldconfig'
- cudnn_path_from_ldconfig = run_shell(
- r'%s -p | sed -n "s/.*libcudnn.so .* => \(.*\)/\\1/p"' % ldconfig_bin)
+ ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
+ cudnn_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
+ cudnn_path_from_ldconfig = re.search('.*libcudnn.so .* => (.*)',
+ cudnn_path_from_ldconfig).group(1)
if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)):
cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
break
@@ -712,11 +718,15 @@ def get_native_cuda_compute_capabilities(environ_cp):
"""
device_query_bin = os.path.join(
environ_cp.get('CUDA_TOOLKIT_PATH'), 'extras/demo_suite/deviceQuery')
- cmd = (r'"%s" | grep "Capability" | grep -o "[0-9]*\.[0-9]*" | sed '
- '":a;{N;s/\\n/,/};ba"') % device_query_bin
- try:
- output = run_shell(cmd)
- except subprocess.CalledProcessError:
+ if os.path.isfile(device_query_bin) and os.access(device_query_bin, os.X_OK):
+ try:
+ output = run_shell(device_query_bin).split('\n')
+ pattern = re.compile('[0-9]*\\.[0-9]*')
+ output = [pattern.search(x) for x in output if 'Capability' in x]
+ output = ','.join(x.group() for x in output if x is not None)
+ except subprocess.CalledProcessError:
+ output = ''
+ else:
output = ''
return output
@@ -797,7 +807,7 @@ def set_other_cuda_vars(environ_cp):
def set_host_cxx_compiler(environ_cp):
"""Set HOST_CXX_COMPILER."""
- default_cxx_host_compiler = run_shell('which g++ || true')
+ default_cxx_host_compiler = which('g++') or ''
ask_cxx_host_compiler = (
'Please specify which C++ compiler should be used as'
' the host C++ compiler. [Default is %s]: ') % default_cxx_host_compiler
@@ -820,7 +830,7 @@ def set_host_cxx_compiler(environ_cp):
def set_host_c_compiler(environ_cp):
"""Set HOST_C_COMPILER."""
- default_c_host_compiler = run_shell('which gcc || true')
+ default_c_host_compiler = which('gcc') or ''
ask_c_host_compiler = (
'Please specify which C compiler should be used as the'
' host C compiler. [Default is %s]: ') % default_c_host_compiler
@@ -874,9 +884,9 @@ def set_computecpp_toolkit_path(environ_cp):
def set_mpi_home(environ_cp):
"""Set MPI_HOME."""
- cmd = ('dirname $(dirname $(which mpirun)) || dirname $(dirname $(which '
- 'mpiexec)) || true')
- default_mpi_home = run_shell(cmd)
+ default_mpi_home = which('mpirun') or which('mpiexec') or ''
+ default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home))
+
ask_mpi_home = ('Please specify the MPI toolkit folder. [Default is %s]: '
) % default_mpi_home
while True:
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 80646fb602..566d8963d0 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -465,6 +465,7 @@ cc_binary(
"//tensorflow/c:c_api",
"//tensorflow/c:exported_symbols.lds",
"//tensorflow/c:version_script.lds",
+ "//tensorflow/c/eager:c_api",
"//tensorflow/core:tensorflow",
],
)
@@ -474,6 +475,7 @@ cc_binary(
linkshared = 1,
deps = [
"//tensorflow/c:c_api",
+ "//tensorflow/c/eager:c_api",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:client_session",
"//tensorflow/cc:scope",
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 112eea0b40..85c4e4fd93 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -1,28 +1,39 @@
# Experimental extensions to the C API for eager execution of kernels.
licenses(["notice"]) # Apache 2.0
-cc_library(
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+ "tf_copts",
+ "tf_cuda_library",
+)
+
+tf_cuda_library(
name = "c_api",
srcs = ["c_api.cc"],
hdrs = ["c_api.h"],
- visibility = [
- "//tensorflow:internal",
- "//tensorflow/python/eager:__pkg__",
- ],
- deps = [
- ":runtime",
- "//tensorflow/c:c_api",
- "//tensorflow/c:c_api_internal",
- "//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- ],
+ copts = tf_copts(),
+ visibility = ["//visibility:public"],
+ deps = select({
+ "//tensorflow:android": [
+ ":c_api_internal",
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ ":runtime",
+ "//tensorflow/c:c_api",
+ "//tensorflow/c:c_api_internal",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+ }),
)
-cc_test(
+tf_cc_test(
name = "c_api_test",
srcs = ["c_api_test.cc"],
deps = [
@@ -34,24 +45,31 @@ cc_test(
],
)
-cc_library(
+tf_cuda_library(
name = "runtime",
srcs = ["runtime.cc"],
hdrs = ["runtime.h"],
+ copts = tf_copts(),
visibility = ["//tensorflow:internal"],
- deps = [
- "//tensorflow/c:c_api",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- ],
+ deps = select({
+ "//tensorflow:android": [
+ ":c_api_internal",
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ "//tensorflow/c:c_api",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+ }),
)
-cc_test(
+tf_cc_test(
name = "runtime_test",
srcs = ["runtime_test.cc"],
deps = [
diff --git a/tensorflow/c/exported_symbols.lds b/tensorflow/c/exported_symbols.lds
index a14bdaa48b..41f0637c99 100644
--- a/tensorflow/c/exported_symbols.lds
+++ b/tensorflow/c/exported_symbols.lds
@@ -1 +1,2 @@
_TF_*
+_TFE_*
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index bd78331309..45060c33f3 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -737,14 +737,10 @@ TEST_F(CWiseUnaryComplexGradTest, Angle) {
Tensor x = test::AsTensor<complex64>(
{{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
- Tensor dx_expected =
- test::AsTensor<complex64>({{5.5, 5.5},
- {3, 3},
- {2.1666666666666665, 2.1666666666666665},
- {1.75, 1.75},
- {0.9375, 0.9375},
- {0.8888888888888888, 0.8888888888888888}},
- {2, 3});
+ Tensor dx_expected = test::AsTensor<complex64>(
+ {{5.5, 5.5}, {3, 3},
+ {2.1666666666666665, 2.1666666666666665}, {1.75, 1.75},
+ {0.9375, 0.9375}, {0.8888888888888888, 0.8888888888888888}}, {2, 3});
TestCWiseGradComplex(ANGLE, x, dy, dx_expected);
}
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc
index f9d69ff896..6fc73c3fa1 100644
--- a/tensorflow/cc/gradients/nn_grad.cc
+++ b/tensorflow/cc/gradients/nn_grad.cc
@@ -95,6 +95,21 @@ Status SeluGradHelper(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Selu", SeluGradHelper);
+Status BiasAddGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ string data_format;
+ BiasAddGrad::Attrs input_attrs;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(op.output(0).node()->attrs(), "data_format", &data_format));
+ input_attrs.DataFormat(data_format);
+ auto dx_1 = BiasAddGrad(scope, grad_inputs[0], input_attrs);
+ grad_outputs->push_back(Identity(scope, grad_inputs[0]));
+ grad_outputs->push_back(dx_1);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("BiasAdd", BiasAddGradHelper);
+
} // anonymous namespace
} // namespace ops
} // namespace tensorflow
diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc
index eab5b44626..a02d36549b 100644
--- a/tensorflow/cc/gradients/nn_grad_test.cc
+++ b/tensorflow/cc/gradients/nn_grad_test.cc
@@ -47,6 +47,15 @@ class NNGradTest : public ::testing::Test {
EXPECT_LT(max_error, 1e-4);
}
+ void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
+ const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
+ TF_ASSERT_OK(scope_.status());
+ float max_error;
+ TF_ASSERT_OK(
+ ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
+ EXPECT_LT(max_error, 1e-4);
+ }
+
Scope scope_;
};
@@ -113,5 +122,14 @@ TEST_F(NNGradTest, SeluGrad) {
RunTest(x, x_init_value, y, shape);
}
+TEST_F(NNGradTest, BiasAddGradHelper) {
+ TensorShape shape({4, 5});
+ TensorShape bias_shape({5});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
+ auto bias = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(bias_shape));
+ auto y = BiasAdd(scope_, x, bias);
+ RunTest({x,bias}, {shape, bias_shape}, {y}, {shape});
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 0c6976bf14..02e7ca64e5 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -24,6 +24,7 @@ package(
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
@@ -32,10 +33,11 @@ cc_library(
deps = [
":xla_cpu_device",
":xla_cpu_jit",
+ "//tensorflow/compiler/plugin",
+ ] + if_cuda_is_configured([
":xla_gpu_device",
":xla_gpu_jit",
- "//tensorflow/compiler/plugin",
- ],
+ ]),
alwayslink = 1,
)
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 60e68db2d6..a8d743c071 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -20,6 +20,8 @@ package(
default_visibility = [":internal"],
)
+load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
+
cc_library(
name = "xla_compiler",
srcs = [
@@ -29,7 +31,10 @@ cc_library(
"xla_helpers.cc",
"xla_op_kernel.cc",
"xla_op_registry.cc",
- ],
+ "xla_cpu_backend.cc",
+ ] + if_cuda_is_configured([
+ "xla_gpu_backend.cc",
+ ]),
hdrs = [
"xla_compilation_device.h",
"xla_compiler.h",
diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
new file mode 100644
index 0000000000..8286480e0e
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/kernel_def.pb.h"
+
+namespace tensorflow {
+
+bool CpuOpFilter(KernelDef* kdef) {
+ // TODO(b/34339814): implement inverse erf for double types and remove this
+ // workaround.
+ if (kdef->op() == "RandomStandardNormal") {
+ kdef->clear_constraint();
+ // Change the type constraint to permit only DTD_FLOAT.
+ KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
+ attr_constraint->set_name("dtype");
+ attr_constraint->mutable_allowed_values()->mutable_list()->add_type(
+ DT_FLOAT);
+ return true;
+ }
+ return true;
+}
+
+REGISTER_XLA_BACKEND(DEVICE_CPU_XLA_JIT, kCpuAllTypes, CpuOpFilter);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
new file mode 100644
index 0000000000..d504613d23
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/kernel_def.pb.h"
+
+namespace tensorflow {
+
+bool GpuOpFilter(KernelDef* kdef) {
+ // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to
+ // slow code.
+ // TODO(b/34969189) The implementation of TruncatedNormal generates illegal
+ // code on GPU.
+ if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" ||
+ kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") {
+ return false;
+ }
+ return true;
+}
+
+REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index 5db3566674..2cf3d4c1f2 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -330,35 +330,4 @@ XlaBackendRegistrar::XlaBackendRegistrar(
registry.RegisterBackend(name.ToString(), types, op_filter);
}
-bool CpuOpFilter(KernelDef* kdef) {
- // TODO(b/34339814): implement inverse erf for double types and remove this
- // workaround.
- if (kdef->op() == "RandomStandardNormal") {
- kdef->clear_constraint();
- // Change the type constraint to permit only DTD_FLOAT.
- KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
- attr_constraint->set_name("dtype");
- attr_constraint->mutable_allowed_values()->mutable_list()->add_type(
- DT_FLOAT);
- return true;
- }
- return true;
-}
-
-REGISTER_XLA_BACKEND(DEVICE_CPU_XLA_JIT, kCpuAllTypes, CpuOpFilter);
-
-bool GpuOpFilter(KernelDef* kdef) {
- // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to
- // slow code.
- // TODO(b/34969189) The implementation of TruncatedNormal generates illegal
- // code on GPU.
- if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" ||
- kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") {
- return false;
- }
- return true;
-}
-
-REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter);
-
} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 8aa40435e2..71995b2307 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -839,6 +839,7 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
CONVERT_IF_TYPES_MATCH(U8)
CONVERT_IF_TYPES_MATCH(U32)
CONVERT_IF_TYPES_MATCH(U64)
+ CONVERT_IF_TYPES_MATCH(F16)
CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64)
#undef CONVERT_IF_TYPES_MATCH
@@ -865,6 +866,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert(
CONVERT_IF_DEST_TYPE_MATCHES(U8)
CONVERT_IF_DEST_TYPE_MATCHES(U32)
CONVERT_IF_DEST_TYPE_MATCHES(U64)
+ CONVERT_IF_DEST_TYPE_MATCHES(F16)
CONVERT_IF_DEST_TYPE_MATCHES(F32)
CONVERT_IF_DEST_TYPE_MATCHES(F64)
#undef CONVERT_IF_DEST_TYPE_MATCHES
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index b50e741b8a..a33c0fe09d 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -878,6 +878,14 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{0, 1, 0, 1}, {1, 0, 1, 0}},
{{1, 0, 1, 0}, {0, 1, 0, 1}},
}}, layout_r4_dim0major_);
+ auto f16 = Literal::CreateR4WithLayout<half>({{
+ {{half(10.0), half(0.0), half(12.0), half(0.0)},
+ {half(0.0), half(15.0), half(0.0), half(17.0)}},
+ {{half(0.0), half(19.0), half(0.0), half(21.0)},
+ {half(22.0), half(0.0), half(24.0), half(0.0)}},
+ {{half(26.0), half(0.0), half(28.0), half(0.0)},
+ {half(0.0), half(31.0), half(0.0), half(33.0)}},
+ }}, layout_r4_dim0major_);
auto f32 = Literal::CreateR4WithLayout<float>({{
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
@@ -918,10 +926,20 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
conv = s32->Convert(F32).ConsumeValueOrDie();
EXPECT_TRUE(conv->Equal(*f32));
+ conv = f32->Convert(F16).ConsumeValueOrDie();
+ EXPECT_TRUE(conv->Equal(*f16));
+
+ conv = f64->Convert(F16).ConsumeValueOrDie();
+ EXPECT_TRUE(conv->Equal(*f16));
+
+ conv = s32->Convert(F16).ConsumeValueOrDie();
+ EXPECT_TRUE(conv->Equal(*f16));
+
+ conv = u32->Convert(F16).ConsumeValueOrDie();
+ EXPECT_TRUE(conv->Equal(*f16));
+
EXPECT_EQ(s32->Convert(TUPLE).status().code(),
tensorflow::error::INVALID_ARGUMENT);
- EXPECT_EQ(s32->Convert(F16).status().code(),
- tensorflow::error::INVALID_ARGUMENT);
EXPECT_EQ(s32->Convert(S16).status().code(),
tensorflow::error::INVALID_ARGUMENT);
EXPECT_EQ(s32->Convert(U16).status().code(),
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
index 01b8328115..145918db3e 100644
--- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -51,7 +51,7 @@ class ConvolutionVariantsTest : public ClientLibraryTestBase {
#endif
};
-TEST_F(ConvolutionVariantsTest, Minimal) {
+XLA_TEST_F(ConvolutionVariantsTest, Minimal) {
ComputationBuilder builder(client_, TestName());
const Array4D<float> input_array(1, 1, 1, 1, {2});
@@ -66,7 +66,7 @@ TEST_F(ConvolutionVariantsTest, Minimal) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, MinimalWithBatch) {
+XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) {
ComputationBuilder builder(client_, TestName());
const Array4D<float> input_array(5, 1, 1, 1, {1, 2, 3, 4, 5});
@@ -81,7 +81,7 @@ TEST_F(ConvolutionVariantsTest, MinimalWithBatch) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Flat1x1) {
+XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(2, 1, 3, 4);
@@ -98,7 +98,7 @@ TEST_F(ConvolutionVariantsTest, Flat1x1) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Deep1x1) {
+XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 2, 1, 1, {10, 1});
@@ -113,7 +113,7 @@ TEST_F(ConvolutionVariantsTest, Deep1x1) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 1, 2, {1, 2});
@@ -128,7 +128,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
@@ -143,7 +143,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
@@ -158,7 +158,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
@@ -173,7 +173,7 @@ TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
@@ -188,7 +188,7 @@ TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(
@@ -209,7 +209,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 1, 4, {1, 2, 3, 4});
@@ -224,7 +224,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 1, 5, {1, 2, 3, 4, 5});
@@ -239,7 +239,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 1, 4, {1, 2, 3, 4});
@@ -254,7 +254,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 1, 5, {1, 2, 3, 4, 5});
@@ -269,7 +269,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
@@ -284,7 +284,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 1, 1, {1});
@@ -299,7 +299,7 @@ TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
@@ -314,7 +314,7 @@ TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
@@ -331,7 +331,7 @@ TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 2, 1, 2, {1, 2, 3, 4});
@@ -346,7 +346,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
@@ -361,7 +361,7 @@ TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
@@ -376,7 +376,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) {
ComputationBuilder builder(client_, TestName());
std::vector<float> input_data(64);
@@ -396,7 +396,7 @@ TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) {
ComputationBuilder builder(client_, TestName());
std::vector<float> input_data(16 * 1 * 1 * 1);
@@ -417,7 +417,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) {
ComputationBuilder builder(client_, TestName());
constexpr int bs = 16;
@@ -448,7 +448,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) {
ComputationBuilder builder(client_, TestName());
constexpr int kx = 2;
@@ -478,7 +478,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(16, 1, 8, 8);
@@ -506,7 +506,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) {
ComputationBuilder builder(client_, TestName());
std::vector<float> input_data(2 * 8 * 8);
@@ -532,7 +532,7 @@ TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) {
ComputationBuilder builder(client_, TestName());
std::vector<float> input_data(2 * 2 * 8 * 8);
@@ -558,7 +558,7 @@ TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) {
ComputationBuilder builder(client_, TestName());
std::vector<float> input_data(32 * 2 * 8 * 8);
@@ -598,7 +598,7 @@ TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) {
ComputationBuilder builder(client_, TestName());
Array4D<float> input_array(16, 16, 1, 1);
@@ -794,7 +794,7 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) {
+XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) {
constexpr int bs = 1;
constexpr int iz = 1;
constexpr int oz = 2;
@@ -827,7 +827,7 @@ TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) {
ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) {
+XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) {
constexpr int bs = 1;
constexpr int iz = 16;
constexpr int oz = 1;
@@ -860,7 +860,7 @@ TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) {
ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) {
+XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) {
constexpr int bs = 16;
constexpr int iz = 16;
constexpr int oz = 1;
@@ -893,7 +893,7 @@ TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) {
ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) {
+XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) {
constexpr int bs = 16;
constexpr int iz = 16;
constexpr int oz = 16;
@@ -926,7 +926,7 @@ TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) {
ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x16x16_Filter16x16x16x16) {
+XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x16x16_Filter16x16x16x16) {
constexpr int bs = 16;
constexpr int iz = 16;
constexpr int oz = 16;
@@ -959,7 +959,7 @@ TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x16x16_Filter16x16x16x16) {
ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) {
ComputationBuilder builder(client_, TestName());
std::vector<float> input_data(1 * 2 * 3 * 1);
@@ -999,7 +999,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) {
ComputationBuilder builder(client_, TestName());
std::vector<float> input_data(1 * 2 * 3 * 1);
@@ -1039,7 +1039,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) {
ComputationBuilder builder(client_, TestName());
std::vector<float> input_data(1 * 2 * 3 * 1);
@@ -1076,7 +1076,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) {
+XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) {
ComputationBuilder builder(client_, TestName());
std::vector<float> input_data(1 * 2 * 3 * 2);
@@ -1123,7 +1123,7 @@ TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) {
// Conv([1,2,3], Reverse([5,6]), padding_low=1)
// into
// BackwardInputConv([1,2,3], [5,6], padding_low=0, padding_high=1)
-TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) {
+XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) {
ComputationBuilder builder(client_, TestName());
auto gradients = builder.ConstantR4FromArray4D<float>(
@@ -1141,7 +1141,7 @@ TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) {
// Conv([1], Reverse([1,10,100]), padding_high=3, base_dilation=3)
// into
// BackwardInputConv([1], [1,10,100], stride=3, padding=(2,1))
-TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) {
+XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) {
ComputationBuilder builder(client_, TestName());
auto gradients = builder.ConstantR4FromArray4D<float>(
@@ -1162,7 +1162,7 @@ TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) {
// Conv([1], Reverse([1,10,100]), padding=(1,1))
// into
// BackwardInputConv([1], [1,10,100], padding=(1,1))
-TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) {
+XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) {
ComputationBuilder builder(client_, TestName());
auto gradients = builder.ConstantR4FromArray4D<float>(
@@ -1183,7 +1183,7 @@ TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) {
//
// However, XLA:GPU doesn't actually fuse it because PadInsertion doesn't
// support negative padding on backward convolution yet (b/32744257).
-TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) {
+XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) {
ComputationBuilder builder(client_, TestName());
auto gradients = builder.ConstantR4FromArray4D<float>(
@@ -1198,7 +1198,7 @@ TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) {
ComputeAndCompareR4<float>(&builder, {{{{12, 23, 30, 0}}}}, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) {
+XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) {
ComputationBuilder builder(client_, TestName());
// activations: 1,2,3,4 ---pad--> 0,1,2,3,4,0,0
@@ -1221,7 +1221,7 @@ TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) {
ComputeAndCompareR4<float>(&builder, {{{{24, 130, 240}}}}, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest,
+XLA_TEST_F(ConvolutionVariantsTest,
BackwardFilterLowPaddingGreaterThanHighPadding) {
ComputationBuilder builder(client_, TestName());
@@ -1247,7 +1247,7 @@ TEST_F(ConvolutionVariantsTest,
ComputeAndCompareR4<float>(&builder, {{{{13, 24}}}}, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) {
+XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) {
ComputationBuilder builder(client_, TestName());
// activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4,0
@@ -1274,7 +1274,7 @@ TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) {
ComputeAndCompareR4<float>(&builder, {{{{13, 24, 130}}}}, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) {
+XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) {
ComputationBuilder builder(client_, TestName());
auto gradients = builder.ConstantR3FromArray3D<float>(
@@ -1288,7 +1288,7 @@ TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) {
ComputeAndCompareR3<float>(&builder, {{{10}}}, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) {
+XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) {
ComputationBuilder builder(client_, TestName());
auto activations =
@@ -1307,7 +1307,7 @@ TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) {
ComputeAndCompareR3<float>(&builder, {{{13, 24, 130}}}, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
+XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
ComputationBuilder builder(client_, TestName());
auto gradients_flat = Literal::CreateR1<float>({1});
@@ -1331,7 +1331,7 @@ TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
}
-TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
+XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
ComputationBuilder builder(client_, TestName());
auto activations_flat = Literal::CreateR1<float>({1, 2, 3, 4});
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index afeab5860b..bcb85b04ee 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -56,30 +56,30 @@ class CopyOpTest : public HloTestBase {
tensorflow::gtl::ArraySlice<int64> permutation);
};
-TEST_F(CopyOpTest, CopyR0Bool) { TestCopyOp(*Literal::CreateR0<bool>(true)); }
+XLA_TEST_F(CopyOpTest, CopyR0Bool) { TestCopyOp(*Literal::CreateR0<bool>(true)); }
-TEST_F(CopyOpTest, CopyR1S0U32) { TestCopyOp(*Literal::CreateR1<uint32>({})); }
+XLA_TEST_F(CopyOpTest, CopyR1S0U32) { TestCopyOp(*Literal::CreateR1<uint32>({})); }
-TEST_F(CopyOpTest, CopyR1S3U32) {
+XLA_TEST_F(CopyOpTest, CopyR1S3U32) {
TestCopyOp(*Literal::CreateR1<uint32>({1, 2, 3}));
}
-TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
+XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
TestCopyOp(*Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
-TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
+XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
TestCopyOp(*Literal::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
-TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
+XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
TestCopyOp(*Literal::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
}
-TEST_F(CopyOpTest, CopyParameterScalar) {
+XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
auto builder = HloComputation::Builder(TestName());
// Copy literal to device to use as parameter.
@@ -102,7 +102,7 @@ TEST_F(CopyOpTest, CopyParameterScalar) {
LiteralTestUtil::ExpectR0Near<float>(42.0f, *result, error_spec_);
}
-TEST_F(CopyOpTest, CopyConstantR2Twice) {
+XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
auto builder = HloComputation::Builder(TestName());
auto literal = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
@@ -123,7 +123,7 @@ TEST_F(CopyOpTest, CopyConstantR2Twice) {
error_spec_);
}
-TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
+XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
HloComputation::Builder builder(TestName());
std::unique_ptr<Literal> literal =
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index 06540e02a6..5da6104cfa 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -193,12 +193,27 @@ class SliceR1Test : public ClientLibraryTestBase,
}
};
-XLA_TEST_P(SliceR1Test, DoIt) {
+XLA_TEST_P(SliceR1Test, DoIt_F32) {
Run<float>(GetParam());
+}
+
+XLA_TEST_P(SliceR1Test, DoIt_F64) {
Run<double>(GetParam());
+}
+
+XLA_TEST_P(SliceR1Test, DoIt_U32) {
Run<uint32>(GetParam());
+}
+
+XLA_TEST_P(SliceR1Test, DoIt_S32) {
Run<int32>(GetParam());
+}
+
+XLA_TEST_P(SliceR1Test, DoIt_U64) {
Run<uint64>(GetParam());
+}
+
+XLA_TEST_P(SliceR1Test, DoIt_S64) {
Run<int64>(GetParam());
}
diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt
index 1f86288cf9..f61e9560ef 100644
--- a/tensorflow/contrib/android/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/android/cmake/CMakeLists.txt
@@ -28,7 +28,7 @@ set_target_properties(lib_proto PROPERTIES IMPORTED_LOCATION
add_library(lib_nsync STATIC IMPORTED )
set_target_properties(lib_nsync PROPERTIES IMPORTED_LOCATION
- ${TARGET_NSYNC_LIB})
+ ${TARGET_NSYNC_LIB}/lib/libnsync.a)
add_library(lib_tf STATIC IMPORTED )
set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
index 587f2941e5..9b7f394258 100644
--- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
+++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
@@ -55,23 +55,7 @@ public class TensorFlowInferenceInterface {
* @param model The filepath to the GraphDef proto representing the model.
*/
public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
- Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
- try {
- // Hack to see if the native libraries have been loaded.
- new RunStats();
- Log.i(TAG, "TensorFlow native methods already loaded");
- } catch (UnsatisfiedLinkError e1) {
- Log.i(
- TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference");
- try {
- System.loadLibrary("tensorflow_inference");
- Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
- } catch (UnsatisfiedLinkError e2) {
- throw new RuntimeException(
- "Native TF methods not found; check that the correct native"
- + " libraries are present in the APK.");
- }
- }
+ prepareNativeRuntime();
this.modelName = model;
this.g = new Graph();
@@ -102,6 +86,31 @@ public class TensorFlowInferenceInterface {
throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
+
+ /*
+ * Load a TensorFlow model from provided InputStream.
+ * Note: The InputStream will not be closed after loading model, users need to
+ * close it themselves.
+ *
+ * @param is The InputStream to use to load the model.
+ */
+ public TensorFlowInferenceInterface(InputStream is) {
+ prepareNativeRuntime();
+
+ // modelName is redundant for model loading from input stream, here is for
+ // avoiding error in initialization as modelName is marked final.
+ this.modelName = "";
+ this.g = new Graph();
+ this.sess = new Session(g);
+ this.runner = sess.runner();
+
+ try {
+ loadGraph(is, g);
+ Log.i(TAG, "Successfully loaded model from the input stream");
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to load model from the input stream", e);
+ }
+ }
/**
* Runs inference between the previously registered input nodes (via feed*) and the requested
@@ -408,6 +417,26 @@ public class TensorFlowInferenceInterface {
public void fetch(String outputName, ByteBuffer dst) {
getTensor(outputName).writeTo(dst);
}
+
+ private void prepareNativeRuntime() {
+ Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
+ try {
+ // Hack to see if the native libraries have been loaded.
+ new RunStats();
+ Log.i(TAG, "TensorFlow native methods already loaded");
+ } catch (UnsatisfiedLinkError e1) {
+ Log.i(
+ TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference");
+ try {
+ System.loadLibrary("tensorflow_inference");
+ Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
+ } catch (UnsatisfiedLinkError e2) {
+ throw new RuntimeException(
+ "Native TF methods not found; check that the correct native"
+ + " libraries are present in the APK.");
+ }
+ }
+ }
private void loadGraph(InputStream is, Graph g) throws IOException {
final long startMs = System.currentTimeMillis();
diff --git a/tensorflow/contrib/cmake/external/boringssl.cmake b/tensorflow/contrib/cmake/external/boringssl.cmake
index 2ae591d3fa..04a9664701 100644
--- a/tensorflow/contrib/cmake/external/boringssl.cmake
+++ b/tensorflow/contrib/cmake/external/boringssl.cmake
@@ -17,7 +17,7 @@ include (ExternalProject)
set(boringssl_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src/boringssl/include)
#set(boringssl_EXTRA_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src)
set(boringssl_URL https://boringssl.googlesource.com/boringssl)
-set(boringssl_TAG e72df93)
+set(boringssl_TAG 17cf2cb1d226b0ba2401304242df7ddd3b6f1ff2)
set(boringssl_BUILD ${CMAKE_BINARY_DIR}/boringssl/src/boringssl-build)
#set(boringssl_LIBRARIES ${boringssl_BUILD}/obj/so/libboringssl.so)
set(boringssl_STATIC_LIBRARIES
diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py
index 80a31cc334..bc749339bd 100644
--- a/tensorflow/contrib/crf/__init__.py
+++ b/tensorflow/contrib/crf/__init__.py
@@ -21,7 +21,10 @@ See the @{$python/contrib.crf} guide.
@@crf_log_likelihood
@@crf_unary_score
@@crf_binary_score
+@@crf_decode
@@CrfForwardRnnCell
+@@CrfDecodeForwardRnnCell
+@@CrfDecodeBackwardRnnCell
@@viterbi_decode
"""
@@ -31,11 +34,14 @@ from __future__ import print_function
from tensorflow.contrib.crf.python.ops.crf import _lengths_to_masks
from tensorflow.contrib.crf.python.ops.crf import crf_binary_score
+from tensorflow.contrib.crf.python.ops.crf import crf_decode
from tensorflow.contrib.crf.python.ops.crf import crf_log_likelihood
from tensorflow.contrib.crf.python.ops.crf import crf_log_norm
from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score
from tensorflow.contrib.crf.python.ops.crf import crf_unary_score
from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell
+from tensorflow.contrib.crf.python.ops.crf import CrfDecodeForwardRnnCell
+from tensorflow.contrib.crf.python.ops.crf import CrfDecodeBackwardRnnCell
from tensorflow.contrib.crf.python.ops.crf import viterbi_decode
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py
index 470fb15fd6..e595e4d90b 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util.py
@@ -77,12 +77,10 @@ def reduce_sum_n(tensors, name=None):
return tensors[0]
return math_ops.add_n(tensors, name=name_scope)
-
-@deprecated(
- None,
- 'Please switch to tf.confusion_matrix.remove_squeezable_dimensions. Note '
- 'that order of the inputs and ouputs of labels and predictions have also '
- 'been switched.')
+@deprecated(None,
+ "Please switch to tf.confusion_matrix.remove_squeezable_dimensions. Note "
+ "that order of the inputs and ouputs of labels and predictions have also "
+ "been switched.")
def remove_squeezable_dimensions(predictions, labels, name=None):
"""Squeeze last dim if ranks of `predictions` and `labels` differ by 1.
diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py
index fee1a6c2bc..1ed19265b3 100755
--- a/tensorflow/contrib/image/__init__.py
+++ b/tensorflow/contrib/image/__init__.py
@@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""##Ops for image manipulation.
+"""Ops for image manipulation.
### API
This module provides functions for image manipulation; currently, only
projective transforms (including rotation) are supported.
-## Image `Ops`
-
@@angles_to_projective_transforms
@@compose_transforms
@@rotate
diff --git a/tensorflow/contrib/imperative/README.md b/tensorflow/contrib/imperative/README.md
index ea643a45d1..44860c796b 100644
--- a/tensorflow/contrib/imperative/README.md
+++ b/tensorflow/contrib/imperative/README.md
@@ -96,7 +96,7 @@ TensorFlow:
```
* Variables are automatically initialized, no need to run the
- [`tf.global_variables_initializer()`](https://www.tensorflow.org/api_docs/python/state_ops/variable_helper_functions#global_variables_initializer) operation.
+ [`tf.global_variables_initializer()`](https://www.tensorflow.org/api_docs/python/tf/global_variables_initializer) operation.
```python
x = tf.Variable(np.random.normal(size=[2, 2]), dtype=tf.float32)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 0a60685613..fb0e4fdc53 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -29,7 +29,6 @@ import numpy as np
import six
from google.protobuf import message
-from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import layers
from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.framework import deprecated
@@ -58,6 +57,7 @@ from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resources
@@ -97,8 +97,7 @@ def _verify_input_args(x, y, input_fn, feed_fn, batch_size):
if x is None:
raise ValueError('Either x or input_fn must be provided.')
- if contrib_framework.is_tensor(x) or (y is not None and
- contrib_framework.is_tensor(y)):
+ if tensor_util.is_tensor(x) or y is not None and tensor_util.is_tensor(y):
raise ValueError('Inputs cannot be tensors. Please provide input_fn.')
if feed_fn is not None:
@@ -847,7 +846,7 @@ class BaseEstimator(
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
- global_step = contrib_framework.create_global_step(g)
+ global_step = training_util.create_global_step(g)
features, labels = input_fn()
self._check_inputs(features, labels)
@@ -908,7 +907,7 @@ class BaseEstimator(
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
- contrib_framework.create_global_step(g)
+ training_util.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
infer_ops = self._get_predict_ops(features)
predictions = self._filter_predictions(infer_ops.predictions, outputs)
@@ -978,7 +977,7 @@ class BaseEstimator(
self._graph = ops.Graph()
with self._graph.as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
- global_step = contrib_framework.create_global_step(g)
+ global_step = training_util.create_global_step(g)
features, labels = input_fn()
self._check_inputs(features, labels)
model_fn_ops = self._get_train_ops(features, labels)
@@ -1132,13 +1131,14 @@ class Estimator(BaseEstimator):
self._feature_engineering_fn = (
feature_engineering_fn or _identity_feature_engineering_fn)
- def _call_model_fn(self, features, labels, mode):
+ def _call_model_fn(self, features, labels, mode, metrics=None):
"""Calls model function with support of 2, 3 or 4 arguments.
Args:
features: features dict.
labels: labels dict.
mode: ModeKeys
+ metrics: Dict of metrics.
Returns:
A `ModelFnOps` object. If model_fn returns a tuple, wraps them up in a
@@ -1161,17 +1161,24 @@ class Estimator(BaseEstimator):
model_fn_results = self._model_fn(features, labels, **kwargs)
if isinstance(model_fn_results, model_fn_lib.ModelFnOps):
- return model_fn_results
-
- # Here model_fn_results should be a tuple with 3 elements.
- if len(model_fn_results) != 3:
- raise ValueError('Unrecognized value returned by model_fn, '
- 'please return ModelFnOps.')
- return model_fn_lib.ModelFnOps(
- mode=mode,
- predictions=model_fn_results[0],
- loss=model_fn_results[1],
- train_op=model_fn_results[2])
+ model_fn_ops = model_fn_results
+ else:
+ # Here model_fn_results should be a tuple with 3 elements.
+ if len(model_fn_results) != 3:
+ raise ValueError('Unrecognized value returned by model_fn, '
+ 'please return ModelFnOps.')
+ model_fn_ops = model_fn_lib.ModelFnOps(
+ mode=mode,
+ predictions=model_fn_results[0],
+ loss=model_fn_results[1],
+ train_op=model_fn_results[2])
+
+ # Custom metrics should overwrite defaults.
+ if metrics:
+ model_fn_ops.eval_metric_ops.update(_make_metrics_ops(
+ metrics, features, labels, model_fn_ops.predictions))
+
+ return model_fn_ops
def _get_train_ops(self, features, labels):
"""Method that builds model graph and returns trainer ops.
@@ -1215,13 +1222,7 @@ class Estimator(BaseEstimator):
ValueError: if `metrics` don't match `labels`.
"""
model_fn_ops = self._call_model_fn(
- features, labels, model_fn_lib.ModeKeys.EVAL)
-
- features, labels = self._feature_engineering_fn(features, labels)
- # Custom metrics should overwrite defaults.
- if metrics:
- model_fn_ops.eval_metric_ops.update(_make_metrics_ops(
- metrics, features, labels, model_fn_ops.predictions))
+ features, labels, model_fn_lib.ModeKeys.EVAL, metrics)
if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops:
model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = (
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py
index a458e7abf6..1d89dfb55b 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py
@@ -30,6 +30,7 @@ from tensorflow.contrib.learn.python.learn.estimators import estimator as estima
from tensorflow.contrib.learn.python.learn.estimators._sklearn import accuracy_score
from tensorflow.contrib.learn.python.learn.estimators._sklearn import train_test_split
from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.training import momentum as momentum_lib
@@ -77,6 +78,44 @@ class FeatureEngineeringFunctionTest(test.TestCase):
# labels = transformed_y (99)
self.assertEqual(99., metrics["label"])
+ def testFeatureEngineeringFnWithSameName(self):
+
+ def input_fn():
+ return {
+ "x": constant_op.constant(["9."])
+ }, {
+ "y": constant_op.constant(["99."])
+ }
+
+ def feature_engineering_fn(features, labels):
+ # Github #12205: raise a TypeError if called twice.
+ _ = string_ops.string_split(features["x"])
+ features["x"] = constant_op.constant([9.])
+ labels["y"] = constant_op.constant([99.])
+ return features, labels
+
+ def model_fn(features, labels):
+ # dummy variable:
+ _ = variables_lib.Variable([0.])
+ _ = labels
+ predictions = features["x"]
+ loss = constant_op.constant([2.])
+ update_global_step = variables.get_global_step().assign_add(1)
+ return predictions, loss, update_global_step
+
+ estimator = estimator_lib.Estimator(
+ model_fn=model_fn, feature_engineering_fn=feature_engineering_fn)
+ estimator.fit(input_fn=input_fn, steps=1)
+ prediction = next(estimator.predict(input_fn=input_fn, as_iterable=True))
+ # predictions = transformed_x (9)
+ self.assertEqual(9., prediction)
+ metrics = estimator.evaluate(
+ input_fn=input_fn, steps=1,
+ metrics={"label":
+ metric_spec.MetricSpec(lambda predictions, labels: labels)})
+ # labels = transformed_y (99)
+ self.assertEqual(99., metrics["label"])
+
def testNoneFeatureEngineeringFn(self):
def input_fn():
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 7b49cd475d..c31d5d2d47 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -25,8 +25,6 @@ import six
from tensorflow.contrib import framework as framework_lib
from tensorflow.contrib import layers as layers_lib
from tensorflow.contrib import lookup as lookup_lib
-# TODO(ptucker): Use tf.metrics.
-from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
@@ -38,6 +36,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
@@ -766,7 +765,7 @@ class _RegressionHead(_SingleHead):
with ops.name_scope("metrics", values=[eval_loss]):
return {
_summary_key(self.head_name, mkey.LOSS):
- metrics_lib.streaming_mean(eval_loss)}
+ metrics_lib.mean(eval_loss)}
def _log_loss_with_two_classes(labels, logits, weights=None):
@@ -903,11 +902,11 @@ class _BinaryLogisticHead(_SingleHead):
logistic = predictions[prediction_key.PredictionKey.LOGISTIC]
metrics = {_summary_key(self.head_name, mkey.LOSS):
- metrics_lib.streaming_mean(eval_loss)}
+ metrics_lib.mean(eval_loss)}
# TODO(b/29366811): This currently results in both an "accuracy" and an
# "accuracy/threshold_0.500000_mean" metric for binary classification.
metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
- metrics_lib.streaming_accuracy(classes, labels, weights))
+ metrics_lib.accuracy(labels, classes, weights))
metrics[_summary_key(self.head_name, mkey.PREDICTION_MEAN)] = (
_predictions_streaming_mean(logistic, weights))
metrics[_summary_key(self.head_name, mkey.LABEL_MEAN)] = (
@@ -1132,12 +1131,11 @@ class _MultiClassHead(_SingleHead):
classes = predictions[prediction_key.PredictionKey.CLASSES]
metrics = {_summary_key(self.head_name, mkey.LOSS):
- metrics_lib.streaming_mean(eval_loss)}
+ metrics_lib.mean(eval_loss)}
# TODO(b/29366811): This currently results in both an "accuracy" and an
# "accuracy/threshold_0.500000_mean" metric for binary classification.
metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
- metrics_lib.streaming_accuracy(
- classes, self._labels(labels), weights))
+ metrics_lib.accuracy(self._labels(labels), classes, weights))
if not self._label_keys:
# Classes are IDs. Add some metrics.
@@ -1290,13 +1288,13 @@ class _BinarySvmHead(_SingleHead):
with ops.name_scope("metrics", values=(
[eval_loss, labels, weights] + list(six.itervalues(predictions)))):
metrics = {_summary_key(self.head_name, mkey.LOSS):
- metrics_lib.streaming_mean(eval_loss)}
+ metrics_lib.mean(eval_loss)}
# TODO(b/29366811): This currently results in both an "accuracy" and an
# "accuracy/threshold_0.500000_mean" metric for binary classification.
classes = predictions[prediction_key.PredictionKey.CLASSES]
metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
- metrics_lib.streaming_accuracy(classes, labels, weights))
+ metrics_lib.accuracy(labels, classes, weights))
# TODO(sibyl-vie3Poto): add more metrics relevant for svms.
return metrics
@@ -1397,11 +1395,11 @@ class _MultiLabelHead(_SingleHead):
logits = predictions[prediction_key.PredictionKey.LOGITS]
metrics = {_summary_key(self.head_name, mkey.LOSS):
- metrics_lib.streaming_mean(eval_loss)}
+ metrics_lib.mean(eval_loss)}
# TODO(b/29366811): This currently results in both an "accuracy" and an
# "accuracy/threshold_0.500000_mean" metric for binary classification.
metrics[_summary_key(self.head_name, mkey.ACCURACY)] = (
- metrics_lib.streaming_accuracy(classes, labels, weights))
+ metrics_lib.accuracy(labels, classes, weights))
metrics[_summary_key(self.head_name, mkey.AUC)] = _streaming_auc(
probabilities, labels, weights)
metrics[_summary_key(self.head_name, mkey.AUC_PR)] = _streaming_auc(
@@ -1946,7 +1944,7 @@ def _indicator_labels_streaming_mean(labels, weights=None, class_id=None):
if weights is not None:
weights = weights[:, class_id]
labels = labels[:, class_id]
- return metrics_lib.streaming_mean(labels, weights=weights)
+ return metrics_lib.mean(labels, weights)
def _predictions_streaming_mean(predictions,
@@ -1960,7 +1958,7 @@ def _predictions_streaming_mean(predictions,
if weights is not None:
weights = weights[:, class_id]
predictions = predictions[:, class_id]
- return metrics_lib.streaming_mean(predictions, weights=weights)
+ return metrics_lib.mean(predictions, weights)
# TODO(ptucker): Add support for SparseTensor labels.
@@ -1973,7 +1971,7 @@ def _class_id_labels_to_indicator(labels, num_classes):
def _class_predictions_streaming_mean(predictions, weights, class_id):
- return metrics_lib.streaming_mean(
+ return metrics_lib.mean(
array_ops.where(
math_ops.equal(
math_ops.to_int32(class_id), math_ops.to_int32(predictions)),
@@ -1983,7 +1981,7 @@ def _class_predictions_streaming_mean(predictions, weights, class_id):
def _class_labels_streaming_mean(labels, weights, class_id):
- return metrics_lib.streaming_mean(
+ return metrics_lib.mean(
array_ops.where(
math_ops.equal(
math_ops.to_int32(class_id), math_ops.to_int32(labels)),
@@ -2006,8 +2004,7 @@ def _streaming_auc(predictions, labels, weights=None, class_id=None,
weights = weights[:, class_id]
predictions = predictions[:, class_id]
labels = labels[:, class_id]
- return metrics_lib.streaming_auc(
- predictions, labels, weights=weights, curve=curve)
+ return metrics_lib.auc(labels, predictions, weights, curve=curve)
def _assert_class_id(class_id, num_classes=None):
@@ -2024,21 +2021,18 @@ def _assert_class_id(class_id, num_classes=None):
def _streaming_accuracy_at_threshold(predictions, labels, weights, threshold):
threshold_predictions = math_ops.to_float(
math_ops.greater_equal(predictions, threshold))
- return metrics_lib.streaming_accuracy(
- predictions=threshold_predictions, labels=labels, weights=weights)
+ return metrics_lib.accuracy(labels, threshold_predictions, weights)
def _streaming_precision_at_threshold(predictions, labels, weights, threshold):
- precision_tensor, update_op = metrics_lib.streaming_precision_at_thresholds(
- predictions, labels=labels, thresholds=(threshold,),
- weights=_float_weights_or_none(weights))
+ precision_tensor, update_op = metrics_lib.precision_at_thresholds(
+ labels, predictions, (threshold,),_float_weights_or_none(weights))
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
def _streaming_recall_at_threshold(predictions, labels, weights, threshold):
- precision_tensor, update_op = metrics_lib.streaming_recall_at_thresholds(
- predictions, labels=labels, thresholds=(threshold,),
- weights=_float_weights_or_none(weights))
+ precision_tensor, update_op = metrics_lib.recall_at_thresholds(
+ labels, predictions, (threshold,),_float_weights_or_none(weights))
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
diff --git a/tensorflow/contrib/lookup/__init__.py b/tensorflow/contrib/lookup/__init__.py
index dbd64cf042..17eafeb2da 100644
--- a/tensorflow/contrib/lookup/__init__.py
+++ b/tensorflow/contrib/lookup/__init__.py
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-# TODO(ptucker): deprecate string_to_index_table_from_file and
-# string_to_index_table_from_tensor 2017-04-10.
"""Ops for lookup operations.
@@string_to_index
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc
index 2def4f3f17..c33804906f 100644
--- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc
@@ -15,8 +15,8 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h"
#include <algorithm>
+#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h
index c62a7b20d6..fc3a2da9b3 100644
--- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#define Sum(a, b) ((a) + (b))
#define Prod(a, b) ((a) * (b))
@@ -58,11 +58,11 @@ inline T negative_infinity() {
} // namespace reduce_functions
-#define CALL_ALL_REDUCEOPS(func, ...) \
- func(Sum, functor::reduce_functions::zero, ##__VA_ARGS__) \
- func(Prod, functor::reduce_functions::one, ##__VA_ARGS__) func( \
- Max, functor::reduce_functions::negative_infinity, ##__VA_ARGS__) \
- func(Min, functor::reduce_functions::infinity, ##__VA_ARGS__)
+#define CALL_ALL_REDUCEOPS(func, ...) \
+ func(Sum, functor::reduce_functions::zero, ##__VA_ARGS__) \
+ func(Prod, functor::reduce_functions::one, ##__VA_ARGS__) \
+ func(Max, functor::reduce_functions::negative_infinity, ##__VA_ARGS__) \
+ func(Min, functor::reduce_functions::infinity, ##__VA_ARGS__)
#define ReduceSliceFunctorReduceop(reduceop, dummy) \
template <typename Device, typename T, typename Index> \
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
index 8b205f7dd5..8e6870fadd 100644
--- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
@@ -17,10 +17,10 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
@@ -68,9 +68,8 @@ namespace functor {
if (sizex * sizey * sizez == 0) { \
return; \
} \
- Cuda3DLaunchConfig config = GetCuda3DLaunchConfig( \
- sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop<T, Index>, \
- 0, 0); \
+ Cuda3DLaunchConfig config = GetCuda3DLaunchConfig(sizex, sizey, sizez, d,\
+ ReduceSliceDeviceKernel##reduceop<T, Index>, 0, 0); \
\
ReduceSliceDeviceKernel##reduceop<T, Index> \
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( \
diff --git a/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py b/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py
index 8c8db295ff..60a193db4c 100644
--- a/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py
+++ b/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py
@@ -39,48 +39,44 @@ class ReduceSliceTest(TensorFlowTestCase):
def testReduceSliceSum2D(self):
x = np.array([[1, 2, 3], [40, 50, 60], [700, 800, 900]], dtype=np.int32)
indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
- result = np.array(
- [[1, 2, 3], [741, 852, 963], [40, 50, 60], [740, 850, 960],
- [41, 52, 63]],
- dtype=np.int32)
+ result = np.array([[1, 2, 3], [741, 852, 963], [40, 50, 60],
+ [740, 850, 960], [41, 52, 63]], dtype=np.int32)
with self.test_session(use_gpu=True):
y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
self.assertAllEqual(y_tf, result)
def testReduceSliceSum3D(self):
- x = np.array(
- [[[1, 2], [3, 4]], [[50, 60], [70, 80]], [[600, 700], [800, 900]]],
- dtype=np.int32)
+ x = np.array([[[1, 2], [3, 4]], [[50, 60], [70, 80]],
+ [[600, 700], [800, 900]]], dtype=np.int32)
indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
- result = np.array(
- [[[1, 2], [3, 4]], [[651, 762], [873, 984]], [[50, 60], [70, 80]],
- [[650, 760], [870, 980]], [[51, 62], [73, 84]]],
- dtype=np.int32)
+ result = np.array([[[1, 2], [3, 4]],
+ [[651, 762], [873, 984]],
+ [[50, 60], [70, 80]],
+ [[650, 760], [870, 980]],
+ [[51, 62], [73, 84]]], dtype=np.int32)
with self.test_session(use_gpu=True):
y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
self.assertAllEqual(y_tf, result)
def testReduceSliceSumAxis1(self):
- x = np.transpose(
- np.array([[1, 2, 3], [40, 50, 60], [700, 800, 900]], dtype=np.int32))
+ x = np.transpose(np.array([[1, 2, 3], [40, 50, 60],
+ [700, 800, 900]], dtype=np.int32))
indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
- result = np.transpose(
- np.array(
- [[1, 2, 3], [741, 852, 963], [40, 50, 60], [740, 850, 960],
- [41, 52, 63]],
- dtype=np.int32))
+ result = np.transpose(np.array([[1, 2, 3],
+ [741, 852, 963],
+ [40, 50, 60],
+ [740, 850, 960],
+ [41, 52, 63]], dtype=np.int32))
with self.test_session(use_gpu=True):
y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 1).eval()
self.assertAllEqual(y_tf, result)
def testReduceSliceSum1DIndices(self):
- x = np.array(
- [[1, 2, 3], [40, 50, 60], [700, 800, 900], [1000, 2000, 3000],
- [40000, 50000, 60000]],
- dtype=np.int32)
+ x = np.array([[1, 2, 3], [40, 50, 60], [700, 800, 900],
+ [1000, 2000, 3000], [40000, 50000, 60000]], dtype=np.int32)
indices = np.array([0, 0, 2, 5], dtype=np.int32)
- result = np.array(
- [[0, 0, 0], [41, 52, 63], [41700, 52800, 63900]], dtype=np.int32)
+ result = np.array([[0, 0, 0], [41, 52, 63],
+ [41700, 52800, 63900]], dtype=np.int32)
with self.test_session(use_gpu=True):
y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
self.assertAllEqual(y_tf, result)
@@ -88,9 +84,8 @@ class ReduceSliceTest(TensorFlowTestCase):
def testReduceSliceProd(self):
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
- result = np.array(
- [[1, 2, 3], [28, 80, 162], [4, 5, 6], [28, 40, 54], [4, 10, 18]],
- dtype=np.int32)
+ result = np.array([[1, 2, 3], [28, 80, 162], [4, 5, 6],
+ [28, 40, 54], [4, 10, 18]], dtype=np.int32)
with self.test_session(use_gpu=True):
y_tf = reduce_slice_ops.reduce_slice_prod(x, indices, 0).eval()
self.assertAllEqual(y_tf, result)
@@ -98,8 +93,8 @@ class ReduceSliceTest(TensorFlowTestCase):
def testReduceSliceMax(self):
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
- result = np.array(
- [[1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9], [4, 5, 6]], dtype=np.int32)
+ result = np.array([[1, 2, 3], [7, 8, 9], [4, 5, 6],
+ [7, 8, 9], [4, 5, 6]], dtype=np.int32)
with self.test_session(use_gpu=True):
y_tf = reduce_slice_ops.reduce_slice_max(x, indices, 0).eval()
self.assertAllEqual(y_tf, result)
@@ -107,8 +102,8 @@ class ReduceSliceTest(TensorFlowTestCase):
def testReduceSliceMin(self):
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
- result = np.array(
- [[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6], [1, 2, 3]], dtype=np.int32)
+ result = np.array([[1, 2, 3], [1, 2, 3], [4, 5, 6],
+ [4, 5, 6], [1, 2, 3]], dtype=np.int32)
with self.test_session(use_gpu=True):
y_tf = reduce_slice_ops.reduce_slice_min(x, indices, 0).eval()
self.assertAllEqual(y_tf, result)
diff --git a/tensorflow/contrib/rnn/python/ops/rnn.py b/tensorflow/contrib/rnn/python/ops/rnn.py
index 47c3e74325..2f0caadda3 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn.py
@@ -83,12 +83,14 @@ def stack_bidirectional_rnn(cells_fw,
raise ValueError("cells_bw must be a list of RNNCells (one per layer).")
if len(cells_fw) != len(cells_bw):
raise ValueError("Forward and Backward cells must have the same depth.")
- if initial_states_fw is not None and (not isinstance(cells_fw, list) or
- len(cells_fw) != len(cells_fw)):
+ if (initial_states_fw is not None and
+ (not isinstance(initial_states_fw, list) or
+ len(initial_states_fw) != len(cells_fw))):
raise ValueError(
"initial_states_fw must be a list of state tensors (one per layer).")
- if initial_states_bw is not None and (not isinstance(cells_bw, list) or
- len(cells_bw) != len(cells_bw)):
+ if (initial_states_bw is not None and
+ (not isinstance(initial_states_bw, list) or
+ len(initial_states_bw) != len(cells_bw))):
raise ValueError(
"initial_states_bw must be a list of state tensors (one per layer).")
states_fw = []
@@ -194,12 +196,14 @@ def stack_bidirectional_dynamic_rnn(cells_fw,
raise ValueError("cells_bw must be a list of RNNCells (one per layer).")
if len(cells_fw) != len(cells_bw):
raise ValueError("Forward and Backward cells must have the same depth.")
- if initial_states_fw is not None and (not isinstance(cells_fw, list) or
- len(cells_fw) != len(cells_fw)):
+ if (initial_states_fw is not None and
+ (not isinstance(initial_states_fw, list) or
+ len(initial_states_fw) != len(cells_fw))):
raise ValueError(
"initial_states_fw must be a list of state tensors (one per layer).")
- if initial_states_bw is not None and (not isinstance(cells_bw, list) or
- len(cells_bw) != len(cells_bw)):
+ if (initial_states_bw is not None and
+ (not isinstance(initial_states_bw, list) or
+ len(initial_states_bw) != len(cells_bw))):
raise ValueError(
"initial_states_bw must be a list of state tensors (one per layer).")
diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
index 10ea883e1f..ea8cc0ff61 100644
--- a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
@@ -119,9 +119,13 @@ class ParallelReaderTest(test.TestCase):
self.assertEquals(count0, num_records_per_file)
self.assertEquals(count1, num_records_per_file)
self.assertEquals(count2, num_records_per_file)
- self.assertEquals(all_keys_count, num_files * num_records_per_file)
+ self.assertEquals(
+ all_keys_count,
+ num_files * num_records_per_file)
self.assertEquals(all_values_count, all_keys_count)
- self.assertEquals(count0 + count1 + count2, all_keys_count)
+ self.assertEquals(
+ count0 + count1 + count2,
+ all_keys_count)
def testRandomShuffleQueue(self):
shared_queue = data_flow_ops.RandomShuffleQueue(
@@ -140,16 +144,14 @@ class ParallelReaderTest(test.TestCase):
capacity=55,
min_after_dequeue=28,
dtypes=[dtypes_lib.string, dtypes_lib.string],
- shapes=[tensor_shape.scalar(),
- tensor_shape.scalar()])
+ shapes=[tensor_shape.scalar(), tensor_shape.scalar()])
self._verify_read_up_to_out(shared_queue)
def testReadUpToFromFIFOQueue(self):
shared_queue = data_flow_ops.FIFOQueue(
capacity=99,
dtypes=[dtypes_lib.string, dtypes_lib.string],
- shapes=[tensor_shape.scalar(),
- tensor_shape.scalar()])
+ shapes=[tensor_shape.scalar(), tensor_shape.scalar()])
self._verify_read_up_to_out(shared_queue)
diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py
index 8f690fb549..c7614fd426 100644
--- a/tensorflow/contrib/slim/python/slim/learning.py
+++ b/tensorflow/contrib/slim/python/slim/learning.py
@@ -251,7 +251,6 @@ import os
import sys
import time
-from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.training.python.training import training
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import timeline
@@ -263,7 +262,7 @@ from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import optimizer as tf_optimizer
@@ -646,7 +645,7 @@ def train(train_op,
graph = graph or ops.get_default_graph()
with graph.as_default():
if global_step is None:
- global_step = variables.get_or_create_global_step()
+ global_step = training_util.get_or_create_global_step()
saver = saver or tf_saver.Saver()
if sync_optimizer is not None:
@@ -657,14 +656,14 @@ def train(train_op,
with ops.name_scope('init_ops'):
if init_op == _USE_DEFAULT:
- init_op = tf_variables.global_variables_initializer()
+ init_op = variables.global_variables_initializer()
if ready_op == _USE_DEFAULT:
- ready_op = tf_variables.report_uninitialized_variables()
+ ready_op = variables.report_uninitialized_variables()
if local_init_op == _USE_DEFAULT:
local_init_op = control_flow_ops.group(
- tf_variables.local_variables_initializer(),
+ variables.local_variables_initializer(),
lookup_ops.tables_initializer())
if sync_optimizer is not None and isinstance(sync_optimizer, list):
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
index 4b5f2a1c91..e77cae4695 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
@@ -52,12 +52,26 @@ setup(
# 4 - Beta
# 5 - Production/Stable
'Development Status :: 3 - Alpha',
-
- # Indicate who your project is intended for
+
'Intended Audience :: Developers',
- 'Topic :: Software Development',
+ 'Intended Audience :: Education',
+ 'Intended Audience :: Science/Research',
+
'License :: OSI Approved :: Apache Software License',
+
+ 'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.4',
+ 'Programming Language :: Python :: 3.5',
+ 'Programming Language :: Python :: 3.6',
+
+ 'Topic :: Scientific/Engineering',
+ 'Topic :: Scientific/Engineering :: Mathematics',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'Topic :: Software Development',
+ 'Topic :: Software Development :: Libraries',
+ 'Topic :: Software Development :: Libraries :: Python Modules',
],
license='Apache 2.0',
keywords='tensorflow performance tpu',)
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index 445cbe290a..ec5adfdaa0 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -707,7 +707,6 @@ void RdmaTensorBuffer::SendNextItem() {
bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
// string tensor needs to be serialized
Tensor copy;
- StringPiece copy_buf;
TensorProto proto;
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
@@ -721,109 +720,133 @@ void RdmaTensorBuffer::SendNextItem() {
host_alloc_attrs.set_on_host(true);
Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
copy = Tensor(alloc, in.dtype(), in.shape());
- s = VerbsUtil::CopyGPUTensorToCPUSync(
- src_dev, send_args.device_context, &in, &copy);
- CHECK(s.ok()) << "copy tensor from gpu sync";
- copy_buf = copy.tensor_data();
+ tensor_bytes = in.TotalBytes();
+ buffer_size += tensor_bytes;
+ GPUUtil::CopyGPUTensorToCPU(
+ src_dev, send_args.device_context, &in, &copy,
+ [this, copy, tensor_bytes, buffer_size, key, in, step_id,
+ key_with_step_id, is_dead](const Status& s) {
+ CHECK(s.ok()) << "copy tensor from gpu sync";
+ StringPiece copy_buf;
+ copy_buf = copy.tensor_data();
+ PostCopyOperations(true, buffer_size, tensor_bytes, key, in,
+ step_id, is_dead, key_with_step_id, &copy,
+ NULL, &copy_buf);
+ });
} else {
- // "val" is on a GPU. Uses GPUUtil to fill the proto.
- s = VerbsUtil::SetProtoFromGPUSync(
- in, src_dev, send_args.device_context, &proto, is_dead);
- CHECK(s.ok()) << "set proto from gpu sync";
+ // "val" is on a GPU. No longer uses GPUUtil to fill the proto, use
+ // aync instead
+ GPUUtil::SetProtoFromGPU(
+ in, src_dev, send_args.device_context, &proto, is_dead,
+ [this, proto, buffer_size, key, in, step_id, key_with_step_id,
+ is_dead](const Status& s) mutable {
+ CHECK(s.ok()) << "copy proto from gpu sync";
+ auto tensor_bytes = proto.ByteSize();
+ buffer_size += tensor_bytes;
+ PostCopyOperations(false, buffer_size, tensor_bytes, key, in,
+ step_id, is_dead, key_with_step_id, NULL,
+ &proto, NULL);
+ });
}
} else {
// tensor is in CPU memory.
+ StringPiece copy_buf;
if (can_memcpy) {
copy_buf = in.tensor_data();
+ tensor_bytes = in.TotalBytes();
} else {
in.AsProtoTensorContent(&proto);
+ tensor_bytes = proto.ByteSize();
}
- }
- if (can_memcpy) {
- tensor_bytes = in.TotalBytes();
- } else {
- tensor_bytes = proto.ByteSize();
+ buffer_size += tensor_bytes;
+ PostCopyOperations(can_memcpy, buffer_size, tensor_bytes, key, in,
+ step_id, is_dead, key_with_step_id, &copy, &proto,
+ &copy_buf);
}
// maybe some margin for string tensor?
- buffer_size += tensor_bytes;
- // prepare message
- RdmaMessage rm;
- rm.name_size_ = key.size();
- rm.name_ = key;
- rm.tensor_shape_ = in.shape();
- rm.data_type_ = in.dtype();
- rm.step_id_ = step_id;
- rm.is_dead_ = is_dead;
- rm.tensor_bytes_ = tensor_bytes;
- rm.buffer_size_ = buffer_size;
- mu_.lock();
- if (local_status_ == none ||
- (buffer_size > size_ && local_status_ == idle &&
- remote_status_ == idle)) {
- if ((local_status_ != none) && (buffer_size > size_)) {
- VLOG(2) << "Extend RDMA buffer from " << size_ << " to "
- << buffer_size;
- }
- CreateCPUBuffer(buffer_size, false);
- mu_.unlock();
- // put back the key since it is not sent;
- EnqueueItem(key_with_step_id);
- // ask the remote to create the same buffer
- rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST;
- rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_);
- rm.rkey_ = self_->rkey;
- string message = RdmaMessage::CreateMessage(rm);
- channel_->tx_message_buffer_->EnqueueItem(message);
- channel_->tx_message_buffer_->SendNextItem();
- } else if ((local_status_ == idle) && (remote_status_ == idle)) {
- // both buffers are ready, send the tensor
- local_status_ = busy;
- remote_status_ = busy;
- // local/remote_status_ won't be set back to idle
- // unitl Write() is successful
- mu_.unlock();
- if (!((buffer_size == size_ && rm.data_type_ != DT_STRING) ||
- (buffer_size <= size_ && rm.data_type_ == DT_STRING))) {
- VLOG(2) << "Tensor and buffer size do not agree,"
- << " buffer_size = " << size_
- << " requested tensor size = "
- << buffer_size << in.DebugString();
- }
- uint32_t imm_data = LookupBufferIndex(key);
- rm.type_ = RDMA_MESSAGE_TENSOR_WRITE;
- string message = RdmaMessage::CreateMessage(rm);
- memcpy(buffer_, message.data(), message.size());
- if (!is_dead) {
- // copy the tensor buffer content
- void* output =
- static_cast<void*>(static_cast<char*>(buffer_) +
- RdmaMessage::kTensorBufferStartIndex);
- CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
- if (can_memcpy) {
- CHECK(copy_buf.size() == tensor_bytes)
- << "unexpected tensor size: "
- << copy_buf.size()
- << " != "
- << tensor_bytes;
- memcpy(output, copy_buf.data(), tensor_bytes);
- } else {
- proto.SerializeToArray(output, tensor_bytes);
- }
- } else {
- buffer_size = RdmaMessage::kMessageTotalBytes;
- }
- Write(imm_data, buffer_size);
- } else {
- mu_.unlock();
- // put back the key since it is not sent;
- EnqueueItem(key_with_step_id);
- }
};
+
channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(step_id,
parsed, cb);
}
}
+void RdmaTensorBuffer::PostCopyOperations(
+ bool can_memcpy, size_t buffer_size, size_t tensor_bytes, const string& key,
+ const Tensor& in, int64 step_id, bool is_dead,
+ const string& key_with_step_id, const Tensor* copy,
+ const TensorProto* proto, const StringPiece* copy_buf) {
+ // prepare message
+ RdmaMessage rm;
+ rm.name_size_ = key.size();
+ rm.name_ = key;
+ rm.tensor_shape_ = in.shape();
+ rm.data_type_ = in.dtype();
+ rm.step_id_ = step_id;
+ rm.is_dead_ = is_dead;
+ rm.tensor_bytes_ = tensor_bytes;
+ rm.buffer_size_ = buffer_size;
+ mu_.lock();
+ if (local_status_ == none || (buffer_size > size_ && local_status_ == idle &&
+ remote_status_ == idle)) {
+ if ((local_status_ != none) && (buffer_size > size_)) {
+ VLOG(2) << "Extend RDMA buffer from " << size_ << " to " << buffer_size;
+ }
+ CreateCPUBuffer(buffer_size, false);
+ mu_.unlock();
+ // put back the key since it is not sent;
+ EnqueueItem(key_with_step_id);
+ // ask the remote to create the same buffer
+ rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST;
+ rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_);
+ rm.rkey_ = self_->rkey;
+ string message = RdmaMessage::CreateMessage(rm);
+ channel_->tx_message_buffer_->EnqueueItem(message);
+ channel_->tx_message_buffer_->SendNextItem();
+ } else if ((local_status_ == idle) && (remote_status_ == idle)) {
+ // both buffers are ready, send the tensor
+ local_status_ = busy;
+ remote_status_ = busy;
+ // local/remote_status_ won't be set back to idle
+ // unitl Write() is successful
+ mu_.unlock();
+ if (!((buffer_size == size_ && rm.data_type_ != DT_STRING) ||
+ (buffer_size <= size_ && rm.data_type_ == DT_STRING))) {
+ VLOG(2) << "Tensor and buffer size do not agree,"
+ << " buffer_size = " << size_
+ << " requested tensor size = " << buffer_size << in.DebugString();
+ }
+ uint32_t imm_data = LookupBufferIndex(key);
+ rm.type_ = RDMA_MESSAGE_TENSOR_WRITE;
+ string message = RdmaMessage::CreateMessage(rm);
+ memcpy(buffer_, message.data(), message.size());
+ if (!is_dead) {
+ // copy the tensor buffer content
+ void* output = static_cast<void*>(static_cast<char*>(buffer_) +
+ RdmaMessage::kTensorBufferStartIndex);
+ CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
+ if (can_memcpy) {
+ CHECK(copy != NULL) << "callback missing pointer to copy tensor";
+ CHECK(copy_buf != NULL) << "callback missing pointer to copy buffer";
+ CHECK(copy_buf->size() == tensor_bytes)
+ << "unexpected tensor size: " << copy_buf->size()
+ << " != " << tensor_bytes;
+ memcpy(output, copy_buf->data(), tensor_bytes);
+ } else {
+ CHECK(proto != NULL) << "callback missing pointer to proto tensor";
+ proto->SerializeToArray(output, tensor_bytes);
+ }
+ } else {
+ buffer_size = RdmaMessage::kMessageTotalBytes;
+ }
+ Write(imm_data, buffer_size);
+ } else {
+ mu_.unlock();
+ // put back the key since it is not sent;
+ EnqueueItem(key_with_step_id);
+ }
+}
+
// Create a RdmaMessage according to the pre-defined format
// Args:
// rm: the message structure
diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h
index 10cbbe58d9..16ef58bc62 100644
--- a/tensorflow/contrib/verbs/rdma.h
+++ b/tensorflow/contrib/verbs/rdma.h
@@ -28,6 +28,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/env.h"
@@ -225,6 +226,12 @@ class RdmaTensorBuffer : public RdmaBuffer {
explicit RdmaTensorBuffer(RdmaChannel* channel, string name);
virtual ~RdmaTensorBuffer() override {}
void SendNextItem() override;
+ void PostCopyOperations(bool can_memcpy, size_t buffer_size,
+ size_t tensor_bytes, const string& key,
+ const Tensor& in, int64 step_id, bool is_dead,
+ const string& key_with_step_id, const Tensor* copy,
+ const TensorProto* proto,
+ const StringPiece* copy_buf);
};
struct RdmaMessage {
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
index 3ba6510711..ce82ca2883 100644
--- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
@@ -33,6 +34,11 @@ class RdmaRemoteRendezvous : public BaseRemoteRendezvous {
RdmaRemoteRendezvous(const WorkerEnv* env, int64 step_id, RdmaMgr* rdma_mgr)
: BaseRemoteRendezvous(env, step_id), rdma_mgr_(rdma_mgr) {}
+ void RecvPostCopyOps(const string& key, const string& key_with_step_id,
+ const Rendezvous::Args& recv_args,
+ const DoneCallback& done, const RdmaMessage& rm,
+ RdmaChannel* rc, Tensor& val, const Status& s);
+
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
@@ -113,10 +119,18 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
Allocator* dst_alloc = dst_dev->GetAllocator(recv_args.alloc_attrs);
Tensor gpu_copy(dst_alloc, rm.data_type_, rm.tensor_shape_);
- s = VerbsUtil::CopyCPUTensorToGPUSync(&copy, recv_args.device_context,
- dst_dev, &gpu_copy);
- CHECK(s.ok()) << "copy tensor to gpu sync";
- val = std::move(gpu_copy);
+
+ GPUUtil::CopyCPUTensorToGPU(
+ &copy, recv_args.device_context, dst_dev, &gpu_copy,
+ [this, gpu_copy, key, key_with_step_id, recv_args, done, rm,
+ rc](const Status& s) {
+ CHECK(s.ok()) << "copy tensor to gpu sync";
+ Tensor val;
+ val = std::move(gpu_copy);
+ RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc,
+ val, s);
+ });
+ return;
} else {
AllocatorAttributes host_alloc_attrs;
host_alloc_attrs.set_gpu_compatible(true);
@@ -135,18 +149,7 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
}
}
-
- rc->RemoveRecvCallback(key_with_step_id);
- // create message
- RdmaMessage br;
- br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
- br.name_size_ = key.size();
- br.name_ = key;
- string message = RdmaMessage::CreateMessage(br);
- RdmaBuffer* tb = rc->tx_message_buffer_;
- tb->EnqueueItem(message);
- tb->SendNextItem();
- done(s, Args(), recv_args, val, rm.is_dead_);
+ RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc, val, s);
});
// append key to message queue
RdmaBuffer* rb = rc->tx_message_buffer_;
@@ -160,6 +163,22 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
rb->SendNextItem();
}
+void RdmaRemoteRendezvous::RecvPostCopyOps(
+ const string& key, const string& key_with_step_id,
+ const Rendezvous::Args& recv_args, const DoneCallback& done,
+ const RdmaMessage& rm, RdmaChannel* rc, Tensor& val, const Status& s) {
+ rc->RemoveRecvCallback(key_with_step_id);
+ RdmaMessage br;
+ br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
+ br.name_size_ = key.size();
+ br.name_ = key;
+ string message = RdmaMessage::CreateMessage(br);
+ RdmaBuffer* tb = rc->tx_message_buffer_;
+ tb->EnqueueItem(message);
+ tb->SendNextItem();
+ done(s, Args(), recv_args, val, rm.is_dead_);
+}
+
RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env)
: BaseRendezvousMgr(env) {}
diff --git a/tensorflow/contrib/verbs/verbs_util.cc b/tensorflow/contrib/verbs/verbs_util.cc
index 76e44d34a9..4f5c731a18 100644
--- a/tensorflow/contrib/verbs/verbs_util.cc
+++ b/tensorflow/contrib/verbs/verbs_util.cc
@@ -20,55 +20,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
-// static sync wrapper:
-Status VerbsUtil::CopyGPUTensorToCPUSync(Device* gpu_device,
- const DeviceContext* device_context,
- const Tensor* gpu_tensor,
- Tensor* cpu_tensor) {
- Notification n;
- Status status;
- GPUUtil::CopyGPUTensorToCPU(gpu_device, device_context,
- gpu_tensor, cpu_tensor,
- [&n, &status](const Status& s) {
- status = s;
- n.Notify();
- });
- n.WaitForNotification();
- return status;
-}
-
-// static sync wrapper:
-Status VerbsUtil::CopyCPUTensorToGPUSync(const Tensor* cpu_tensor,
- const DeviceContext* device_context,
- Device* gpu_device,
- Tensor* gpu_tensor) {
- Notification n;
- Status status;
- GPUUtil::CopyCPUTensorToGPU(cpu_tensor, device_context,
- gpu_device, gpu_tensor,
- [&n, &status](const Status& s) {
- status = s;
- n.Notify();
- });
- n.WaitForNotification();
- return status;
-}
-
-// static sync wrapper:
-Status VerbsUtil::SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
- const DeviceContext* device_context,
- TensorProto* proto, bool is_dead) {
- Notification n;
- Status status;
- GPUUtil::SetProtoFromGPU(tensor, dev, device_context, proto, is_dead,
- [&n, &status](const Status& s) {
- status = s;
- n.Notify();
- });
- n.WaitForNotification();
- return status;
-}
-
// static
string VerbsUtil::AppendStepidToKey(const string& key, int64 step_id) {
return strings::StrCat(key, ";", step_id);
diff --git a/tensorflow/contrib/verbs/verbs_util.h b/tensorflow/contrib/verbs/verbs_util.h
index d9da396228..8b44adaedc 100644
--- a/tensorflow/contrib/verbs/verbs_util.h
+++ b/tensorflow/contrib/verbs/verbs_util.h
@@ -28,20 +28,6 @@ class TensorProto;
class VerbsUtil {
public:
- // synchronous wrapper of CopyGPUTensorToCPU
- static Status CopyGPUTensorToCPUSync(Device* gpu_device,
- const DeviceContext* device_context,
- const Tensor* gpu_tensor,
- Tensor* cpu_tensor);
- // synchronous wrapper of CopyCPUTensorToGPU
- static Status CopyCPUTensorToGPUSync(const Tensor* cpu_tensor,
- const DeviceContext* device_context,
- Device* gpu_device,
- Tensor* gpu_tensor);
- // synchronous wrapper of SetProtoFromGPU
- static Status SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
- const DeviceContext* device_context,
- TensorProto* proto, bool is_dead);
static string AppendStepidToKey(const string& key, int64 step_id);
static void GetKeyAndStepId(const string& key_with_step_id, string& key,
int64& step_id);
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
index 4ff2d0f5e3..d1dc622ce7 100644
--- a/tensorflow/core/distributed_runtime/master.cc
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -116,6 +116,18 @@ void Master::GC() {
}
}
+MasterSession* Master::FindMasterSession(const string& handle) {
+ MasterSession* session = nullptr;
+ {
+ mutex_lock l(mu_);
+ session = gtl::FindPtrOrNull(sessions_, handle);
+ if (session != nullptr) {
+ session->Ref();
+ }
+ }
+ return session;
+}
+
class DeviceFinder {
public:
static Status GetRemoteDevices(
@@ -429,16 +441,11 @@ void Master::CreateSession(const CreateSessionRequest* req,
void Master::ExtendSession(const ExtendSessionRequest* req,
ExtendSessionResponse* resp, MyClosure done) {
- mu_.lock();
- MasterSession* session = nullptr;
- session = gtl::FindPtrOrNull(sessions_, req->session_handle());
+ auto session = FindMasterSession(req->session_handle());
if (session == nullptr) {
- mu_.unlock();
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
return;
}
- session->Ref();
- mu_.unlock();
SchedClosure([session, req, resp, done]() {
Status status = ValidateExternalGraphDefSyntax(req->graph_def());
@@ -452,15 +459,11 @@ void Master::ExtendSession(const ExtendSessionRequest* req,
void Master::PartialRunSetup(const PartialRunSetupRequest* req,
PartialRunSetupResponse* resp, MyClosure done) {
- mu_.lock();
- MasterSession* session = gtl::FindPtrOrNull(sessions_, req->session_handle());
+ auto session = FindMasterSession(req->session_handle());
if (session == nullptr) {
- mu_.unlock();
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
return;
}
- session->Ref();
- mu_.unlock();
SchedClosure([this, session, req, resp, done]() {
Status s = session->PartialRunSetup(req, resp);
@@ -471,16 +474,12 @@ void Master::PartialRunSetup(const PartialRunSetupRequest* req,
void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
MutableRunStepResponseWrapper* resp, MyClosure done) {
- mu_.lock();
- uint64 start_time = env_->env->NowMicros();
- MasterSession* session = gtl::FindPtrOrNull(sessions_, req->session_handle());
+ auto start_time = env_->env->NowMicros();
+ auto session = FindMasterSession(req->session_handle());
if (session == nullptr) {
- mu_.unlock();
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
return;
}
- session->Ref();
- mu_.unlock();
SchedClosure([this, start_time, session, opts, req, resp, done]() {
Status status = session->Run(opts, *req, resp);
@@ -526,18 +525,11 @@ void Master::ListDevices(const ListDevicesRequest* req,
ListDevicesResponse* resp, MyClosure done) {
SchedClosure([this, req, resp, done]() {
if (!req->session_handle().empty()) {
- MasterSession* session = nullptr;
- {
- mutex_lock l(mu_);
- session = gtl::FindPtrOrNull(sessions_, req->session_handle());
- if (session != nullptr) {
- session->Ref();
- }
- }
+ auto session = FindMasterSession(req->session_handle());
if (session == nullptr) {
done(errors::InvalidArgument(
- "Session ", req->session_handle(),
- " is not found. Possibly, this master has restarted."));
+ "Session ", req->session_handle(),
+ " is not found. Possibly, this master has restarted."));
return;
}
core::ScopedUnref ref(session);
diff --git a/tensorflow/core/distributed_runtime/master.h b/tensorflow/core/distributed_runtime/master.h
index ce05a6508b..678fc46bd7 100644
--- a/tensorflow/core/distributed_runtime/master.h
+++ b/tensorflow/core/distributed_runtime/master.h
@@ -94,6 +94,10 @@ class Master {
// Cleanup unused session.
void GC();
+ // Find master session by session handle, and increments the reference count
+ // on the returned MasterSession if not null.
+ MasterSession* FindMasterSession(const string& handle);
+
TF_DISALLOW_COPY_AND_ASSIGN(Master);
};
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
index c04aa44941..70418f6368 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
@@ -32,6 +32,8 @@ namespace tensorflow {
// GrpcRemoteMaster is an implementation of the MasterInterface
// that uses gRPC to talk to the Master service.
class GrpcRemoteMaster : public MasterInterface {
+ using MasterServiceStub = grpc::MasterService::Stub;
+
public:
explicit GrpcRemoteMaster(const SharedGrpcChannelPtr& client_channel)
: stub_(grpc::MasterService::NewStub(client_channel)) {}
@@ -42,63 +44,56 @@ class GrpcRemoteMaster : public MasterInterface {
const CreateSessionRequest* request,
CreateSessionResponse* response) override {
::grpc::ClientContext ctx;
- ctx.set_fail_fast(false);
- SetDeadline(&ctx, call_options->GetTimeout());
- return FromGrpcStatus(stub_->CreateSession(&ctx, *request, response));
+ return Call(&ctx, call_options, request, response,
+ &MasterServiceStub::CreateSession);
}
Status ExtendSession(CallOptions* call_options,
const ExtendSessionRequest* request,
ExtendSessionResponse* response) override {
::grpc::ClientContext ctx;
- ctx.set_fail_fast(false);
- SetDeadline(&ctx, call_options->GetTimeout());
- return FromGrpcStatus(stub_->ExtendSession(&ctx, *request, response));
+ return Call(&ctx, call_options, request, response,
+ &MasterServiceStub::ExtendSession);
}
Status PartialRunSetup(CallOptions* call_options,
const PartialRunSetupRequest* request,
PartialRunSetupResponse* response) override {
::grpc::ClientContext ctx;
- ctx.set_fail_fast(false);
- SetDeadline(&ctx, call_options->GetTimeout());
- return FromGrpcStatus(stub_->PartialRunSetup(&ctx, *request, response));
+ return Call(&ctx, call_options, request, response,
+ &MasterServiceStub::PartialRunSetup);
}
Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request,
MutableRunStepResponseWrapper* response) override {
::grpc::ClientContext ctx;
auto trace = TraceRpc("RunStep/Client", &ctx);
- ctx.set_fail_fast(false);
- SetDeadline(&ctx, call_options->GetTimeout());
- return FromGrpcStatus(stub_->RunStep(&ctx, request->ToProto(),
- get_proto_from_wrapper(response)));
+ return Call(&ctx, call_options, &request->ToProto(),
+ get_proto_from_wrapper(response),
+ &MasterServiceStub::RunStep);
}
Status CloseSession(CallOptions* call_options,
const CloseSessionRequest* request,
CloseSessionResponse* response) override {
::grpc::ClientContext ctx;
- ctx.set_fail_fast(false);
- SetDeadline(&ctx, call_options->GetTimeout());
- return FromGrpcStatus(stub_->CloseSession(&ctx, *request, response));
+ return Call(&ctx, call_options, request, response,
+ &MasterServiceStub::CloseSession);
}
Status ListDevices(CallOptions* call_options,
const ListDevicesRequest* request,
ListDevicesResponse* response) override {
::grpc::ClientContext ctx;
- ctx.set_fail_fast(false);
- SetDeadline(&ctx, call_options->GetTimeout());
- return FromGrpcStatus(stub_->ListDevices(&ctx, *request, response));
+ return Call(&ctx, call_options, request, response,
+ &MasterServiceStub::ListDevices);
}
Status Reset(CallOptions* call_options, const ResetRequest* request,
ResetResponse* response) override {
::grpc::ClientContext ctx;
- ctx.set_fail_fast(false);
- SetDeadline(&ctx, call_options->GetTimeout());
- return FromGrpcStatus(stub_->Reset(&ctx, *request, response));
+ return Call(&ctx, call_options, request, response,
+ &MasterServiceStub::Reset);
}
private:
@@ -110,13 +105,23 @@ class GrpcRemoteMaster : public MasterInterface {
return port::Tracing::TraceMe(name, trace_id);
}
- std::unique_ptr<grpc::MasterService::Stub> stub_;
-
void SetDeadline(::grpc::ClientContext* ctx, int64 time_in_ms) {
if (time_in_ms > 0) {
ctx->set_deadline(gpr_time_from_millis(time_in_ms, GPR_TIMESPAN));
}
}
+
+ template <typename Request, typename Response>
+ Status Call(::grpc::ClientContext* ctx, CallOptions* call_options,
+ const Request* request, Response* response,
+ ::grpc::Status (MasterServiceStub::*pfunc)(
+ ::grpc::ClientContext*, const Request&, Response*)) {
+ ctx->set_fail_fast(false);
+ SetDeadline(ctx, call_options->GetTimeout());
+ return FromGrpcStatus((stub_.get()->*pfunc)(ctx, *request, response));
+ }
+
+ std::unique_ptr<MasterServiceStub> stub_;
};
MasterInterface* NewGrpcMaster(const SharedGrpcChannelPtr& channel) {
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 34b53e965f..80c8f3ad3d 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -34,8 +34,8 @@ void Worker::GetStatusAsync(const GetStatusRequest* request,
std::vector<DeviceAttributes> devices;
dm->ListDeviceAttributes(&devices);
response->mutable_device_attributes()->Reserve(devices.size());
- for (size_t i = 0; i < devices.size(); ++i) {
- response->add_device_attributes()->Swap(&devices[i]);
+ for (auto& d : devices) {
+ response->add_device_attributes()->Swap(&d);
}
done(Status::OK());
}
diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc
index 8691450e9b..cdf5c3cf3b 100644
--- a/tensorflow/core/distributed_runtime/worker_session.cc
+++ b/tensorflow/core/distributed_runtime/worker_session.cc
@@ -27,7 +27,7 @@ class WorkerFreeListCache : public WorkerCacheInterface {
: wrapped_(std::move(w)) {}
~WorkerFreeListCache() final {
- for (auto p : workers_) {
+ for (auto& p : workers_) {
wrapped_->ReleaseWorker(p.first, p.second.worker);
}
}
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 81551ee26f..329e7f9531 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -112,7 +112,6 @@ struct FusedBatchNorm<CPUDevice, T> {
batch_var.device(d) = variance * rest_size_adjust;
saved_var.device(d) = variance;
} else {
- mean.device(d) = estimated_mean;
variance.device(d) = estimated_variance;
}
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index ba597f5c67..64bdef0008 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -122,7 +122,9 @@ void DoNonMaxSuppressionOp(OpKernelContext* context,
for (int i = 0; i < num_boxes; ++i) {
if (selected.size() >= output_size) break;
bool should_select = true;
- for (int j = 0; j < num_selected; ++j) {
+ // Overlapping boxes are likely to have similar scores,
+ // therefore we iterate through the selected boxes backwards.
+ for (int j = num_selected - 1; j >= 0; --j) {
if (IOUGreaterThanThreshold(boxes_data, sorted_indices[i],
sorted_indices[selected_indices[j]],
iou_threshold)) {
diff --git a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
index 81eead11d1..d17b72bc26 100644
--- a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
+++ b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
@@ -114,13 +114,11 @@ class SparseFillEmptyRowsOp : public OpKernel {
&scratch_t));
auto scratch = scratch_t.vec<int64>();
scratch.device(d) = scratch.constant(0);
- int64 prev_row = -1;
for (int i = 0; i < N; ++i) {
const int64 row = indices(i, 0);
- OP_REQUIRES(context, indices(i, 0) >= 0 && indices(i, 0) < dense_rows,
+ OP_REQUIRES(context, row >= 0 && row < dense_rows,
errors::InvalidArgument("indices(", i, ", 0) is invalid: ",
- indices(i, 0), " >= ", dense_rows));
- prev_row = row;
+ row, " >= ", dense_rows));
++scratch(indices(i, 0));
}
for (int row = 0; row < dense_rows; ++row) {
diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc
index 36b8ff09d7..b14e555103 100644
--- a/tensorflow/core/kernels/variable_ops.cc
+++ b/tensorflow/core/kernels/variable_ops.cc
@@ -83,6 +83,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
IsVariableInitializedOp);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+TF_CALL_bool(REGISTER_GPU_KERNELS)
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/lib/core/refcount.h b/tensorflow/core/lib/core/refcount.h
index 21919adc2c..eb41f9ff36 100644
--- a/tensorflow/core/lib/core/refcount.h
+++ b/tensorflow/core/lib/core/refcount.h
@@ -87,7 +87,7 @@ inline bool RefCounted::Unref() const {
DCHECK_GT(ref_.load(), 0);
// If ref_==1, this object is owned only by the caller. Bypass a locked op
// in that case.
- if (ref_.load(std::memory_order_acquire) == 1 || ref_.fetch_sub(1) == 1) {
+ if (RefCountIsOne() || ref_.fetch_sub(1) == 1) {
// Make DCHECK in ~RefCounted happy
DCHECK((ref_.store(0), true));
delete this;
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 55262a8f18..651f22c6ea 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1006,7 +1006,8 @@ REGISTER_OP("Reverse")
.Input("dims: bool")
.Output("output: T")
.Attr(
- "T: {uint8, int8, int32, int64, bool, half, float, double, complex64, "
+ "T: {uint8, int8, uint16, int16, int32, int64, bool, half, float, "
+ "double, complex64, "
"complex128, string}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
@@ -1083,7 +1084,8 @@ REGISTER_OP("ReverseV2")
.Output("output: T")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr(
- "T: {uint8, int8, int32, int64, bool, half, float, double, complex64, "
+ "T: {uint8, int8, uint16, int16, int32, int64, bool, half, float, "
+ "double, complex64, "
"complex128, string}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
@@ -1518,8 +1520,8 @@ REGISTER_OP("GatherNd")
if (c->Value(r_dim) > c->Rank(params)) {
return errors::InvalidArgument(
"indices.shape[-1] must be <= params.rank, but saw indices shape: ",
- c->DebugString(indices),
- " and params shape: ", c->DebugString(params));
+ c->DebugString(indices), " and params shape: ",
+ c->DebugString(params));
}
// Remove r_dim from indices to get output.
@@ -2146,12 +2148,12 @@ REGISTER_OP("ReverseSequence")
// Validate batch_dim and seq_dim against input.
const int32 input_rank = c->Rank(input);
if (batch_dim >= input_rank) {
- return errors::InvalidArgument(
- "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
+ return errors::InvalidArgument("batch_dim must be < input rank: ",
+ batch_dim, " vs. ", input_rank);
}
if (seq_dim >= input_rank) {
- return errors::InvalidArgument(
- "seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank);
+ return errors::InvalidArgument("seq_dim must be < input rank: ",
+ seq_dim, " vs. ", input_rank);
}
DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
@@ -5094,9 +5096,8 @@ Status ScatterNdShape(InferenceContext* c) {
Status s = c->Merge(prefix_indices, prefix_updates, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
- "The outer ", outer_dims,
- " dimensions of indices.shape=", c->DebugString(indices_shape),
- " must match the outer ", outer_dims,
+ "The outer ", outer_dims, " dimensions of indices.shape=",
+ c->DebugString(indices_shape), " must match the outer ", outer_dims,
" dimensions of updates.shape=", c->DebugString(updates_shape),
": ", s.error_message());
}
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index b0d4f51fe3..0baeac0984 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -164,6 +164,8 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) {
} else {
hdfs_->hdfsBuilderSetNameNode(builder, nn.c_str());
}
+ // KERB_TICKET_CACHE_PATH will be deleted in the future, Because KRB5CCNAME is the build in
+ // environment variable of Kerberos, so KERB_TICKET_CACHE_PATH and related code are unnecessary.
char* ticket_cache_path = getenv("KERB_TICKET_CACHE_PATH");
if (ticket_cache_path != nullptr) {
hdfs_->hdfsBuilderSetKerbTicketCachePath(builder, ticket_cache_path);
diff --git a/tensorflow/core/profiler/README.md b/tensorflow/core/profiler/README.md
index 40fb1f836e..5c50a86c88 100644
--- a/tensorflow/core/profiler/README.md
+++ b/tensorflow/core/profiler/README.md
@@ -54,7 +54,7 @@ with tf.contrib.tfprof.ProfileContext() as pctx:
train_loop()
```
-```python
+```shell
# Profiling from Python API is not interactive.
# Dump the profiles to files and profile with interactive command line.
with tf.contrib.tfprof.ProfileContext() as pctx:
@@ -137,7 +137,7 @@ ApplyAdam 231.65MB (85.28%, 0.31%), 92.66ms (23.43%,
### Auto-profile.
-```
+```shell
tfprof> advise
Not running under xxxx. Skip JobChecker.
@@ -194,8 +194,9 @@ seq2seq_attention_model.py:363:build_graph:self._add_train_o..., cpu: 1.28sec, a
optimizer.py:97:update_op:return optimizer...., cpu: 84.76ms, accelerator: 0us, total: 84.76ms
```
-### Visualize time and memory.
-```
+### Visualize time and memory
+
+```shell
# The following example generates a timeline.
tfprof> graph -step 0 -max_depth 100000 -output timeline:outfile=<filename>
@@ -206,11 +207,10 @@ Timeline file is written to <filename>.
Open a Chrome browser, enter URL chrome://tracing and load the timeline file.
******************************************************
```
-<left>
+
![Timeline](g3doc/graph_timeline.png)
-</left>
-```
+```shell
# The following example generates a pprof graph (only supported by code view).
# Since TensorFlow runs the graph instead of Python code, the pprof graph
# doesn't profile the statistics of Python, but the TensorFlow graph
@@ -226,9 +226,7 @@ tfprof> code -select accelerator_micros -max_depth 100000 -output pprof:outfile=
pprof -png --nodecount=100 --sample_index=1 <filename>
```
-<left>
![PprofGraph](g3doc/pprof.jpg)
-</left>
### Feature Request and Bug Report
diff --git a/tensorflow/docs_src/api_guides/python/reading_data.md b/tensorflow/docs_src/api_guides/python/reading_data.md
index ff8b4f1aa7..c0ddd82d73 100644
--- a/tensorflow/docs_src/api_guides/python/reading_data.md
+++ b/tensorflow/docs_src/api_guides/python/reading_data.md
@@ -57,7 +57,7 @@ A typical pipeline for reading records from files has the following stages:
7. *Optional* preprocessing
8. Example queue
-Note: This section discusses implementing input pipelines useing the
+Note: This section discusses implementing input pipelines using the
queue-based APIs which can be cleanly replaced by the ${$datasets$Dataset API}.
### Filenames, shuffling, and epoch limits
@@ -498,4 +498,4 @@ session.
Note: Regardless of the implementation, many
operations (like ${tf.layers.batch_normalization}, and @{tf.layers.dropout})
need to know if they are in training or evaluation mode, and you must be
-careful to set this apropriately if you change the data source.
+careful to set this appropriately if you change the data source.
diff --git a/tensorflow/docs_src/deploy/hadoop.md b/tensorflow/docs_src/deploy/hadoop.md
index c50c1580a5..7592cf828b 100644
--- a/tensorflow/docs_src/deploy/hadoop.md
+++ b/tensorflow/docs_src/deploy/hadoop.md
@@ -55,10 +55,10 @@ be set:
If the Hadoop cluster is in secure mode, the following environment variable must
be set:
-* **KERB_TICKET_CACHE_PATH**: The path of Kerberos ticket cache file. For example:
+* **KRB5CCNAME**: The path of Kerberos ticket cache file. For example:
```shell
- export KERB_TICKET_CACHE_PATH=/tmp/krb5cc_10002
+ export KRB5CCNAME=/tmp/krb5cc_10002
```
If you are running @{$distributed$Distributed TensorFlow}, then all
diff --git a/tensorflow/docs_src/get_started/get_started.md b/tensorflow/docs_src/get_started/get_started.md
index 815b83e5fb..8eed9b5c5b 100644
--- a/tensorflow/docs_src/get_started/get_started.md
+++ b/tensorflow/docs_src/get_started/get_started.md
@@ -104,6 +104,7 @@ operations (Operations are also nodes). For example, we can add our two
constant nodes and produce a new graph as follows:
```python
+from __future__ import print_function
node3 = tf.add(node1, node2)
print("node3:", node3)
print("sess.run(node3):", sess.run(node3))
@@ -279,15 +280,14 @@ print(sess.run([W, b]))
```
results in the final model parameters:
```
-[array([-0.9999969], dtype=float32), array([ 0.99999082],
- dtype=float32)]
+[array([-0.9999969], dtype=float32), array([ 0.99999082], dtype=float32)]
```
-Now we have done actual machine learning! Although doing this simple linear
-regression doesn't require much TensorFlow core code, more complicated models
-and methods to feed data into your model necessitate more code. Thus TensorFlow
-provides higher level abstractions for common patterns, structures, and
-functionality. We will learn how to use some of these abstractions in the
+Now we have done actual machine learning! Although this simple linear
+regression model does not require much TensorFlow core code, more complicated
+models and methods to feed data into your models necessitate more code. Thus,
+TensorFlow provides higher level abstractions for common patterns, structures,
+and functionality. We will learn how to use some of these abstractions in the
next section.
### Complete program
@@ -354,9 +354,9 @@ Notice how much simpler the linear regression program becomes with
`tf.estimator`:
```python
-import tensorflow as tf
# NumPy is often used to load, manipulate and preprocess data.
import numpy as np
+import tensorflow as tf
# Declare list of features. We only have one numeric feature. There are many
# other types of columns that are more complicated and useful.
@@ -393,10 +393,10 @@ eval_metrics = estimator.evaluate(input_fn=eval_input_fn)
print("train metrics: %r"% train_metrics)
print("eval metrics: %r"% eval_metrics)
```
-When run, it produces
+When run, it produces something like
```
-train metrics: {'loss': 1.2712867e-09, 'global_step': 1000}
-eval metrics: {'loss': 0.0025279333, 'global_step': 1000}
+train metrics: {'average_loss': 1.4833182e-08, 'global_step': 1000, 'loss': 5.9332727e-08}
+eval metrics: {'average_loss': 0.0025353201, 'global_step': 1000, 'loss': 0.01014128}
```
Notice how our eval data has a higher loss, but it is still close to zero.
That means we are learning properly.
diff --git a/tensorflow/docs_src/install/index.md b/tensorflow/docs_src/install/index.md
index 05a1af320f..3df16139fb 100644
--- a/tensorflow/docs_src/install/index.md
+++ b/tensorflow/docs_src/install/index.md
@@ -4,7 +4,7 @@ The following guides explain how to install a version of TensorFlow
that enables you to write applications in Python:
* @{$install_linux$Installing TensorFlow on Ubuntu}
- * @{$install_mac$Installing TensorFlow on Mac OS X}
+ * @{$install_mac$Installing TensorFlow on macOS}
* @{$install_windows$Installing TensorFlow on Windows}
* @{$install_sources$Installing TensorFlow from Sources}
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index a961c6b1ce..43e09906f7 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -33,7 +33,7 @@ must be installed on your system:
`LD_LIBRARY_PATH` environment variable as described in the
NVIDIA documentation.
* The NVIDIA drivers associated with CUDA Toolkit 8.0.
- * cuDNN v5.1. For details, see
+ * cuDNN v6.0. For details, see
[NVIDIA's documentation](https://developer.nvidia.com/cudnn).
Ensure that you create the `CUDA_HOME` environment variable as
described in the NVIDIA documentation.
@@ -231,7 +231,7 @@ Python is automatically installed on Ubuntu. Take a moment to confirm
versions is already installed on your system:
* Python 2.7
- * Python 3.3+
+ * Python 3.4+
The pip or pip3 package manager is *usually* installed on Ubuntu. Take a
moment to confirm (by issuing a `pip -V` or `pip3 -V` command)
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index 6bae3c03d1..6552bff459 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -1,8 +1,8 @@
-# Installing TensorFlow on Mac OS X
+# Installing TensorFlow on macOS
-This guide explains how to install TensorFlow on Mac OS X.
+This guide explains how to install TensorFlow on macOS.
-Note: As of version 1.2, TensorFlow no longer provides GPU support on Mac OS X.
+Note: As of version 1.2, TensorFlow no longer provides GPU support on macOS.
## Determine how to install TensorFlow
@@ -15,7 +15,7 @@ You must pick the mechanism by which you install TensorFlow. The supported choic
[a separate guide](https://www.tensorflow.org/install/install_sources).
**We recommend the virtualenv installation.**
-[Virtualenv](https://virtualenv.pypa.io/en/stable/)
+[Virtualenv](https://virtualenv.pypa.io/en/stable)
is a virtual Python environment isolated from other Python development,
incapable of interfering with or being affected by other Python programs
on the same machine. During the virtualenv installation process,
@@ -33,7 +33,7 @@ to disable System Integrity Protection (SIP) in order to install through native
pip. However, if you understand SIP, pip, and your Python environment, a
native pip installation is relatively easy to perform.
-[Docker](http://docker.com/) completely isolates the TensorFlow installation
+[Docker](http://docker.com) completely isolates the TensorFlow installation
from pre-existing packages on your machine. The Docker container contains
TensorFlow and all its dependencies. Note that the Docker image can be quite
large (hundreds of MBs). You might choose the Docker installation if you are
@@ -58,7 +58,7 @@ Take the following steps to install TensorFlow with Virtualenv:
2. Install pip and virtualenv by issuing the following commands:
<pre> $ <b>sudo easy_install pip</b>
- $ <b>sudo pip install --upgrade virtualenv</b> </pre>
+ $ <b>pip install --upgrade virtualenv</b> </pre>
3. Create a virtualenv environment by issuing a command of one
of the following formats:
@@ -104,7 +104,7 @@ Take the following steps to install TensorFlow with Virtualenv:
Python version. Find the appropriate value for
<i>tfBinaryURL</i> for your system
[here](#the_url_of_the_tensorflow_python_package).
- For example, if you are installing TensorFlow for Mac OS X,
+ For example, if you are installing TensorFlow for macOS,
Python 2.7, the command to install
TensorFlow in the active Virtualenv is as follows:
diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md
index a06ab88046..be6a490ff9 100644
--- a/tensorflow/docs_src/install/install_windows.md
+++ b/tensorflow/docs_src/install/install_windows.md
@@ -29,7 +29,7 @@ installed on your system:
Ensure that you append the relevant Cuda pathnames to the `%PATH%`
environment variable as described in the NVIDIA documentation.
* The NVIDIA drivers associated with CUDA Toolkit 8.0.
- * cuDNN v5.1. For details, see
+ * cuDNN v6.1. For details, see
[NVIDIA's documentation](https://developer.nvidia.com/cudnn).
Note that cuDNN is typically installed in a different location from the
other CUDA DLLs. Ensure that you add the directory where you installed
@@ -40,7 +40,7 @@ installed on your system:
If you have a different version of one of the preceding packages, please
change to the specified versions. In particular, the cuDNN version
-must match exactly: TensorFlow will not load if it cannot find `cuDNN64_5.dll`.
+must match exactly: TensorFlow will not load if it cannot find `cuDNN64_6.dll`.
To use a different version of cuDNN, you must build from source.
## Determine how to install TensorFlow
@@ -76,7 +76,6 @@ install it now:
* [Python 3.5.x 64-bit from python.org](https://www.python.org/downloads/release/python-352/)
-TensorFlow only supports version 3.5.x of Python on Windows.
Note that Python 3.5.x comes with the pip3 package manager, which is the
program you'll use to install TensorFlow.
@@ -115,12 +114,12 @@ Take the following steps to install TensorFlow in an Anaconda environment:
environment. To install the CPU-only version of TensorFlow, enter the
following command:
- <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-1.3.0-cp35-cp35m-win_amd64.whl</b> </pre>
+ <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade tensorflow</b> </pre>
To install the GPU version of TensorFlow, enter the following command
(on a single line):
- <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-1.3.0-cp35-cp35m-win_amd64.whl</b> </pre>
+ <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade tensorflow-gpu</b> </pre>
## Validate your installation
diff --git a/tensorflow/docs_src/programmers_guide/datasets.md b/tensorflow/docs_src/programmers_guide/datasets.md
index 006245b44b..bf3cb5bf19 100644
--- a/tensorflow/docs_src/programmers_guide/datasets.md
+++ b/tensorflow/docs_src/programmers_guide/datasets.md
@@ -120,7 +120,7 @@ dataset3 = dataset3.filter(lambda x, (y, z): ...)
### Creating an iterator
-One you have built a `Dataset` to represent your input data, the next step is to
+Once you have built a `Dataset` to represent your input data, the next step is to
create an `Iterator` to access elements from that dataset. The `Dataset` API
currently supports the following iterators, in increasing level of
sophistication:
diff --git a/tensorflow/docs_src/programmers_guide/threading_and_queues.md b/tensorflow/docs_src/programmers_guide/threading_and_queues.md
index 313de178de..9d8a05c7dc 100644
--- a/tensorflow/docs_src/programmers_guide/threading_and_queues.md
+++ b/tensorflow/docs_src/programmers_guide/threading_and_queues.md
@@ -60,7 +60,7 @@ prepare inputs for training a model as follows:
We recommend using the @{tf.contrib.data.Dataset.shuffle$`shuffle`}
and @{tf.contrib.data.Dataset.batch$`batch`} methods of a
-@{tf.contrib.data.Dataset$`Dataset`} to acomplish this. However, if you'd prefer
+@{tf.contrib.data.Dataset$`Dataset`} to accomplish this. However, if you'd prefer
to use a queue-based version instead, you can find a full implementation in the
@{tf.train.shuffle_batch} function.
diff --git a/tensorflow/docs_src/programmers_guide/variables.md b/tensorflow/docs_src/programmers_guide/variables.md
index b265dbbe3e..f310b89380 100644
--- a/tensorflow/docs_src/programmers_guide/variables.md
+++ b/tensorflow/docs_src/programmers_guide/variables.md
@@ -266,7 +266,7 @@ calling this function repeatedly would not work:
``` python
input1 = tf.random_normal([1,10,10,32])
input2 = tf.random_normal([1,20,20,32])
-x = conv_relu(input1, kernel_shape=[5, 5, 1, 32], bias_shape=[32])
+x = conv_relu(input1, kernel_shape=[5, 5, 32, 32], bias_shape=[32])
x = conv_relu(x, kernel_shape=[5, 5, 32, 32], bias_shape = [32]) # This fails.
```
@@ -278,7 +278,7 @@ however, clarifies that we want to create new variables:
def my_image_filter(input_images):
with tf.variable_scope("conv1"):
# Variables created here will be named "conv1/weights", "conv1/biases".
- relu1 = conv_relu(input_images, [5, 5, 1, 32], [32])
+ relu1 = conv_relu(input_images, [5, 5, 32, 32], [32])
with tf.variable_scope("conv2"):
# Variables created here will be named "conv2/weights", "conv2/biases".
return conv_relu(relu1, [5, 5, 32, 32], [32])
diff --git a/tensorflow/docs_src/tutorials/deep_cnn.md b/tensorflow/docs_src/tutorials/deep_cnn.md
index 591b8ea6aa..b57ef24f58 100644
--- a/tensorflow/docs_src/tutorials/deep_cnn.md
+++ b/tensorflow/docs_src/tutorials/deep_cnn.md
@@ -11,8 +11,8 @@ problem is to classify RGB 32x32 pixel images across 10 categories:
airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck.
```
-For more details refer to the [CIFAR-10 page](http://www.cs.toronto.edu/~kriz/cifar.html)
-and a [Tech Report](http://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf)
+For more details refer to the [CIFAR-10 page](https://www.cs.toronto.edu/~kriz/cifar.html)
+and a [Tech Report](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf)
by Alex Krizhevsky.
### Goals
@@ -42,7 +42,7 @@ designing larger and more sophisticated models in TensorFlow:
([wiki](https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer))
and @{tf.nn.local_response_normalization$local response normalization}
(Chapter 3.3 in
-[AlexNet paper](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)).
+[AlexNet paper](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)).
* @{$summaries_and_tensorboard$Visualization}
of network activities during training, including input images,
losses and distributions of activations and gradients.
diff --git a/tensorflow/docs_src/tutorials/image_recognition.md b/tensorflow/docs_src/tutorials/image_recognition.md
index 88ae451cd5..ddb771700a 100644
--- a/tensorflow/docs_src/tutorials/image_recognition.md
+++ b/tensorflow/docs_src/tutorials/image_recognition.md
@@ -8,7 +8,7 @@ seem easy because our brains are incredibly good at understanding images.
In the last few years the field of machine learning has made tremendous
progress on addressing these difficult problems. In particular, we've
found that a kind of model called a deep
-[convolutional neural network](http://colah.github.io/posts/2014-07-Conv-Nets-Modular/)
+[convolutional neural network](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/)
can achieve reasonable performance on hard visual recognition tasks --
matching or exceeding human performance in some domains.
@@ -23,11 +23,11 @@ these models but the results are still hard to reproduce.
We're now taking the next step by releasing code for running image recognition
on our latest model, [Inception-v3].
-[QuocNet]: http://static.googleusercontent.com/media/research.google.com/en//archive/unsupervised_icml2012.pdf
-[AlexNet]: http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf
-[Inception (GoogLeNet)]: http://arxiv.org/abs/1409.4842
-[BN-Inception-v2]: http://arxiv.org/abs/1502.03167
-[Inception-v3]: http://arxiv.org/abs/1512.00567
+[QuocNet]: https://static.googleusercontent.com/media/research.google.com/en//archive/unsupervised_icml2012.pdf
+[AlexNet]: https://www.cs.toronto.edu/~fritz/absps/imagenet.pdf
+[Inception (GoogLeNet)]: https://arxiv.org/abs/1409.4842
+[BN-Inception-v2]: https://arxiv.org/abs/1502.03167
+[Inception-v3]: https://arxiv.org/abs/1512.00567
Inception-v3 is trained for the [ImageNet] Large Visual Recognition Challenge
using the data from 2012. This is a standard task in computer vision,
@@ -51,7 +51,7 @@ Andrej Karpathy who attempted to measure his own performance. He reached
[ImageNet]: http://image-net.org/
[1000 classes]: http://image-net.org/challenges/LSVRC/2014/browse-synsets
-[blog post]: http://karpathy.github.io/2014/09/02/what-i-learned-from-competing-against-a-convnet-on-imagenet/
+[blog post]: https://karpathy.github.io/2014/09/02/what-i-learned-from-competing-against-a-convnet-on-imagenet/
This tutorial will teach you how to use [Inception-v3]. You'll learn how to
classify images into [1000 classes] in Python or C++. We'll also discuss how to
@@ -433,7 +433,7 @@ TensorFlow within your own products.
should be able to transfer some of that understanding to solving related
problems. One way to perform transfer learning is to remove the final
classification layer of the network and extract
-the [next-to-last layer of the CNN](http://arxiv.org/abs/1310.1531), in this case a 2048 dimensional vector.
+the [next-to-last layer of the CNN](https://arxiv.org/abs/1310.1531), in this case a 2048 dimensional vector.
There's a guide to doing this @{$image_retraining$in the how-to section}.
@@ -443,7 +443,7 @@ To learn about neural networks in general, Michael Nielsen's
[free online book](http://neuralnetworksanddeeplearning.com/chap1.html)
is an excellent resource. For convolutional neural networks in particular,
Chris Olah has some
-[nice blog posts](http://colah.github.io/posts/2014-07-Conv-Nets-Modular/),
+[nice blog posts](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),
and Michael Nielsen's book has a
[great chapter](http://neuralnetworksanddeeplearning.com/chap6.html)
covering them.
diff --git a/tensorflow/docs_src/tutorials/image_retraining.md b/tensorflow/docs_src/tutorials/image_retraining.md
index b0e715edcb..90652ac405 100644
--- a/tensorflow/docs_src/tutorials/image_retraining.md
+++ b/tensorflow/docs_src/tutorials/image_retraining.md
@@ -6,7 +6,7 @@ work by taking a fully-trained model for a set of categories like ImageNet, and
retrains from the existing weights for new classes. In this example we'll be
retraining the final layer from scratch, while leaving all the others untouched.
For more information on the approach you can see
-[this paper on Decaf](http://arxiv.org/pdf/1310.1531v1.pdf).
+[this paper on Decaf](https://arxiv.org/pdf/1310.1531v1.pdf).
Though it's not as good as a full training run, this is surprisingly effective
for many applications, and can be run in as little as thirty minutes on a
@@ -213,7 +213,7 @@ the object you actually care about. To avoid this, try to take pictures in as
wide a variety of situations as you can, at different times, and with different
devices. If you want to know more about this problem, you can read about the
classic (and possibly apocryphal)
-[tank recognition problem](http://www.jefftk.com/p/detecting-tanks).
+[tank recognition problem](https://www.jefftk.com/p/detecting-tanks).
You may also want to think about the categories you use. It might be worth
splitting big categories that cover a lot of different physical forms into
diff --git a/tensorflow/docs_src/tutorials/kernel_methods.md b/tensorflow/docs_src/tutorials/kernel_methods.md
index 8506b5228e..324c34fdfa 100644
--- a/tensorflow/docs_src/tutorials/kernel_methods.md
+++ b/tensorflow/docs_src/tutorials/kernel_methods.md
@@ -14,7 +14,7 @@ Machines (SVMs). If you are new to kernel methods, refer to either of the
following sources for an introduction:
* If you have a strong mathematical background:
-[Kernel Methods in Machine Learning](http://www.kernel-machines.org/publications/pdfs/0701907.pdf)
+[Kernel Methods in Machine Learning](https://arxiv.org/pdf/math/0701907.pdf)
* [Kernel method wikipedia page](https://en.wikipedia.org/wiki/Kernel_method)
Currently, TensorFlow supports explicit kernel mappings for dense features only;
diff --git a/tensorflow/docs_src/tutorials/layers.md b/tensorflow/docs_src/tutorials/layers.md
index acf33afe6d..0815cc2a17 100644
--- a/tensorflow/docs_src/tutorials/layers.md
+++ b/tensorflow/docs_src/tutorials/layers.md
@@ -79,7 +79,7 @@ relative measurements of how likely it is that the image falls into each target
class.
> Note: For a more comprehensive walkthrough of CNN architecture, see Stanford
-> University's <a href="http://cs231n.github.io/convolutional-networks/">
+> University's <a href="https://cs231n.github.io/convolutional-networks/">
> Convolutional Neural Networks for Visual Recognition course materials</a>.</p>
## Building the CNN MNIST Classifier {#building_the_cnn_mnist_classifier}
diff --git a/tensorflow/docs_src/tutorials/recurrent.md b/tensorflow/docs_src/tutorials/recurrent.md
index 346b6be06c..73d40575d7 100644
--- a/tensorflow/docs_src/tutorials/recurrent.md
+++ b/tensorflow/docs_src/tutorials/recurrent.md
@@ -2,7 +2,7 @@
## Introduction
-Take a look at [this great article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
+Take a look at [this great article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for an introduction to recurrent neural networks and LSTMs in particular.
## Language Modeling
@@ -17,11 +17,11 @@ models, whilst being small and relatively fast to train.
Language modeling is key to many interesting problems such as speech
recognition, machine translation, or image captioning. It is also fun --
-take a look [here](http://karpathy.github.io/2015/05/21/rnn-effectiveness/).
+take a look [here](https://karpathy.github.io/2015/05/21/rnn-effectiveness/).
For the purpose of this tutorial, we will reproduce the results from
-[Zaremba et al., 2014](http://arxiv.org/abs/1409.2329)
-([pdf](http://arxiv.org/pdf/1409.2329.pdf)), which achieves very good quality
+[Zaremba et al., 2014](https://arxiv.org/abs/1409.2329)
+([pdf](https://arxiv.org/pdf/1409.2329.pdf)), which achieves very good quality
on the PTB dataset.
## Tutorial Files
diff --git a/tensorflow/docs_src/tutorials/seq2seq.md b/tensorflow/docs_src/tutorials/seq2seq.md
index 84c8a9c9f3..0dcb23d4aa 100644
--- a/tensorflow/docs_src/tutorials/seq2seq.md
+++ b/tensorflow/docs_src/tutorials/seq2seq.md
@@ -37,8 +37,8 @@ File | What's in it?
## Sequence-to-sequence basics
A basic sequence-to-sequence model, as introduced in
-[Cho et al., 2014](http://arxiv.org/abs/1406.1078)
-([pdf](http://arxiv.org/pdf/1406.1078.pdf)), consists of two recurrent neural
+[Cho et al., 2014](https://arxiv.org/abs/1406.1078)
+([pdf](https://arxiv.org/pdf/1406.1078.pdf)), consists of two recurrent neural
networks (RNNs): an *encoder* that processes the input and a *decoder* that
generates the output. This basic architecture is depicted below.
@@ -51,14 +51,14 @@ a GRU cell or an LSTM cell (see the @{$recurrent$RNN Tutorial}
for an explanation of those). Encoder and decoder can share weights or,
as is more common, use a different set of parameters. Multi-layer cells
have been successfully used in sequence-to-sequence models too, e.g. for
-translation [Sutskever et al., 2014](http://arxiv.org/abs/1409.3215)
-([pdf](http://arxiv.org/pdf/1409.3215.pdf)).
+translation [Sutskever et al., 2014](https://arxiv.org/abs/1409.3215)
+([pdf](https://arxiv.org/pdf/1409.3215.pdf)).
In the basic model depicted above, every input has to be encoded into
a fixed-size state vector, as that is the only thing passed to the decoder.
To allow the decoder more direct access to the input, an *attention* mechanism
-was introduced in [Bahdanau et al., 2014](http://arxiv.org/abs/1409.0473)
-([pdf](http://arxiv.org/pdf/1409.0473.pdf)).
+was introduced in [Bahdanau et al., 2014](https://arxiv.org/abs/1409.0473)
+([pdf](https://arxiv.org/pdf/1409.0473.pdf)).
We will not go into the details of the attention mechanism (see the paper);
suffice it to say that it allows the decoder to peek into the input at every
decoding step. A multi-layer sequence-to-sequence network with LSTM cells and
@@ -129,8 +129,8 @@ All other tensors from this list would be ignored, and instead the previous
output of the decoder would be used. This is used for decoding translations
in our translation model, but it can also be used during training, to make
the model more robust to its own mistakes, similar
-to [Bengio et al., 2015](http://arxiv.org/abs/1506.03099)
-([pdf](http://arxiv.org/pdf/1506.03099.pdf)).
+to [Bengio et al., 2015](https://arxiv.org/abs/1506.03099)
+([pdf](https://arxiv.org/pdf/1506.03099.pdf)).
One more important argument used above is `output_projection`. If not specified,
the outputs of the embedding model will be tensors of shape batch-size by
@@ -140,8 +140,8 @@ When training models with large output vocabularies, i.e., when
tensors. Instead, it is better to return smaller output tensors, which will
later be projected onto a large output tensor using `output_projection`.
This allows to use our seq2seq models with a sampled softmax loss, as described
-in [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
-([pdf](http://arxiv.org/pdf/1412.2007.pdf)).
+in [Jean et al., 2014](https://arxiv.org/abs/1412.2007)
+([pdf](https://arxiv.org/pdf/1412.2007.pdf)).
In addition to `basic_rnn_seq2seq` and `embedding_rnn_seq2seq` there are a few
more sequence-to-sequence models in `seq2seq.py`; take a look there. They all
@@ -243,8 +243,8 @@ Remember that when constructing decoder inputs we prepend the special `GO`
symbol to the input data. This is done in the `get_batch()` function in
`seq2seq_model.py`, which also reverses the input English sentence.
Reversing the inputs was shown to improve results for the neural translation
-model in [Sutskever et al., 2014](http://arxiv.org/abs/1409.3215)
-([pdf](http://arxiv.org/pdf/1409.3215.pdf)).
+model in [Sutskever et al., 2014](https://arxiv.org/abs/1409.3215)
+([pdf](https://arxiv.org/pdf/1409.3215.pdf)).
To put it all together, imagine we have the sentence "I go.", tokenized
as `["I", "go", "."]` as input and the sentence "Je vais." as output,
tokenized `["Je", "vais", "."]`. It will be put in the (5, 10) bucket,
@@ -348,7 +348,7 @@ Finally, the model presented above can be used for any sequence-to-sequence
task, not only for translation. Even if you want to transform a sequence to
a tree, for example to generate a parsing tree, the same model as above can
give state-of-the-art results, as demonstrated in
-[Vinyals & Kaiser et al., 2014](http://arxiv.org/abs/1412.7449)
-([pdf](http://arxiv.org/pdf/1412.7449.pdf)).
+[Vinyals & Kaiser et al., 2014](https://arxiv.org/abs/1412.7449)
+([pdf](https://arxiv.org/pdf/1412.7449.pdf)).
So you can not only build your own translator, you can also build a parser,
a chat-bot, or any program that comes to your mind. Experiment!
diff --git a/tensorflow/docs_src/tutorials/using_gpu.md b/tensorflow/docs_src/tutorials/using_gpu.md
index b6edbe3345..de8d88ce76 100644
--- a/tensorflow/docs_src/tutorials/using_gpu.md
+++ b/tensorflow/docs_src/tutorials/using_gpu.md
@@ -83,7 +83,7 @@ MatMul: /job:localhost/replica:0/task:0/device:GPU:0
## Allowing GPU memory growth
By default, TensorFlow maps nearly all of the GPU memory of all GPUs (subject to
-[`CUDA_VISIBLE_DEVICES`](http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars))
+[`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars))
visible to the process. This is done to more efficiently use the relatively
precious GPU memory resources on the devices by reducing [memory
fragmentation](https://en.wikipedia.org/wiki/Fragmentation_\(computing\)).
diff --git a/tensorflow/docs_src/tutorials/wide.md b/tensorflow/docs_src/tutorials/wide.md
index 3571a55a2e..3055c54021 100644
--- a/tensorflow/docs_src/tutorials/wide.md
+++ b/tensorflow/docs_src/tutorials/wide.md
@@ -33,7 +33,7 @@ To try the code for this tutorial:
$ pip install -U pandas
If you have trouble installing pandas, consult the
- [instructions](http://pandas.pydata.org/pandas-docs/stable/install.html)
+ [instructions](https://pandas.pydata.org/pandas-docs/stable/install.html)
on the pandas site.
4. Execute the tutorial code with the following command to train the linear
@@ -62,7 +62,7 @@ urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/ad
```
Once the CSV files are downloaded, let's read them into
-[Pandas](http://pandas.pydata.org/) dataframes.
+[Pandas](https://pandas.pydata.org/) dataframes.
```python
import pandas as pd
diff --git a/tensorflow/docs_src/tutorials/wide_and_deep.md b/tensorflow/docs_src/tutorials/wide_and_deep.md
index e6344405d5..16f7925e8d 100644
--- a/tensorflow/docs_src/tutorials/wide_and_deep.md
+++ b/tensorflow/docs_src/tutorials/wide_and_deep.md
@@ -15,7 +15,7 @@ combines the strengths of memorization and generalization. It's useful for
generic large-scale regression and classification problems with sparse input
features (e.g., categorical features with a large number of possible feature
values). If you're interested in learning more about how Wide & Deep Learning
-works, please check out our [research paper](http://arxiv.org/abs/1606.07792).
+works, please check out our [research paper](https://arxiv.org/abs/1606.07792).
![Wide & Deep Spectrum of Models](https://www.tensorflow.org/images/wide_n_deep.svg "Wide & Deep")
@@ -58,7 +58,7 @@ To try the code for this tutorial:
$ sudo pip install pandas
If you have trouble installing pandas, consult the
- [instructions](http://pandas.pydata.org/pandas-docs/stable/install.html)
+ [instructions](https://pandas.pydata.org/pandas-docs/stable/install.html)
on the pandas site.
4. Execute the tutorial code with the following command to train the linear
@@ -320,5 +320,5 @@ Note that this tutorial is just a quick example on a small dataset to get you
familiar with the API. Wide & Deep Learning will be even more powerful if you
try it on a large dataset with many sparse feature columns that have a large
number of possible feature values. Again, feel free to take a look at our
-[research paper](http://arxiv.org/abs/1606.07792) for more ideas about how to
+[research paper](https://arxiv.org/abs/1606.07792) for more ideas about how to
apply Wide & Deep Learning in real-world large-scale machine learning problems.
diff --git a/tensorflow/docs_src/tutorials/word2vec.md b/tensorflow/docs_src/tutorials/word2vec.md
index 8e7c19035e..0a1c41c84a 100644
--- a/tensorflow/docs_src/tutorials/word2vec.md
+++ b/tensorflow/docs_src/tutorials/word2vec.md
@@ -1,7 +1,7 @@
# Vector Representations of Words
In this tutorial we look at the word2vec model by
-[Mikolov et al.](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf)
+[Mikolov et al.](https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf)
This model is used for learning vector representations of words, called "word
embeddings".
@@ -78,7 +78,7 @@ model).
Word2vec is a particularly computationally-efficient predictive model for
learning word embeddings from raw text. It comes in two flavors, the Continuous
-Bag-of-Words model (CBOW) and the Skip-Gram model (Section 3.1 and 3.2 in [Mikolov et al.](http://arxiv.org/pdf/1301.3781.pdf)). Algorithmically, these
+Bag-of-Words model (CBOW) and the Skip-Gram model (Section 3.1 and 3.2 in [Mikolov et al.](https://arxiv.org/pdf/1301.3781.pdf)). Algorithmically, these
models are similar, except that CBOW predicts target words (e.g. 'mat') from
source context words ('the cat sits on the'), while the skip-gram does the
inverse and predicts source context-words from the target words. This inversion
@@ -155,14 +155,14 @@ from the noise distribution (i.e. we compute a
This objective is maximized when the model assigns high probabilities
to the real words, and low probabilities to noise words. Technically, this is
called
-[Negative Sampling](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf),
+[Negative Sampling](https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf),
and there is good mathematical motivation for using this loss function:
The updates it proposes approximate the updates of the softmax function in the
limit. But computationally it is especially appealing because computing the
loss function now scales only with the number of *noise words* that we
select (\\(k\\)), and not *all words* in the vocabulary (\\(V\\)). This makes it
much faster to train. We will actually make use of the very similar
-[noise-contrastive estimation (NCE)](http://papers.nips.cc/paper/5165-learning-word-embeddings-efficiently-with-noise-contrastive-estimation.pdf)
+[noise-contrastive estimation (NCE)](https://papers.nips.cc/paper/5165-learning-word-embeddings-efficiently-with-noise-contrastive-estimation.pdf)
loss, for which TensorFlow has a handy helper function `tf.nn.nce_loss()`.
Let's get an intuitive feel for how this would work in practice!
@@ -222,7 +222,7 @@ successful at discriminating real words from noise words.
We can visualize the learned vectors by projecting them down to 2 dimensions
using for instance something like the
-[t-SNE dimensionality reduction technique](http://lvdmaaten.github.io/tsne/).
+[t-SNE dimensionality reduction technique](https://lvdmaaten.github.io/tsne/).
When we inspect these visualizations it becomes apparent that the vectors
capture some general, and in fact quite useful, semantic information about
words and their relationships to one another. It was very interesting when we
@@ -230,7 +230,7 @@ first discovered that certain directions in the induced vector space specialize
towards certain semantic relationships, e.g. *male-female*, *verb tense* and
even *country-capital* relationships between words, as illustrated in the figure
below (see also for example
-[Mikolov et al., 2013](http://www.aclweb.org/anthology/N13-1090)).
+[Mikolov et al., 2013](https://www.aclweb.org/anthology/N13-1090)).
<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/linear-relationships.png" alt>
@@ -239,9 +239,9 @@ below (see also for example
This explains why these vectors are also useful as features for many canonical
NLP prediction tasks, such as part-of-speech tagging or named entity recognition
(see for example the original work by
-[Collobert et al., 2011](http://arxiv.org/abs/1103.0398)
-([pdf](http://arxiv.org/pdf/1103.0398.pdf)), or follow-up work by
-[Turian et al., 2010](http://www.aclweb.org/anthology/P10-1040)).
+[Collobert et al., 2011](https://arxiv.org/abs/1103.0398)
+([pdf](https://arxiv.org/pdf/1103.0398.pdf)), or follow-up work by
+[Turian et al., 2010](https://www.aclweb.org/anthology/P10-1040)).
But for now, let's just use them to draw pretty pictures!
@@ -351,7 +351,7 @@ to evaluate embeddings is to directly use them to predict syntactic and semantic
relationships like `king is to queen as father is to ?`. This is called
*analogical reasoning* and the task was introduced by
[Mikolov and colleagues
-](http://www.anthology.aclweb.org/N/N13/N13-1090.pdf).
+](https://www.aclweb.org/anthology/N13-1090).
Download the dataset for this task from
[download.tensorflow.org](http://download.tensorflow.org/data/questions-words.txt).
diff --git a/tensorflow/examples/android/build.gradle b/tensorflow/examples/android/build.gradle
index e97faad516..48f566f825 100644
--- a/tensorflow/examples/android/build.gradle
+++ b/tensorflow/examples/android/build.gradle
@@ -79,7 +79,7 @@ android {
if (nativeBuildSystem == 'cmake') {
defaultConfig {
- applicationId = 'com.tensorflow.demo'
+ applicationId = 'org.tensorflow.demo'
minSdkVersion 21
targetSdkVersion 23
ndk {
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
index 6c93617ae5..d73b1c6373 100644
--- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
@@ -78,10 +78,8 @@ def build_dataset(words, n_words):
data = list()
unk_count = 0
for word in words:
- if word in dictionary:
- index = dictionary[word]
- else:
- index = 0 # dictionary['UNK']
+ index = dictionary.get(word, 0)
+ if index == 0: # dictionary['UNK']
unk_count += 1
data.append(index)
count[0][1] = unk_count
@@ -110,14 +108,13 @@ def generate_batch(batch_size, num_skips, skip_window):
buffer.extend(data[data_index:data_index + span])
data_index += span
for i in range(batch_size // num_skips):
- target = skip_window # target label at the center of the buffer
- targets_to_avoid = [skip_window]
+ context_words = [w for w in range(span) if w != skip_window]
+ random.shuffle(context_words)
+ words_to_use = collections.deque(context_words)
for j in range(num_skips):
- while target in targets_to_avoid:
- target = random.randint(0, span - 1)
- targets_to_avoid.append(target)
batch[i * num_skips + j] = buffer[skip_window]
- labels[i * num_skips + j, 0] = buffer[target]
+ context_word = words_to_use.pop()
+ labels[i * num_skips + j, 0] = buffer[context_word]
if data_index == len(data):
buffer[:] = data[:span]
data_index = span
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index c08fa9b145..58ad3ab193 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -15,6 +15,8 @@ limitations under the License.
package org.tensorflow;
+import java.util.Iterator;
+
/**
* A data flow graph representing a TensorFlow computation.
*
@@ -77,6 +79,16 @@ public final class Graph implements AutoCloseable {
}
/**
+ * Iterator over all the {@link Operation}s in the graph.
+ *
+ * The order of iteration is unspecified. Consumers of the iterator will received no notification
+ * should the underlying graph change during iteration.
+ */
+ public Iterator<Operation> operations() {
+ return new OperationIterator(this);
+ }
+
+ /**
* Returns a builder to add {@link Operation}s to the Graph.
*
* @param type of the Operation (i.e., identifies the computation to be performed)
@@ -179,12 +191,63 @@ public final class Graph implements AutoCloseable {
return new Reference();
}
+ private static final class OperationIterator implements Iterator<Operation> {
+
+ OperationIterator(Graph g) {
+ this.graph = g;
+ this.operation = null;
+ this.position = 0;
+ this.advance();
+ }
+
+ private final void advance() {
+ Graph.Reference reference = this.graph.ref();
+
+ this.operation = null;
+
+ try {
+ long[] nativeReturn = nextOperation(reference.nativeHandle(), this.position);
+
+ if ((nativeReturn != null) && (nativeReturn[0] != 0)) {
+ this.operation = new Operation(this.graph, nativeReturn[0]);
+ this.position = (int) nativeReturn[1];
+ }
+ } finally {
+ reference.close();
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ return (this.operation != null);
+ }
+
+ @Override
+ public Operation next() {
+ Operation rhett = this.operation;
+ this.advance();
+ return rhett;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException("remove() is unsupported.");
+ }
+
+ private final Graph graph;
+ private Operation operation;
+ private int position;
+ }
+
private static native long allocate();
private static native void delete(long handle);
private static native long operation(long handle, String name);
+ // This method returns the Operation native handle at index 0 and the new value for pos at index 1 (see TF_GraphNextOperation)
+ private static native long[] nextOperation(long handle, int position);
+
private static native void importGraphDef(long handle, byte[] graphDef, String prefix)
throws IllegalArgumentException;
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc
index 8e9187b437..0fef155275 100644
--- a/tensorflow/java/src/main/native/graph_jni.cc
+++ b/tensorflow/java/src/main/native/graph_jni.cc
@@ -54,6 +54,26 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv* env,
return reinterpret_cast<jlong>(op);
}
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint position) {
+ TF_Graph* g = requireHandle(env, handle);
+ if (g == nullptr) return nullptr;
+
+ size_t pos = static_cast<size_t>(position);
+ TF_Operation* operation = TF_GraphNextOperation(g, &pos);
+ if (operation == nullptr) return nullptr;
+
+ jlong handle_and_position[2];
+ handle_and_position[0] = reinterpret_cast<jlong>(operation);
+ handle_and_position[1] = static_cast<jlong>(pos);
+
+ jlongArray rhett = env->NewLongArray(2);
+ env->SetLongArrayRegion(rhett, 0, 2, handle_and_position);
+ return rhett;
+}
+
JNIEXPORT void JNICALL Java_org_tensorflow_Graph_importGraphDef(
JNIEnv* env, jclass clazz, jlong handle, jbyteArray graph_def,
jstring prefix) {
diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h
index b84c11578e..dd2e038332 100644
--- a/tensorflow/java/src/main/native/graph_jni.h
+++ b/tensorflow/java/src/main/native/graph_jni.h
@@ -47,6 +47,15 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv *, jclass,
/*
* Class: org_tensorflow_Graph
+ * Method: operations
+ * Signature: (JI)[J
+ */
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(JNIEnv *,
+ jclass, jlong,
+ jint);
+
+/*
+ * Class: org_tensorflow_Graph
* Method: importGraphDef
* Signature: (J[BLjava/lang/String;)V
*/
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index f6dc3ee1e9..4adc861bf1 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -16,7 +16,12 @@ limitations under the License.
package org.tensorflow;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import java.util.HashSet;
+import java.util.Iterator;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -71,6 +76,34 @@ public class GraphTest {
}
@Test
+ public void iterateOverOperations() {
+ try (Graph g = new Graph()) {
+ Iterator<Operation> iterator = g.operations();
+ HashSet<Operation> operations;
+
+ assertFalse(iterator.hasNext());
+
+ operations = new HashSet<>();
+ operations.add(TestUtil.constant(g, "Const-A", Float.valueOf(1.0f)).op());
+ operations.add(TestUtil.constant(g, "Const-B", Integer.valueOf(23)).op());
+ operations.add(TestUtil.constant(g, "Const-C", Double.valueOf(1.618)).op());
+
+ iterator = g.operations();
+
+ assertTrue(iterator.hasNext());
+ assertTrue(operations.remove(iterator.next()));
+
+ assertTrue(iterator.hasNext());
+ assertTrue(operations.remove(iterator.next()));
+
+ assertTrue(iterator.hasNext());
+ assertTrue(operations.remove(iterator.next()));
+
+ assertFalse(iterator.hasNext());
+ }
+ }
+
+ @Test
public void failImportOnInvalidGraphDefs() {
try (Graph g = new Graph()) {
try {
diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
index 149425436a..d7fe4bbfa1 100644
--- a/tensorflow/python/estimator/inputs/queues/feeding_functions.py
+++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
@@ -47,12 +47,13 @@ except ImportError:
def _fill_array(arr, seq, fillvalue=0):
- """Recursively fills padded arr with elements from seq.
-
+ """
+ Recursively fills padded arr with elements from seq.
If lenght of seq is less then arr padded length, fillvalue used.
+
Args:
arr: Padded tensor of shape [batch_size, ..., max_padded_dim_len].
- seq: Non-padded list of data sampels of shape
+ seq: Non-padded list of data sampels of shape
[batch_size, ..., padded_dim(None)]
fillvalue: Default fillvalue to use.
"""
@@ -83,23 +84,21 @@ def _pad_if_needed(batch_key_item, fillvalue=0):
Raises:
ValueError if data samples have different shapes (except last padded dim).
"""
- shapes = [
- seq.shape[:-1] if len(seq.shape) > 0 else -1 for seq in batch_key_item
- ]
+ shapes = [seq.shape[:-1] if len(seq.shape) > 0 else -1
+ for seq in batch_key_item]
if not all(shapes[0] == x for x in shapes):
raise ValueError("Array shapes must match.")
- last_length = [
- seq.shape[-1] if len(seq.shape) > 0 else 0 for seq in batch_key_item
- ]
+ last_length = [seq.shape[-1] if len(seq.shape) > 0 else 0
+ for seq in batch_key_item]
if all([x == last_length[0] for x in last_length]):
return batch_key_item
batch_size = len(batch_key_item)
max_sequence_length = max(last_length)
result_batch = np.zeros(
- shape=[batch_size] + list(shapes[0]) + [max_sequence_length],
- dtype=batch_key_item[0].dtype)
+ shape=[batch_size] + list(shapes[0]) + [max_sequence_length],
+ dtype=batch_key_item[0].dtype)
_fill_array(result_batch, batch_key_item, fillvalue)
return result_batch
@@ -326,15 +325,11 @@ class _GeneratorFeedFn(object):
list_dict_size += 1
if self._pad_value is not None:
- feed_dict = {
- key: np.asarray(_pad_if_needed(item, self._pad_value))
- for key, item in list(list_dict.items())
- }
+ feed_dict = {key: np.asarray(_pad_if_needed(item, self._pad_value))
+ for key, item in list(list_dict.items())}
else:
- feed_dict = {
- key: np.asarray(item)
- for key, item in list(list_dict.items())
- }
+ feed_dict = {key: np.asarray(item)
+ for key, item in list(list_dict.items())}
return feed_dict
@@ -380,7 +375,7 @@ def _enqueue_data(data,
arrays, a numpy `ndarray`, or a generator producing these.
NotImplementedError: padding and shuffling data at the same time.
NotImplementedError: padding usage with non generator data type.
- """
+ """
with ops.name_scope(name):
if isinstance(data, np.ndarray):
types = [dtypes.int64, dtypes.as_dtype(data.dtype)]
@@ -452,11 +447,11 @@ def _enqueue_data(data,
seed=seed)
elif pad_data:
min_after_dequeue = 0 # just for the summary text
- queue_shapes = list(
- map(lambda x: tuple(list(x[:-1]) + [None]) if len(x) > 0 else x,
- queue_shapes))
+ queue_shapes = list(map(
+ lambda x: tuple(list(x[:-1]) + [None]) if len(x) > 0 else x,
+ queue_shapes))
queue = data_flow_ops.PaddingFIFOQueue(
- capacity, dtypes=types, shapes=queue_shapes)
+ capacity, dtypes=types, shapes=queue_shapes)
else:
min_after_dequeue = 0 # just for the summary text
queue = data_flow_ops.FIFOQueue(
@@ -475,23 +470,23 @@ def _enqueue_data(data,
if not pad_data:
feed_fns.append(
- get_feed_fn(
- placeholders,
- data,
- enqueue_size,
- random_start=shuffle,
- seed=seed_i,
- num_epochs=num_epochs))
+ get_feed_fn(
+ placeholders,
+ data,
+ enqueue_size,
+ random_start=shuffle,
+ seed=seed_i,
+ num_epochs=num_epochs))
else:
feed_fns.append(
- get_feed_fn(
- placeholders,
- data,
- enqueue_size,
- random_start=shuffle,
- seed=seed_i,
- num_epochs=num_epochs,
- pad_value=pad_value))
+ get_feed_fn(
+ placeholders,
+ data,
+ enqueue_size,
+ random_start=shuffle,
+ seed=seed_i,
+ num_epochs=num_epochs,
+ pad_value=pad_value))
runner = fqr._FeedingQueueRunner( # pylint: disable=protected-access
queue=queue, enqueue_ops=enqueue_ops, feed_fns=feed_fns)
diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions_test.py b/tensorflow/python/estimator/inputs/queues/feeding_functions_test.py
index 3508f5aa3c..30abd82130 100644
--- a/tensorflow/python/estimator/inputs/queues/feeding_functions_test.py
+++ b/tensorflow/python/estimator/inputs/queues/feeding_functions_test.py
@@ -291,8 +291,8 @@ class _FeedingFunctionsTestCase(test.TestCase):
self.assertEqual(expected, vals_to_list(actual))
def testFillArraySmall(self):
- a = (np.ones(shape=[32, 32], dtype=np.int32).tolist() + np.ones(
- shape=[32, 36], dtype=np.int32).tolist())
+ a = (np.ones(shape=[32, 32], dtype=np.int32).tolist() +
+ np.ones(shape=[32, 36], dtype=np.int32).tolist())
actual = np.ones(shape=[64, 36], dtype=np.int32)
ff._fill_array(actual, a)
expected = np.ones(shape=[64, 36], dtype=np.int32)
@@ -300,8 +300,8 @@ class _FeedingFunctionsTestCase(test.TestCase):
self.assertEqual(expected.tolist(), actual.tolist())
def testFillArrayLarge(self):
- a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() + np.ones(
- shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())
+ a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() +
+ np.ones(shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())
actual = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)
ff._fill_array(actual, a)
expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)
@@ -310,8 +310,8 @@ class _FeedingFunctionsTestCase(test.TestCase):
def testFillArraySmallWithSpecifiedValue(self):
fill_value = 8
- a = (np.ones(shape=[32, 32], dtype=np.int32).tolist() + np.ones(
- shape=[32, 36], dtype=np.int32).tolist())
+ a = (np.ones(shape=[32, 32], dtype=np.int32).tolist() +
+ np.ones(shape=[32, 36], dtype=np.int32).tolist())
actual = np.ones(shape=[64, 36], dtype=np.int32)
ff._fill_array(actual, a, fill_value)
expected = np.ones(shape=[64, 36], dtype=np.int32)
@@ -320,8 +320,8 @@ class _FeedingFunctionsTestCase(test.TestCase):
def testFillArrayLargeWithSpecifiedValue(self):
fill_value = 8
- a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() + np.ones(
- shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())
+ a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() +
+ np.ones(shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())
actual = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)
ff._fill_array(actual, a, fill_value)
expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)
@@ -329,8 +329,8 @@ class _FeedingFunctionsTestCase(test.TestCase):
self.assertEqual(expected.tolist(), actual.tolist())
def testPadIfNeededSmall(self):
- a = (np.ones(shape=[32, 32], dtype=np.int32).tolist() + np.ones(
- shape=[32, 36], dtype=np.int32).tolist())
+ a = (np.ones(shape=[32, 32], dtype=np.int32).tolist() +
+ np.ones(shape=[32, 36], dtype=np.int32).tolist())
a = list(map(np.array, a))
actual = ff._pad_if_needed(a)
expected = np.ones(shape=[64, 36], dtype=np.int32)
@@ -338,8 +338,8 @@ class _FeedingFunctionsTestCase(test.TestCase):
self.assertEqual(expected.tolist(), actual.tolist())
def testPadIfNeededLarge(self):
- a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() + np.ones(
- shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())
+ a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() +
+ np.ones(shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())
a = list(map(np.array, a))
actual = ff._pad_if_needed(a)
expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)
@@ -348,8 +348,8 @@ class _FeedingFunctionsTestCase(test.TestCase):
def testPadIfNeededSmallWithSpecifiedValue(self):
fill_value = 8
- a = (np.ones(shape=[32, 32], dtype=np.int32).tolist() + np.ones(
- shape=[32, 36], dtype=np.int32).tolist())
+ a = (np.ones(shape=[32, 32], dtype=np.int32).tolist() +
+ np.ones(shape=[32, 36], dtype=np.int32).tolist())
a = list(map(np.array, a))
actual = ff._pad_if_needed(a, fill_value)
expected = np.ones(shape=[64, 36], dtype=np.int32)
@@ -358,8 +358,8 @@ class _FeedingFunctionsTestCase(test.TestCase):
def testPadIfNeededLargeWithSpecifiedValue(self):
fill_value = 8
- a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() + np.ones(
- shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())
+ a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() +
+ np.ones(shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())
a = list(map(np.array, a))
actual = ff._pad_if_needed(a, fill_value)
expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)
@@ -368,8 +368,8 @@ class _FeedingFunctionsTestCase(test.TestCase):
def testPadIfNeededSmallWithSpecifiedNonNumericValue(self):
fill_value = False
- a = (np.ones(shape=[32, 32], dtype=np.bool).tolist() + np.ones(
- shape=[32, 36], dtype=np.bool).tolist())
+ a = (np.ones(shape=[32, 32], dtype=np.bool).tolist() +
+ np.ones(shape=[32, 36], dtype=np.bool).tolist())
a = list(map(np.array, a))
actual = ff._pad_if_needed(a, fill_value)
expected = np.ones(shape=[64, 36], dtype=np.bool)
@@ -378,8 +378,8 @@ class _FeedingFunctionsTestCase(test.TestCase):
def testPadIfNeededLargeWithSpecifiedNonNumericValue(self):
fill_value = False
- a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.bool).tolist() + np.ones(
- shape=[8, 8, 8, 8, 36], dtype=np.bool).tolist())
+ a = (np.ones(shape=[8, 8, 8, 8, 32], dtype=np.bool).tolist() +
+ np.ones(shape=[8, 8, 8, 8, 36], dtype=np.bool).tolist())
a = list(map(np.array, a))
actual = ff._pad_if_needed(a, fill_value)
expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.bool)
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 2cfbb63de8..42e63d8b81 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -285,7 +285,8 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
def testReverse1DimAuto(self):
for dtype in [
- np.uint8, np.int8, np.int32, np.int64, np.bool, np.float16, np.float32,
+ np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64,
+ np.bool, np.float16, np.float32,
np.float64, np.complex64, np.complex128,
np.array(b"").dtype.type
]:
@@ -293,7 +294,8 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
def testReverse2DimAuto(self):
for dtype in [
- np.uint8, np.int8, np.int32, np.int64, np.bool, np.float16, np.float32,
+ np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64,
+ np.bool, np.float16, np.float32,
np.float64, np.complex64, np.complex128,
np.array(b"").dtype.type
]:
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 95cd27ebad..f95d9833d8 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -70,7 +70,6 @@ def _sparsify(x, thresh=0.5, index_dtype=np.int64):
return sparse_tensor.SparseTensor(
indices=x_indices, values=x_values, dense_shape=x_shape), x_values
-
def _default_tolerance(dtype):
"""Returns a sensible default tolerance for comparing results of a given
type"""
@@ -81,7 +80,7 @@ def _default_tolerance(dtype):
elif dtype in (np.float64, np.complex128):
return 1e-5
else:
- return None # Fail fast for unexpected types
+ return None # Fail fast for unexpected types
class UnaryOpTest(test.TestCase):
@@ -1957,7 +1956,7 @@ class ComplexMakeRealImagTest(test.TestCase):
with self.test_session(use_gpu=use_gpu) as sess:
inx = ops.convert_to_tensor(cplx)
tf_angle = math_ops.angle(inx)
- tf_angle_val = tf_angle.eval()
+ tf_angle_val = sess.run([tf_angle])
self.assertAllEqual(np_angle, tf_angle_val)
self.assertShapeEqual(np_angle, tf_angle)
@@ -2116,22 +2115,21 @@ class AccumulateTest(test.TestCase):
with self.assertRaises(ValueError):
a = variables.Variable(0.2)
b = variables.Variable(0.1)
- tf_val = math_ops.accumulate_n(
- [a, b], shape=[2, 2]) # Should be shape=[]
+ tf_val = math_ops.accumulate_n([a,b], shape=[2,2]) # Should be shape=[]
def testWrongType(self):
with self.test_session():
with self.assertRaises(TypeError):
a = variables.Variable(0.2, dtype=np.float32)
b = variables.Variable(0.1, dtype=np.float32)
- tf_val = math_ops.accumulate_n([a, b], tensor_dtype=np.int32)
+ tf_val = math_ops.accumulate_n([a,b], tensor_dtype=np.int32)
def testWrongTypeOneInput(self):
# Scenario that used to trigger a bug, even when testWrongType() worked
with self.test_session():
with self.assertRaises(TypeError):
a = variables.Variable(0.2, dtype=np.float32)
- tf_val = math_ops.accumulate_n([a], tensor_dtype=np.int32)
+ tf_val = math_ops.accumulate_n([a], tensor_dtype=np.int32)
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/denormal_test.py b/tensorflow/python/kernel_tests/denormal_test.py
index 2d48cd4163..95fc40f883 100644
--- a/tensorflow/python/kernel_tests/denormal_test.py
+++ b/tensorflow/python/kernel_tests/denormal_test.py
@@ -35,8 +35,8 @@ class DenormalTest(test.TestCase):
self.assertEqual(tiny, tiny / 16 * 16)
def _flushDenormalsTest(self, use_gpu, dtypes):
- if platform.machine() == "ppc64le":
- # Disabled denormal_test on power platform
+ if platform.machine() == "ppc64le" or platform.machine() == "s390x":
+ # Disabled denormal_test on power/s390x platform
# Check relevant discussion - https://github.com/tensorflow/tensorflow/issues/11902
return
with self.test_session(use_gpu=use_gpu):
diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py
index f3731bad38..a5bd1b6ee0 100644
--- a/tensorflow/python/kernel_tests/string_split_op_test.py
+++ b/tensorflow/python/kernel_tests/string_split_op_test.py
@@ -132,8 +132,9 @@ class StringSplitOpTest(test.TestCase):
with self.test_session() as sess:
tokens = string_ops.string_split(strings, "#", skip_empty=False)
indices, values, shape = sess.run(tokens)
- self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0],
- [2, 1], [2, 2]])
+ self.assertAllEqual(indices, [[0, 0], [0, 1],
+ [1, 0], [1, 1],
+ [2, 0], [2, 1], [2, 2]])
self.assertAllEqual(values, [b"", b"a", b"b", b"", b"", b"c", b""])
self.assertAllEqual(shape, [3, 3])
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 0bfde675f2..39609255b1 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -82,6 +82,7 @@ from __future__ import division
from __future__ import print_function
import sys
+
import numpy as np
from tensorflow.python.eager import context
@@ -100,7 +101,6 @@ from tensorflow.python.ops import gen_math_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_array_ops import *
from tensorflow.python.util import deprecation
-from tensorflow.python.util.deprecation import deprecated
# pylint: enable=wildcard-import
# Used for slicing to specify a new 1 size dimension
@@ -192,8 +192,10 @@ def expand_dims(input, axis=None, name=None, dim=None):
# Aliases for some automatically-generated names.
# pylint: disable=protected-access
-@deprecated("2016-11-30", "This op will be removed after the deprecation date. "
- "Please switch to tf.setdiff1d().")
+@deprecation.deprecated(
+ "2016-11-30",
+ "This op will be removed after the deprecation date. "
+ "Please switch to tf.setdiff1d().")
def listdiff(x, y, out_idx=None, name=None):
return gen_array_ops._list_diff(x, y, out_idx, name)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 5374817118..89de88a530 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -2671,8 +2671,8 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
call to `while_loop`, and not at all during `Session.run()`). `while_loop`
stitches together the graph fragments created during the `cond` and `body`
- calls with some additional graph nodes to make something the repeats
- `body` until `cond` returns false.
+ calls with some additional graph nodes to create the graph flow that
+ repeats `body` until `cond` returns false.
For correctness, `tf.while_loop()` strictly enforces shape invariants for
the loop variables. A shape invariant is a (possibly partial) shape that
@@ -2708,11 +2708,11 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
memory consumption and execution order. For correct programs, `while_loop`
should return the same result for any parallel_iterations > 0.
- For training, TensorFlow remembers the tensors that are produced in the
- forward inference but needed in back propagation. These tensors can be a
- main source of memory consumption and often cause OOM problems when training
- on GPUs. When the flag swap_memory is true, we swap out these tensors from
- GPU to CPU. This for example allows us to train RNN models with very long
+ For training, TensorFlow stores the tensors that are produced in the
+ forward inference and are needed in back propagation. These tensors are a
+ main source of memory consumption and often cause OOM errors when training
+ on GPUs. When the flag swap_memory is true, we swap out these tensors from
+ GPU to CPU. This for example allows us to train RNN models with very long
sequences and large batches.
Args:
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 975494e45f..5cd5d7ba2f 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -170,7 +170,7 @@ class ReaderBase(object):
return self._reader_ref
def read(self, queue, name=None):
- """Returns the next record (key, value pair) produced by a reader.
+ """Returns the next record (key, value) pair produced by a reader.
Will dequeue a work unit from queue if necessary (e.g. when the
Reader needs to start reading from a new file since it has
@@ -200,7 +200,7 @@ class ReaderBase(object):
def read_up_to(self, queue, num_records, # pylint: disable=invalid-name
name=None):
- """Returns up to num_records (key, value pairs) produced by a reader.
+ """Returns up to num_records (key, value) pairs produced by a reader.
Will dequeue a work unit from queue if necessary (e.g., when the
Reader needs to start reading from a new file since it has
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index ea8e2afa53..fd8a5daa33 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2060,7 +2060,7 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
tensor_dtype = inputs[0].dtype
if tensor_dtype != inputs[0].dtype:
raise TypeError("tensor_dtype is {}, but input is of type {}"
- .format(tensor_dtype, inputs[0].dtype))
+ .format(tensor_dtype, inputs[0].dtype))
if len(inputs) == 1:
return inputs[0]
with ops.name_scope(name, "AccumulateN", inputs) as name:
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index b36cfeb2eb..f1103f209c 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -328,9 +328,8 @@ class AddNTest(test_util.TensorFlowTestCase):
addn = math_ops.add_n(input_vars)
sess.run(variables.global_variables_initializer())
add_n_grad = gradients.gradients(addn, input_vars)
- self.assertAllEqual(
- np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1
- [g.eval() for g in add_n_grad])
+ self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1
+ [g.eval() for g in add_n_grad])
class DivAndModTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index e8e0a1dfc3..61a27911a3 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -91,5 +91,5 @@ pip2 install grpcio
pip3 install grpcio
# Eager execution needs autograd:
-pip2 install --upgrade autograd
-pip3 install --upgrade autograd
+pip2 install --upgrade autograd>=1.1.12
+pip3 install --upgrade autograd>=1.1.12
diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py
index f58c1f7fe5..e525e11397 100755
--- a/tensorflow/tools/ci_build/update_version.py
+++ b/tensorflow/tools/ci_build/update_version.py
@@ -71,8 +71,7 @@ def check_all_files():
def replace_with_sed(query, filename):
"""Replace with sed when regex is required."""
- subprocess.check_call("sed -i -r -e \"%s\" \"%s\"" % (query, filename),
- shell=True)
+ subprocess.check_call(['sed', '-i', '-r', '-e', query, filename])
class Version(object):
@@ -277,9 +276,8 @@ def check_for_lingering_string(lingering_string):
"""Check for given lingering strings."""
formatted_string = lingering_string.replace(".", r"\.")
try:
- linger_strs = subprocess.check_output("grep -rnoH \"%s\" \"%s\""
- % (formatted_string, TF_SRC_DIR),
- shell=True).split("\n")
+ linger_strs = subprocess.check_output(
+ ['grep', '-rnoH', formatted_string, TF_SRC_DIR]).split("\n")
except subprocess.CalledProcessError:
linger_strs = []
diff --git a/tensorflow/tools/pip_package/check_load_py_test.py b/tensorflow/tools/pip_package/check_load_py_test.py
index 7a132a8de3..79d11b08ce 100644
--- a/tensorflow/tools/pip_package/check_load_py_test.py
+++ b/tensorflow/tools/pip_package/check_load_py_test.py
@@ -42,11 +42,11 @@ def main():
# Get all py_test target, note bazel query result will also include
# cuda_py_test etc.
try:
- targets = subprocess.check_output(
- 'bazel query "kind(py_test, //tensorflow/contrib/... + '
+ targets = subprocess.check_output([
+ 'bazel', 'query',
+ 'kind(py_test, //tensorflow/contrib/... + '
'//tensorflow/python/... - '
- '//tensorflow/contrib/tensorboard/...)"',
- shell=True).strip()
+ '//tensorflow/contrib/tensorboard/...)']).strip()
except subprocess.CalledProcessError as e:
targets = e.output
@@ -68,9 +68,8 @@ def main():
files_missing_load = []
for build_file in build_files:
updated_build_file = subprocess.check_output(
- 'buildozer -stdout "new_load //tensorflow:tensorflow.bzl py_test" ' +
- build_file,
- shell=True)
+ ['buildozer', '-stdout', 'new_load //tensorflow:tensorflow.bzl py_test',
+ build_file])
with open(build_file, 'r') as f:
if f.read() != updated_build_file:
files_missing_load.append(build_file)
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index 9cdd6410e0..cc46dd5162 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -25,16 +25,16 @@ from __future__ import print_function
import subprocess
-PIP_PACKAGE_QUERY = """bazel query \
- 'deps(//tensorflow/tools/pip_package:build_pip_package)'"""
+PIP_PACKAGE_QUERY_EXPRESSION = \
+ 'deps(//tensorflow/tools/pip_package:build_pip_package)'
-PY_TEST_QUERY = """bazel query 'deps(\
+PY_TEST_QUERY_EXPRESSION = 'deps(\
filter("^((?!benchmark).)*$",\
kind(py_test,\
//tensorflow/python/... \
+ //tensorflow/contrib/... \
- //tensorflow/contrib/tensorboard/... \
- - attr(tags, "manual|no_pip", //tensorflow/...))), 1)'"""
+ - attr(tags, "manual|no_pip", //tensorflow/...))), 1)'
# Hard-coded blacklist of files if not included in pip package
# TODO(amitpatankar): Clean up blacklist.
@@ -83,15 +83,15 @@ def main():
"""
# pip_package_dependencies_list is the list of included files in pip packages
- pip_package_dependencies = subprocess.check_output(
- PIP_PACKAGE_QUERY, shell=True)
+ pip_package_dependencies = subprocess.check_output([
+ 'bazel', 'query', PIP_PACKAGE_QUERY_EXPRESSION])
pip_package_dependencies_list = pip_package_dependencies.strip().split("\n")
print("Pip package superset size: %d" % len(pip_package_dependencies_list))
# tf_py_test_dependencies is the list of dependencies for all python
# tests in tensorflow
- tf_py_test_dependencies = subprocess.check_output(
- PY_TEST_QUERY, shell=True)
+ tf_py_test_dependencies = subprocess.check_output([
+ 'bazel', 'query', PY_TEST_QUERY_EXPRESSION])
tf_py_test_dependencies_list = tf_py_test_dependencies.strip().split("\n")
print("Pytest dependency subset size: %d" % len(tf_py_test_dependencies_list))
@@ -124,9 +124,9 @@ def main():
for missing_dependency in missing_dependencies:
print("\nMissing dependency: %s " % missing_dependency)
print("Affected Tests:")
- rdep_query = """bazel query 'rdeps(kind(py_test, \
- //tensorflow/python/...), %s)'""" % missing_dependency
- affected_tests = subprocess.check_output(rdep_query, shell=True)
+ rdep_query = 'rdeps(kind(py_test, \
+ //tensorflow/python/...), %s)' % missing_dependency
+ affected_tests = subprocess.check_output(['bazel', 'query', rdep_query])
affected_tests_list = affected_tests.split("\n")[:-2]
print("\n".join(affected_tests_list))
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 1c45e9e5f6..b009859e41 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -224,10 +224,18 @@ setup(
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
+ 'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.4',
+ 'Programming Language :: Python :: 3.5',
+ 'Programming Language :: Python :: 3.6',
+ 'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Mathematics',
- 'Topic :: Software Development :: Libraries :: Python Modules',
- 'Topic :: Software Development :: Libraries',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'Topic :: Software Development',
+ 'Topic :: Software Development :: Libraries',
+ 'Topic :: Software Development :: Libraries :: Python Modules',
],
license='Apache 2.0',
keywords='tensorflow tensor machine learning',)
diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD
index 3a60c8c958..6607f629e7 100644
--- a/tensorflow/tools/proto_text/BUILD
+++ b/tensorflow/tools/proto_text/BUILD
@@ -34,7 +34,7 @@ cc_binary(
visibility = ["//tensorflow:internal"],
deps = [
":gen_proto_text_functions_lib",
- "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:lib",
],
)