aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--ISSUE_TEMPLATE.md24
-rwxr-xr-x[-rw-r--r--]tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt11
-rw-r--r--tensorflow/contrib/cmake/tf_c.cmake28
-rw-r--r--tensorflow/contrib/cmake/tf_cc_ops.cmake1
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake4
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake39
-rw-r--r--tensorflow/contrib/cmake/tf_shared_lib.cmake87
-rw-r--r--tensorflow/contrib/cmake/tf_tools.cmake9
-rw-r--r--tensorflow/contrib/cmake/tools/create_def_file.py78
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bernoulli.py13
-rwxr-xr-x[-rw-r--r--]tensorflow/contrib/image/BUILD28
-rwxr-xr-x[-rw-r--r--]tensorflow/contrib/image/__init__.py2
-rwxr-xr-xtensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc424
-rwxr-xr-xtensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc93
-rwxr-xr-xtensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py125
-rw-r--r--tensorflow/contrib/learn/python/learn/dataframe/queues/feeding_functions.py1
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py80
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/run_config.py4
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/self_adjoint_eig_v2_op.cc10
-rw-r--r--tensorflow/core/ops/array_ops.cc3
-rw-r--r--tensorflow/core/ops/linalg_ops.cc2
-rw-r--r--tensorflow/core/ops/ops.pbtxt53
-rw-r--r--tensorflow/docs_src/extend/adding_an_op.md8
-rw-r--r--tensorflow/docs_src/install/install_sources.md3
-rw-r--r--tensorflow/docs_src/performance/quantization.md4
-rw-r--r--tensorflow/examples/learn/iris_custom_decay_dnn.py2
-rw-r--r--tensorflow/examples/learn/text_classification_character_cnn.py5
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_with_summaries.py4
-rw-r--r--tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py9
-rw-r--r--tensorflow/python/layers/normalization.py17
-rw-r--r--tensorflow/python/ops/nn_grad.py10
-rw-r--r--tensorflow/python/ops/nn_ops.py12
-rw-r--r--tensorflow/python/summary/writer/event_file_writer.py22
-rw-r--r--tensorflow/python/summary/writer/writer_test.py9
-rw-r--r--tensorflow/tools/graph_transforms/README.md4
39 files changed, 1122 insertions, 120 deletions
diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md
index 50f67963bf..d0979e87f7 100644
--- a/ISSUE_TEMPLATE.md
+++ b/ISSUE_TEMPLATE.md
@@ -1,14 +1,32 @@
-NOTE: Issues that are not bugs or feature requests will be closed. Please ask usage questions on StackOverflow.
+Please go to Stack Overflow for help and support. http://stackoverflow.com/questions/tagged/tensorflow
+If you open a GitHub issue, here is our policy:
-### You must complete this information or else your issue will be closed
+1. It must be a bug or feature request.
+2. The form below must be filled out.
+
+**Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g. fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.
+
+------------------------
+
+Describe the problem clearly here. Be sure to convey here why it's a bug in TensorFlow or a feature request.
+
+### System Information
- *Have I written custom code (as opposed to using a stock example script provided in TensorFlow)?*:
+- *OS Platform and Distribution (i.e. Linux Ubuntu 16.0)*:
- *TensorFlow installed from (source or binary)?*:
-- *TensorFlow version*:
+- *TensorFlow version* (use command below):
- *Bazel version (if compiling from source)*:
- *CUDA/cuDNN version*:
- *GPU Model and Memory*:
- *Exact command to reproduce*:
+You can collect some of this information using our environment capture script https://github.com/tensorflow/tensorflow/blob/master/tools/
+You can collect the TensorFlow version with
+```sh
+python -c "import tensorflow as tf; print (tf.GIT_VERSION, tf.VERSION)"
+```
+
+
### Describe the problem clearly
### Source Code / Logs
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 17997844df..3ff76c3054 100644..100755
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -28,6 +28,7 @@ py_library(
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
"//tensorflow/contrib/hooks",
"//tensorflow/contrib/image:image_py",
+ "//tensorflow/contrib/image:single_image_random_dot_stereograms_py",
"//tensorflow/contrib/imperative",
"//tensorflow/contrib/input_pipeline:input_pipeline_py",
"//tensorflow/contrib/integrate:integrate_py",
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index 31a3d45a98..af7b4fb386 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -29,6 +29,7 @@ option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON)
option(tensorflow_BUILD_CONTRIB_KERNELS "Build OpKernels from tensorflow/contrib/..." ON)
option(tensorflow_BUILD_CC_TESTS "Build cc unit tests " OFF)
option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF)
+option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF)
option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON)
option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions")
@@ -198,7 +199,7 @@ if (tensorflow_ENABLE_GPU)
# add cudnn
include_directories(${CUDNN_HOME})
set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES}
- ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDNN_HOME}/lib/x64/cudnn.lib)
+ ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${CUDNN_HOME}/lib/x64/cudnn.lib)
# create cuda_config.h
FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h
@@ -219,6 +220,7 @@ if (tensorflow_ENABLE_GPU)
${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h ${CUDNN_HOME}/include/cudnn.h
${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h
${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_runtime_api.h
+ ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h
DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include
)
include_directories(${tensorflow_source_dir}/third_party/gpus)
@@ -244,7 +246,9 @@ include(tf_core_kernels.cmake)
if(tensorflow_ENABLE_GRPC_SUPPORT)
include(tf_core_distributed_runtime.cmake)
endif()
+# We include tf_cc_ops first, because tf_c depends on tf_cc.
include(tf_cc_ops.cmake)
+include(tf_c.cmake)
if(tensorflow_BUILD_CC_EXAMPLE)
include(tf_tutorials.cmake)
include(tf_label_image_example.cmake)
@@ -254,6 +258,9 @@ if(tensorflow_BUILD_PYTHON_BINDINGS)
include(tensorboard)
include(tf_python.cmake)
endif()
-if (tensorflow_BUILD_CC_TESTS OR tensorflow_BUILD_PYTHON_TESTS)
+if(tensorflow_BUILD_SHARED_LIB)
+ include(tf_shared_lib.cmake)
+endif()
+if(tensorflow_BUILD_CC_TESTS OR tensorflow_BUILD_PYTHON_TESTS)
include(tf_tests.cmake)
endif()
diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake
new file mode 100644
index 0000000000..069cdfa352
--- /dev/null
+++ b/tensorflow/contrib/cmake/tf_c.cmake
@@ -0,0 +1,28 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+########################################################
+# tf_c_framework library
+########################################################
+set(tf_c_srcs
+ "${tensorflow_source_dir}/tensorflow/c/c_api.cc"
+ "${tensorflow_source_dir}/tensorflow/c/c_api.h"
+ "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc"
+ "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h"
+ "${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc"
+ "${tensorflow_source_dir}/tensorflow/c/tf_status_helper.h"
+)
+
+add_library(tf_c OBJECT ${tf_c_srcs})
+add_dependencies(tf_c tf_cc_framework tf_core_lib tf_protos_cc)
diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake
index bb03593d54..b53f428461 100644
--- a/tensorflow/contrib/cmake/tf_cc_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake
@@ -19,6 +19,7 @@ set(tf_cc_framework_srcs
"${tensorflow_source_dir}/tensorflow/cc/framework/ops.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/ops.cc"
"${tensorflow_source_dir}/tensorflow/cc/framework/scope.h"
+ "${tensorflow_source_dir}/tensorflow/cc/framework/scope_internal.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/scope.cc"
)
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index c13c28910f..1b2096b645 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -116,6 +116,10 @@ if(WIN32)
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc"
+ # temporarily disable nccl (nccl itself needs to be ported to windows first)
+ "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc"
+ "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc"
+ "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc"
)
list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_windows_exclude_srcs})
endif(WIN32)
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 35629ba3b3..983f976c95 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -686,19 +686,7 @@ set (pywrap_tensorflow_internal_src
"${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_writer.cc"
"${tensorflow_source_dir}/tensorflow/python/util/kernel_registry.h"
"${tensorflow_source_dir}/tensorflow/python/util/kernel_registry.cc"
- "${tensorflow_source_dir}/tensorflow/c/c_api.cc"
- "${tensorflow_source_dir}/tensorflow/c/c_api.h"
- "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc"
- "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h"
- "${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc"
- "${tensorflow_source_dir}/tensorflow/c/tf_status_helper.h"
- "${tensorflow_source_dir}/tensorflow/cc/framework/gradients.h"
- "${tensorflow_source_dir}/tensorflow/cc/framework/gradients.cc"
- "${tensorflow_source_dir}/tensorflow/cc/framework/grad_op_registry.h"
- "${tensorflow_source_dir}/tensorflow/cc/framework/grad_op_registry.cc"
- "${tensorflow_source_dir}/tensorflow/cc/framework/ops.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/ops.cc"
- "${tensorflow_source_dir}/tensorflow/cc/framework/scope_internal.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/scope.cc"
"${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.cc"
)
@@ -715,9 +703,11 @@ if(WIN32)
#
add_library(pywrap_tensorflow_internal_static STATIC
${pywrap_tensorflow_internal_src}
+ $<TARGET_OBJECTS:tf_c>
$<TARGET_OBJECTS:tf_core_lib>
$<TARGET_OBJECTS:tf_core_cpu>
$<TARGET_OBJECTS:tf_core_framework>
+ $<TARGET_OBJECTS:tf_cc>
$<TARGET_OBJECTS:tf_cc_ops>
$<TARGET_OBJECTS:tf_core_ops>
$<TARGET_OBJECTS:tf_core_direct_session>
@@ -727,33 +717,43 @@ if(WIN32)
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_core_kernels_cpu_only>>
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
)
+
target_include_directories(pywrap_tensorflow_internal_static PUBLIC
${PYTHON_INCLUDE_DIR}
${NUMPY_INCLUDE_DIR}
)
- target_link_libraries(pywrap_tensorflow_internal_static
- tf_protos_cc
- tf_python_protos_cc
+ #target_link_libraries(pywrap_tensorflow_internal_static
+ # tf_protos_cc
+ # tf_python_protos_cc
+ #)
+ add_dependencies(pywrap_tensorflow_internal_static tf_protos_cc tf_python_protos_cc)
+ set(pywrap_tensorflow_internal_static_dependencies
+ $<TARGET_FILE:pywrap_tensorflow_internal_static>
+ $<TARGET_FILE:tf_protos_cc>
+ $<TARGET_FILE:tf_python_protos_cc>
)
+
set(pywrap_tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow.def")
set_source_files_properties(${pywrap_tensorflow_deffile} PROPERTIES GENERATED TRUE)
add_custom_command(TARGET pywrap_tensorflow_internal_static POST_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tools/create_def_file.py
- --input $<TARGET_FILE:pywrap_tensorflow_internal_static>
- --output ${pywrap_tensorflow_deffile}
+ --input "${pywrap_tensorflow_internal_static_dependencies}"
+ --output "${pywrap_tensorflow_deffile}"
+ --target _pywrap_tensorflow_internal.pyd
)
endif(WIN32)
-
# pywrap_tensorflow_internal is a shared library containing all of the
# TensorFlow runtime and the standard ops and kernels. These are installed into
# tf_python/tensorflow/python/.
add_library(pywrap_tensorflow_internal SHARED
${pywrap_tensorflow_internal_src}
+ $<TARGET_OBJECTS:tf_c>
$<TARGET_OBJECTS:tf_core_lib>
$<TARGET_OBJECTS:tf_core_cpu>
$<TARGET_OBJECTS:tf_core_framework>
+ $<TARGET_OBJECTS:tf_cc>
$<TARGET_OBJECTS:tf_cc_ops>
$<TARGET_OBJECTS:tf_core_ops>
$<TARGET_OBJECTS:tf_core_direct_session>
@@ -773,7 +773,8 @@ target_include_directories(pywrap_tensorflow_internal PUBLIC
${PYTHON_INCLUDE_DIR}
${NUMPY_INCLUDE_DIR}
)
-target_link_libraries(pywrap_tensorflow_internal
+
+target_link_libraries(pywrap_tensorflow_internal PRIVATE
${tf_core_gpu_kernels_lib}
${tensorflow_EXTERNAL_LIBRARIES}
tf_protos_cc
diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake
new file mode 100644
index 0000000000..47289fd9d2
--- /dev/null
+++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake
@@ -0,0 +1,87 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+if(WIN32)
+ # Windows: build a static library with the same objects as tensorflow.dll.
+ # This can be used to build for a standalone exe and also helps us to
+ # find all symbols that need to be exported from the dll which is needed
+ # to provide the tensorflow c/c++ api in tensorflow.dll.
+ # From the static library we create the def file with all symbols that need to
+ # be exported from tensorflow.dll. Because there is a limit of 64K sybmols
+ # that can be exported, we filter the symbols with a python script to the namespaces
+ # we need.
+ #
+ add_library(tensorflow_static STATIC
+ $<TARGET_OBJECTS:tf_c>
+ $<TARGET_OBJECTS:tf_cc>
+ $<TARGET_OBJECTS:tf_cc_framework>
+ $<TARGET_OBJECTS:tf_cc_ops>
+ $<TARGET_OBJECTS:tf_core_lib>
+ $<TARGET_OBJECTS:tf_core_cpu>
+ $<TARGET_OBJECTS:tf_core_framework>
+ $<TARGET_OBJECTS:tf_core_ops>
+ $<TARGET_OBJECTS:tf_core_direct_session>
+ $<TARGET_OBJECTS:tf_tools_transform_graph_lib>
+ $<$<BOOL:${tensorflow_ENABLE_GRPC_SUPPORT}>:$<TARGET_OBJECTS:tf_core_distributed_runtime>>
+ $<TARGET_OBJECTS:tf_core_kernels>
+ $<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_core_kernels_cpu_only>>
+ $<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
+ )
+
+ add_dependencies(tensorflow_static tf_protos_cc)
+ set(tensorflow_static_dependencies
+ $<TARGET_FILE:tensorflow_static>
+ $<TARGET_FILE:tf_protos_cc>
+ )
+
+ set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/tensorflow.def")
+ set_source_files_properties(${tensorflow_deffile} PROPERTIES GENERATED TRUE)
+
+ add_custom_command(TARGET tensorflow_static POST_BUILD
+ COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tools/create_def_file.py
+ --input "${tensorflow_static_dependencies}"
+ --output "${tensorflow_deffile}"
+ --target tensorflow.dll
+ )
+endif(WIN32)
+
+# tensorflow is a shared library containing all of the
+# TensorFlow runtime and the standard ops and kernels.
+add_library(tensorflow SHARED
+ $<TARGET_OBJECTS:tf_c>
+ $<TARGET_OBJECTS:tf_cc>
+ $<TARGET_OBJECTS:tf_cc_framework>
+ $<TARGET_OBJECTS:tf_cc_ops>
+ $<TARGET_OBJECTS:tf_core_lib>
+ $<TARGET_OBJECTS:tf_core_cpu>
+ $<TARGET_OBJECTS:tf_core_framework>
+ $<TARGET_OBJECTS:tf_core_ops>
+ $<TARGET_OBJECTS:tf_core_direct_session>
+ $<TARGET_OBJECTS:tf_tools_transform_graph_lib>
+ $<$<BOOL:${tensorflow_ENABLE_GRPC_SUPPORT}>:$<TARGET_OBJECTS:tf_core_distributed_runtime>>
+ $<TARGET_OBJECTS:tf_core_kernels>
+ $<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_core_kernels_cpu_only>>
+ $<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
+ ${tensorflow_deffile}
+)
+
+target_link_libraries(tensorflow PRIVATE
+ ${tf_core_gpu_kernels_lib}
+ ${tensorflow_EXTERNAL_LIBRARIES}
+ tf_protos_cc
+)
+
+if(WIN32)
+ add_dependencies(tensorflow tensorflow_static)
+endif(WIN32)
diff --git a/tensorflow/contrib/cmake/tf_tools.cmake b/tensorflow/contrib/cmake/tf_tools.cmake
index c518414f0c..6ef9598963 100644
--- a/tensorflow/contrib/cmake/tf_tools.cmake
+++ b/tensorflow/contrib/cmake/tf_tools.cmake
@@ -73,10 +73,13 @@ add_executable(${transform_graph}
$<TARGET_OBJECTS:tf_core_direct_session>
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
$<TARGET_OBJECTS:tf_core_kernels>
+ $<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_core_kernels_cpu_only>>
+ $<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
)
target_link_libraries(${transform_graph} PUBLIC
tf_protos_cc
+ ${tf_core_gpu_kernels_lib}
${tensorflow_EXTERNAL_LIBRARIES}
)
@@ -92,10 +95,13 @@ add_executable(${summarize_graph}
$<TARGET_OBJECTS:tf_core_direct_session>
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
$<TARGET_OBJECTS:tf_core_kernels>
+ $<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_core_kernels_cpu_only>>
+ $<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
)
target_link_libraries(${summarize_graph} PUBLIC
tf_protos_cc
+ ${tf_core_gpu_kernels_lib}
${tensorflow_EXTERNAL_LIBRARIES}
)
@@ -111,10 +117,13 @@ add_executable(${compare_graphs}
$<TARGET_OBJECTS:tf_core_direct_session>
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
$<TARGET_OBJECTS:tf_core_kernels>
+ $<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_core_kernels_cpu_only>>
+ $<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
)
target_link_libraries(${compare_graphs} PUBLIC
tf_protos_cc
+ ${tf_core_gpu_kernels_lib}
${tensorflow_EXTERNAL_LIBRARIES}
)
diff --git a/tensorflow/contrib/cmake/tools/create_def_file.py b/tensorflow/contrib/cmake/tools/create_def_file.py
index ebdc918168..9bd287d0d7 100644
--- a/tensorflow/contrib/cmake/tools/create_def_file.py
+++ b/tensorflow/contrib/cmake/tools/create_def_file.py
@@ -47,8 +47,16 @@ DUMPBIN = "dumpbin.exe"
EXCLUDE_RE = re.compile(r"deleting destructor|::internal::")
# Include if matched before exclude
-INCLUDEPRE_RE = re.compile(r"tensorflow::internal::LogMessage|"
- r"tensorflow::internal::CheckOpMessageBuilder")
+INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|"
+ r"tensorflow::internal::LogMessage|"
+ r"tensorflow::internal::LogString|"
+ r"tensorflow::internal::CheckOpMessageBuilder|"
+ r"tensorflow::internal::PickUnusedPortOrDie|"
+ r"tensorflow::internal::ValidateDevice|"
+ r"tensorflow::ops::internal::Enter|"
+ r"tensorflow::strings::internal::AppendPieces|"
+ r"tensorflow::strings::internal::CatPieces|"
+ r"tensorflow::io::internal::JoinPathImpl")
# Include if matched after exclude
INCLUDE_RE = re.compile(r"^(TF_\w*)$|"
@@ -56,12 +64,27 @@ INCLUDE_RE = re.compile(r"^(TF_\w*)$|"
r"functor::|"
r"perftools::gputools")
-
+# We want to identify data members explicitly in the DEF file, so that no one
+# can implicitly link against the DLL if they use one of the variables exported
+# from the DLL and the header they use does not decorate the symbol with
+# __declspec(dllimport). It is easier to detect what a data symbol does
+# NOT look like, so doing it with the below regex.
+DATA_EXCLUDE_RE = re.compile(r"[)(]|"
+ r"vftable|"
+ r"vbtable|"
+ r"vcall|"
+ r"RTTI|"
+ r"protobuf::internal::ExplicitlyConstructed")
+
def get_args():
"""Parse command line."""
+ filename_list = lambda x: x.split(";")
parser = argparse.ArgumentParser()
- parser.add_argument("--input", help="input library", required=True)
+ parser.add_argument("--input", type=filename_list,
+ help="paths to input libraries separated by semicolons",
+ required=True)
parser.add_argument("--output", help="output deffile", required=True)
+ parser.add_argument("--target", help="name of the target", required=True)
args = parser.parse_args()
return args
@@ -70,25 +93,26 @@ def main():
"""main."""
args = get_args()
- # Pipe dumpbin to extract all linkable symbols from a lib.
+ # Pipe dumpbin to extract all linkable symbols from libs.
# Good symbols are collected in candidates and also written to
# a temp file.
candidates = []
tmpfile = tempfile.NamedTemporaryFile(mode="w", delete=False)
- proc = subprocess.Popen([DUMPBIN, "/nologo", "/linkermember:1", args.input],
- stdout=subprocess.PIPE)
- for line in io.TextIOWrapper(proc.stdout, encoding="utf-8"):
- cols = line.split()
- if len(cols) < 2:
- continue
- sym = cols[1]
- tmpfile.file.write(sym + "\n")
- candidates.append(sym)
+ for lib_path in args.input:
+ proc = subprocess.Popen([DUMPBIN, "/nologo", "/linkermember:1", lib_path],
+ stdout=subprocess.PIPE)
+ for line in io.TextIOWrapper(proc.stdout, encoding="utf-8"):
+ cols = line.split()
+ if len(cols) < 2:
+ continue
+ sym = cols[1]
+ tmpfile.file.write(sym + "\n")
+ candidates.append(sym)
+ exit_code = proc.wait()
+ if exit_code != 0:
+ print("{} failed, exit={}".format(DUMPBIN, exit_code))
+ return exit_code
tmpfile.file.close()
- exit_code = proc.wait()
- if exit_code != 0:
- print("{} failed, exit={}".format(DUMPBIN, exit_code))
- return exit_code
# Run the symbols through undname to get their undecorated name
# so we can filter on something readable.
@@ -96,9 +120,8 @@ def main():
# track dupes
taken = set()
- # Header for the def file. Since the tensorflow.dll is actually called
- # _pywrap_tensorflow.pyd in the python wheel, hint that in the def file.
- def_fp.write("LIBRARY _pywrap_tensorflow_internal.pyd\n")
+ # Header for the def file.
+ def_fp.write("LIBRARY " + args.target + "\n")
def_fp.write("EXPORTS\n")
def_fp.write("\t ??1OpDef@tensorflow@@UEAA@XZ\n")
@@ -118,8 +141,17 @@ def main():
continue
if not INCLUDE_RE.search(line):
continue
-
- def_fp.write("\t" + decorated + "\n")
+
+ if "deleting destructor" in line:
+ # Some of the symbols convered by INCLUDEPRE_RE export deleting
+ # destructor symbols, which is a bad idea.
+ # So we filter out such symbols here.
+ continue
+
+ if DATA_EXCLUDE_RE.search(line):
+ def_fp.write("\t" + decorated + "\n")
+ else:
+ def_fp.write("\t" + decorated + " DATA\n")
taken.add(decorated)
exit_code = proc.wait()
if exit_code != 0:
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py
index 6ba872ef9c..87b2331a1d 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py
@@ -148,6 +148,15 @@ class BernoulliTest(test.TestCase):
p: [0.2, 0.3, 0.4]
}), [[0.2, 0.7, 0.4]])
+ def testPmfInvalid(self):
+ p = [0.1, 0.2, 0.7]
+ with self.test_session():
+ dist = bernoulli.Bernoulli(probs=p, validate_args=True)
+ with self.assertRaisesOpError("must be non-negative."):
+ dist.prob([1, 1, -1]).eval()
+ with self.assertRaisesOpError("is not less than or equal to 1."):
+ dist.prob([2, 0, 1]).eval()
+
def testPmfWithP(self):
p = [[0.2, 0.4], [0.3, 0.6]]
self._testPmf(probs=p)
diff --git a/tensorflow/contrib/distributions/python/ops/bernoulli.py b/tensorflow/contrib/distributions/python/ops/bernoulli.py
index 33e6dbd78b..c491cb5d42 100644
--- a/tensorflow/contrib/distributions/python/ops/bernoulli.py
+++ b/tensorflow/contrib/distributions/python/ops/bernoulli.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@@ -120,6 +121,7 @@ class Bernoulli(distribution.Distribution):
return math_ops.cast(sample, self.dtype)
def _log_prob(self, event):
+ event = self._maybe_assert_valid_sample(event)
# TODO(jaana): The current sigmoid_cross_entropy_with_logits has
# inconsistent behavior for logits = inf/-inf.
event = math_ops.cast(event, self.logits.dtype)
@@ -160,6 +162,17 @@ class Bernoulli(distribution.Distribution):
"""Returns `1` if `prob > 0.5` and `0` otherwise."""
return math_ops.cast(self.probs > 0.5, self.dtype)
+ def _maybe_assert_valid_sample(self, event, check_integer=True):
+ if not self.validate_args:
+ return event
+ event = distribution_util.embed_check_nonnegative_discrete(
+ event, check_integer=check_integer)
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_less_equal(
+ event, array_ops.ones_like(event),
+ message="event is not less than or equal to 1."),
+ ], event)
+
class BernoulliWithSigmoidProbs(Bernoulli):
"""Bernoulli with `probs = nn.sigmoid(logits)`."""
diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD
index c31ac7b324..7599406aca 100644..100755
--- a/tensorflow/contrib/image/BUILD
+++ b/tensorflow/contrib/image/BUILD
@@ -87,6 +87,7 @@ cuda_py_test(
srcs = ["python/kernel_tests/image_ops_test.py"],
additional_deps = [
":image_py",
+ ":single_image_random_dot_stereograms_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
@@ -96,6 +97,33 @@ cuda_py_test(
],
)
+tf_custom_op_library(
+ name = "python/ops/_single_image_random_dot_stereograms.so",
+ srcs = [
+ "kernels/single_image_random_dot_stereograms_ops.cc",
+ "ops/single_image_random_dot_stereograms_ops.cc",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["single_image_random_dot_stereograms_ops"],
+)
+
+tf_gen_op_wrapper_py(
+ name = "single_image_random_dot_stereograms_ops",
+ deps = [":single_image_random_dot_stereograms_ops_op_lib"],
+)
+
+py_library(
+ name = "single_image_random_dot_stereograms_py",
+ srcs = glob(["python/ops/single*.py"]) + ["__init__.py"],
+ data = [":python/ops/_single_image_random_dot_stereograms.so"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":single_image_random_dot_stereograms_ops",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py
index 4ad599b1f8..aa70d42339 100644..100755
--- a/tensorflow/contrib/image/__init__.py
+++ b/tensorflow/contrib/image/__init__.py
@@ -25,6 +25,7 @@ projective transforms (including rotation) are supported.
@@compose_transforms
@@rotate
@@transform
+@@single_image_random_dot_stereograms
"""
from __future__ import absolute_import
from __future__ import division
@@ -35,6 +36,7 @@ from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_t
from tensorflow.contrib.image.python.ops.image_ops import compose_transforms
from tensorflow.contrib.image.python.ops.image_ops import rotate
from tensorflow.contrib.image.python.ops.image_ops import transform
+from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms import single_image_random_dot_stereograms
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc
new file mode 100755
index 0000000000..23efd359d5
--- /dev/null
+++ b/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc
@@ -0,0 +1,424 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+
+template <typename T>
+class SingleImageRandomDotStereogramsOp : public OpKernel {
+ private:
+ int E2Epixels; // Pixels from eye to eye = eye_to_eye_inches * DPI
+
+ int input_Xvalue; // X value of input Z values (width)
+ int input_Yvalue; // Y value of input Z values (height)
+
+ int output_Ximage; // X value of output image (width)
+ int output_Yimage; // Y value of output image (height)
+ int output_Cimage; // color value of output image (color, 1 or 3) (3 not
+ // implemented)
+
+ int data_box_left; // X starting value for DATA window
+ int data_box_top; // Y starting value for DATA window
+ int data_box_width; // width of scan line
+ int data_box_height; // hight of image
+
+ int converge_dot_box_end; // Row convergences dots end on
+
+ uint8* outputImage; // Output Image flat as a buffer (Tensor Connection)
+ double* ZBuffer; // For internal use, allow for MASK, etc later, actual Z
+ // used for Stereogram, XxY (X is the row index, y is col
+ // index like a screen)
+ // 0 (far) -> 1.0(near) range
+ bool hidden_surface_removal;
+ int convergence_dots_size;
+ int dots_per_inch;
+ float eye_separation;
+ float mu;
+ bool normalize;
+ float normalize_max;
+ float normalize_min;
+ float border_level;
+ int number_colors;
+ ::tensorflow::TensorShapeProto output_image_shape;
+ ::tensorflow::TensorShapeProto output_data_window;
+
+ uint8 Cblack = (uint8)0;
+ uint8 Cwhite = (uint8)255;
+
+ int indexMode = 0; // 0 - truncate XY, 1 - round XY, 2 - Interpolate XY (not
+ // implemented yet, keep default of 0)
+ int interp_x, interp_y; // 1 - yes, 0 - no interpolation directions (not
+ // implemented yet)
+
+ bool debugging = false;
+
+ inline int separation(double z) {
+ return (std::round((1 - mu * z) * E2Epixels / (2 - mu * z)));
+ }
+
+ inline int get_far_width() { return (separation(0.0)); }
+ inline int get_near_width() { return (separation(1.0)); }
+
+ public:
+ explicit SingleImageRandomDotStereogramsOp(OpKernelConstruction* context)
+ : OpKernel(context) { // Constructor
+ OP_REQUIRES_OK(context, context->GetAttr("hidden_surface_removal",
+ &hidden_surface_removal));
+ OP_REQUIRES_OK(context, context->GetAttr("convergence_dots_size",
+ &convergence_dots_size));
+ OP_REQUIRES_OK(context, context->GetAttr("dots_per_inch", &dots_per_inch));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("eye_separation", &eye_separation));
+ OP_REQUIRES_OK(context, context->GetAttr("mu", &mu));
+ OP_REQUIRES_OK(context, context->GetAttr("normalize", &normalize));
+ OP_REQUIRES_OK(context, context->GetAttr("normalize_max", &normalize_max));
+ OP_REQUIRES_OK(context, context->GetAttr("normalize_min", &normalize_min));
+ OP_REQUIRES_OK(context, context->GetAttr("border_level", &border_level));
+ OP_REQUIRES_OK(context, context->GetAttr("number_colors", &number_colors));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("output_image_shape", &output_image_shape));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("output_data_window", &output_data_window));
+
+ E2Epixels =
+ eye_separation * dots_per_inch; // Initialize pixels from eye to eye
+ }
+
+ ~SingleImageRandomDotStereogramsOp() { // Destructor
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_tensor = context->input(0);
+ input_Xvalue = input_tensor.shape().dim_size(
+ 1); // X value is the number of columns of the input matrix
+ input_Yvalue =
+ input_tensor.shape().dim_size(0); // Y value is the number of rows
+
+ output_Ximage = output_image_shape.dim(0).size();
+ output_Yimage = output_image_shape.dim(1).size();
+ output_Cimage = output_image_shape.dim(2).size();
+
+ if (number_colors > 256) // Go to full color image
+ output_Cimage = 3;
+
+ int data_Xwindow = output_data_window.dim(0).size();
+ int data_Ywindow = output_data_window.dim(1).size();
+
+ int deltaX_border_image = output_Ximage - data_Xwindow;
+ int deltaY_border_image = output_Yimage - data_Ywindow;
+
+ if (convergence_dots_size >
+ 0) // 3 frame sections in Y direction due to DOTS
+ {
+ deltaY_border_image =
+ deltaY_border_image -
+ convergence_dots_size; // Take off space for Convergence Dots
+ deltaY_border_image = std::max(0, deltaY_border_image);
+ data_box_top = deltaY_border_image / 3;
+
+ if (deltaY_border_image >= 0) {
+ converge_dot_box_end = output_Yimage - 1 - data_box_top;
+ } else {
+ converge_dot_box_end = output_Yimage - 1;
+ }
+ } else // Otherwise only 2, no convergence dot
+ {
+ data_box_top = deltaY_border_image / 2; // Center DATA in Y dimension
+ converge_dot_box_end = output_Yimage - 1;
+ }
+
+ data_box_left = deltaX_border_image / 2; // Center DATA in X dimension
+ data_box_width = data_Xwindow; // width of scan line
+ data_box_height = data_Ywindow; // hight of image
+
+ const T* inputZ = input_tensor.flat<T>().data(); // Flatten input Z buffer
+
+ BuildZBuffer(inputZ);
+
+ // Output a scalar string.
+ Tensor* output_tensor = NULL;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_output(
+ 0, TensorShape({output_Yimage, output_Ximage, output_Cimage}),
+ &output_tensor));
+
+ outputImage = output_tensor->flat<uint8>().data();
+
+ generate_stereogram();
+
+ delete[] ZBuffer;
+ }
+
+ //***************************************************************************
+ //***************************************************************************
+ // Move input into standard Z format to reduce complexity of algorithm
+ //
+ void BuildZBuffer(const T* Z, bool log = false) {
+ double MaxValue = 1.0;
+ double MinValue = 0.0;
+ ZBuffer = new double[input_Xvalue * input_Yvalue]; // Used to computer
+ // final Z values before
+ // rendering to output
+
+ if (normalize) {
+ // Init Min/Max to first value
+ if (normalize_max < normalize_min) // Autoscale if MIN>MAX
+ {
+ MaxValue = (double)*Z;
+ MinValue = (double)*Z;
+
+ for (int y = 0; y < input_Yvalue; ++y)
+ for (int x = 0; x < input_Xvalue; ++x) {
+ double value = getZfromInputImage(Z, x, y);
+ if (value > MaxValue) MaxValue = value;
+ if (value < MinValue) MinValue = value;
+ }
+ } else {
+ MaxValue = normalize_max;
+ MinValue = normalize_min;
+ }
+ }
+
+ for (int y = 0; y < input_Yvalue; ++y)
+ for (int x = 0; x < input_Xvalue; ++x) {
+ double value = getZfromInputImage(Z, x, y);
+
+ if (normalize) {
+ value = (value - MinValue) / (MaxValue - MinValue);
+ }
+
+ if (value > 1.0) value = 1.0;
+ if (value < 0.0) value = 0.0;
+
+ *(ZBuffer + (input_Xvalue * y + x)) = value;
+ }
+ }
+
+ //***************************************************************************
+ //***************************************************************************
+ double getZfromInputImage(const T* Z, int x, int y) {
+ double return_val;
+
+ return_val = (double)*(Z + input_Xvalue * y + x); // Get value
+ return return_val;
+ }
+
+ //***************************************************************************
+ //***************************************************************************
+ // All normalized, not checking required
+ // Possible Projection issue if DATA is bigger or smaller than Input
+ // Modes include:
+ // Truncate value (Default)
+ // Round-off value
+ // Interpolate between values
+ //
+ double getZfromZbuffer(double x, double y) {
+ int xi, yi;
+
+ switch (indexMode) {
+ case 0: // Truncate
+ xi = int(x);
+ yi = int(y);
+ return (*(ZBuffer + (xi + input_Xvalue * yi)));
+ break;
+ case 1: // Round-off
+ xi = std::round(x);
+ yi = std::round(y);
+ return (*(ZBuffer + (xi + input_Xvalue * yi)));
+ break;
+ case 2: // Interpolate (Not implemented yet, will need 4 points
+ // [x,y],[x+1,y],[x,y+1],[x+1,y+1], then interpolate)
+ xi = int(x);
+ yi = int(y);
+ return (*(ZBuffer + (xi + input_Xvalue * yi)));
+ break;
+ default: // Round-off is the default
+ xi = int(x + 0.5);
+ yi = int(y + 0.5);
+ return (*(ZBuffer + (xi + input_Xvalue * yi)));
+ break;
+ }
+ }
+
+ //***************************************************************************
+ //***************************************************************************
+
+ int getOutputImageIndex(int x, int y,
+ int channel) { // No error checking for some
+ // optimization, calling routine
+ // required to make sure there is no
+ // violation
+ return ((output_Ximage * output_Cimage) * y + x * output_Cimage + channel);
+ }
+
+ //***************************************************************************
+ //***************************************************************************
+
+ double getZFromOutputPixel(int x, int y) {
+ double xofz, yofz, returnval;
+
+ // Convert pixel units to Z units, do this as "double"
+
+ xofz =
+ (double)input_Xvalue * (x - data_box_left) / ((double)data_box_width);
+ yofz =
+ (double)input_Yvalue * (y - data_box_top) / ((double)data_box_height);
+
+ if ((xofz < 0) || (yofz < 0) || (yofz >= input_Yvalue) ||
+ (xofz >= input_Xvalue)) { // Top of left side border hit or Right
+ // side or bottom border hit
+ // Send BORDER Z value
+ return (border_level);
+ }
+
+ { // in data set Z interpolate if need
+ double gz;
+
+ gz = getZfromZbuffer(xofz, yofz);
+
+ returnval = gz;
+ }
+
+ return (returnval);
+ }
+
+ //***************************************************************************
+ //***************************************************************************
+
+ void generate_stereogram() {
+ int s, left, right, visible, t, l;
+ double zt, gz;
+ // Scan line
+ uint8* pix; // Scan row color for each pixel
+ int* same; // Used to determine if Pixel needs to be the same as another
+ // pixel in the row
+
+ pix = new uint8[output_Ximage * output_Cimage];
+ same = new int[output_Ximage];
+
+ for (int y = 0; y < output_Yimage; ++y) {
+ // Set no dependencies on any pixels, tie each one back to itself
+ for (int x = 0; x < output_Ximage; ++x) same[x] = x;
+
+ for (int x = 0; x < output_Ximage; ++x) {
+ gz = getZFromOutputPixel(x, y);
+ s = separation(gz);
+ left = x - s / 2;
+ right = left + s;
+
+ if ((left >= 0) && (right < output_Ximage)) {
+ t = 1;
+ visible = 1;
+ if (hidden_surface_removal) do {
+ zt = gz + 2 * (2 - mu * gz) * t / (mu * E2Epixels);
+ visible = (getZFromOutputPixel(x - t, y) < zt) &&
+ (getZFromOutputPixel(x + t, y) < zt);
+ ++t;
+ } while ((visible) && (zt < 1));
+
+ if (visible) {
+ l = same[left];
+ while ((l != left) && (l != right))
+ if (l < right) {
+ left = l;
+ l = same[left];
+ } else {
+ same[left] = right;
+ left = right;
+ l = same[left];
+ right = l;
+ }
+ same[left] = right;
+ }
+ }
+ }
+ // Set colors for scan row, use channels and number_colors
+ for (int x = output_Ximage - 1; x >= 0; x--) {
+ for (int channel = 0; channel < output_Cimage; ++channel) {
+ if (same[x] == x) { // Pick a random color
+ if (number_colors == 2) {
+ if ((rand() % 2) == 0) {
+ pix[x * output_Cimage + channel] = Cblack;
+ } else {
+ pix[x * output_Cimage + channel] = Cwhite;
+ }
+ } else {
+ pix[x * output_Cimage + channel] = rand() % 256;
+ }
+ } else
+ pix[x * output_Cimage + channel] =
+ pix[same[x] * output_Cimage + channel];
+
+ setpixel(x, y, channel, pix[x * output_Cimage + channel]);
+ }
+ }
+ }
+
+ draw_convergence_dots();
+
+ delete[] pix;
+ delete[] same;
+ }
+
+ //***************************************************************************
+ //***************************************************************************
+
+ void draw_convergence_dots() {
+ int x1, x2; // center position for convergence dots
+
+ if (convergence_dots_size == 0) // No dot, return
+ return;
+
+ x1 = output_Ximage / 2 - get_far_width() / 2;
+ x2 = output_Ximage / 2 + get_far_width() / 2;
+
+ for (int lloop = 0; lloop < convergence_dots_size; ++lloop)
+ for (int wloop = 0; wloop < convergence_dots_size; ++wloop)
+ for (int channel = 0; channel < output_Cimage; ++channel) {
+ setpixel(x1 - (convergence_dots_size / 2) + wloop,
+ converge_dot_box_end - lloop, channel, Cblack);
+ setpixel(x2 - (convergence_dots_size / 2) + wloop,
+ converge_dot_box_end - lloop, channel, Cblack);
+ }
+ }
+
+ //***************************************************************************
+ //***************************************************************************
+
+ void setpixel(int x, int y, int channel, uint8 color) {
+ *(outputImage + getOutputImageIndex(x, y, channel)) = color;
+ }
+};
+
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("SingleImageRandomDotStereograms") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ SingleImageRandomDotStereogramsOp<T>);
+
+REGISTER_KERNEL(int32);
+REGISTER_KERNEL(int64);
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+
+#undef REGISTER_KERNEL
+
+} // end namespace tensorflow
diff --git a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc
new file mode 100755
index 0000000000..8a7cc56256
--- /dev/null
+++ b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc
@@ -0,0 +1,93 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+
+REGISTER_OP("SingleImageRandomDotStereograms")
+ .Attr("T: {double,float,int64,int32}")
+ .Input("depth_values: T")
+ .Output("image: uint8")
+ .Attr("hidden_surface_removal: bool = true")
+ .Attr("convergence_dots_size: int = 8")
+ .Attr("dots_per_inch: int = 72")
+ .Attr("eye_separation: float = 2.5")
+ .Attr("mu: float = .3333")
+ .Attr("normalize: bool = true")
+ .Attr("normalize_max: float = -100.0")
+ .Attr("normalize_min: float = 100.0")
+ .Attr("border_level: float = 0.0")
+ .Attr("number_colors: int = 256")
+ .Attr(
+ "output_image_shape: shape = { dim {size:1024} dim {size: 768} dim "
+ "{size: 1}}")
+ .Attr("output_data_window: shape = { dim {size:1022} dim {size: 757}}")
+ .Doc(R"doc(
+Outputs a single image random dot stereogram for export via encode_PNG/JPG OP.
+
+Given the 2-D tensor 'depth_values' with encoded Z values, this operation will
+encode 3-D data into a 2-D image. The output of this Op is suitable for the
+encode_PNG/JPG ops. Be careful with image compression as this may corrupt the
+encode 3-D data witin the image.
+
+This Op is based upon:
+'http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper'
+
+Example use which outputs a SIRDS image as picture_out.png:
+```python
+img=[[1,2,3,3,2,1],
+ [1,2,3,4,5,2],
+ [1,2,3,4,5,3],
+ [1,2,3,4,5,4],
+ [6,5,4,4,5,5]]
+
+session = tf.InteractiveSession()
+
+sirds = single_image_random_dot_stereograms(img,convergence_dots_size=8,number_colors=256,normalize=True)
+
+out = sirds.eval()
+
+png = tf.image.encode_png(out).eval()
+
+with open('picture_out.png', 'wb') as f:
+ f.write(png)
+```
+
+depth_values: Z values of data to encode into 'output_data_window' window,
+ lower values are further away {0.0 floor(far), 1.0 ceiling(near) after normalization}, must be 2-D tensor
+hidden_surface_removal: Activate hidden surface removal
+convergence_dots_size: Black dot size in pixels to help view converge image, drawn on bottom of image
+dots_per_inch: Output device in dots/inch
+eye_separation: Separation between eyes in inches
+mu: Depth of field, Fraction of viewing distance (eg. 1/3 = .3333)
+normalize: Normalize input data to [0.0, 1.0]
+normalize_max: Fix MAX value for Normalization - if < MIN, autoscale
+normalize_min: Fix MIN value for Normalization - if > MAX, autoscale
+border_level: Value of border depth 0.0 {far} to 1.0 {near}
+number_colors: 2 (Black & White),256 (grayscale), and Numbers > 256 (Full Color) are all that are supported currently
+output_image_shape: Output size of returned image in X,Y, Channels 1-grayscale, 3 color (1024, 768, 1),
+ channels will be updated to 3 if 'number_colors' > 256
+output_data_window: Size of "DATA" window, must be equal to or smaller than 'output_image_shape', will be centered
+ and use 'convergence_dots_size' for best fit to avoid overlap if possible
+
+image:= A tensor of size 'output_image_shape' with the encloded 'depth_values'
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py
new file mode 100755
index 0000000000..79261c5e75
--- /dev/null
+++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py
@@ -0,0 +1,125 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python layer for image_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import resource_loader
+
+_sirds_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile(
+ "_single_image_random_dot_stereograms.so"))
+
+def single_image_random_dot_stereograms(
+ depth_values,
+ hidden_surface_removal=None,
+ convergence_dots_size=None,
+ dots_per_inch=None,
+ eye_separation=None, mu=None,
+ normalize=None, normalize_max=None,
+ normalize_min=None,
+ border_level=None,
+ number_colors=None,
+ output_image_shape=None,
+ output_data_window=None):
+ """Output a RandomDotStereogram Tensor for export via encode_PNG/JPG OP.
+
+ Given the 2-D tensor 'depth_values' with encoded Z values, this operation
+ will encode 3-D data into a 2-D image. The output of this Op is suitable
+ for the encode_PNG/JPG ops. Be careful with image compression as this may
+ corrupt the encode 3-D data witin the image.
+
+ Based upon [this paper](http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper).
+
+ This outputs a SIRDS image as picture_out.png:
+
+ ```python
+ img=[[1,2,3,3,2,1],
+ [1,2,3,4,5,2],
+ [1,2,3,4,5,3],
+ [1,2,3,4,5,4],
+ [6,5,4,4,5,5]]
+ session = tf.InteractiveSession()
+ sirds = single_image_random_dot_stereograms(
+ img,
+ convergence_dots_size=8,
+ number_colors=256,normalize=True)
+
+ out = sirds.eval()
+ png = tf.image.encode_png(out).eval()
+ with open('picture_out.png', 'wb') as f:
+ f.write(png)
+ ```
+
+ Args:
+ depth_values: A `Tensor`. Must be one of the following types:
+ `float64`, `float32`, `int64`, `int32`. Z values of data to encode
+ into 'output_data_window' window, lower further away {0.0 floor(far),
+ 1.0 ceiling(near) after norm}, must be 2-D tensor
+ hidden_surface_removal: An optional `bool`. Defaults to `True`.
+ Activate hidden surface removal
+ convergence_dots_size: An optional `int`. Defaults to `8`.
+ Black dot size in pixels to help view converge image, drawn on bottom
+ of the image
+ dots_per_inch: An optional `int`. Defaults to `72`.
+ Output device in dots/inch
+ eye_separation: An optional `float`. Defaults to `2.5`.
+ Separation between eyes in inches
+ mu: An optional `float`. Defaults to `0.3333`.
+ Depth of field, Fraction of viewing distance (eg. 1/3 = 0.3333)
+ normalize: An optional `bool`. Defaults to `True`.
+ Normalize input data to [0.0, 1.0]
+ normalize_max: An optional `float`. Defaults to `-100`.
+ Fix MAX value for Normalization (0.0) - if < MIN, autoscale
+ normalize_min: An optional `float`. Defaults to `100`.
+ Fix MIN value for Normalization (0.0) - if > MAX, autoscale
+ border_level: An optional `float`. Defaults to `0`.
+ Value of bord in depth 0.0 {far} to 1.0 {near}
+ number_colors: An optional `int`. Defaults to `256`. 2 (Black &
+ White), 256 (grayscale), and Numbers > 256 (Full Color) are
+ supported
+ output_image_shape: An optional `tf.TensorShape` or list of `ints`.
+ Defaults to shape `[1024, 768, 1]`. Defines output shape of returned
+ image in '[X,Y, Channels]' 1-grayscale, 3 color; channels will be
+ updated to 3 if number_colors > 256
+ output_data_window: An optional `tf.TensorShape` or list of `ints`.
+ Defaults to `[1022, 757]`. Size of "DATA" window, must be equal to or
+ smaller than `output_image_shape`, will be centered and use
+ `convergence_dots_size` for best fit to avoid overlap if possible
+
+ Returns:
+ A `Tensor` of type `uint8` of shape 'output_image_shape' with encoded
+ 'depth_values'
+ """
+
+ result = _sirds_ops.single_image_random_dot_stereograms(
+ depth_values=depth_values,
+ hidden_surface_removal=hidden_surface_removal,
+ convergence_dots_size=convergence_dots_size,
+ dots_per_inch=dots_per_inch,
+ eye_separation=eye_separation, mu=mu,
+ normalize=normalize,
+ normalize_max=normalize_max,
+ normalize_min=normalize_min,
+ border_level=border_level,
+ number_colors=number_colors,
+ output_image_shape=output_image_shape,
+ output_data_window=output_data_window)
+ return result
+
+ops.NotDifferentiable("SingleImageRandomDotStereograms")
diff --git a/tensorflow/contrib/learn/python/learn/dataframe/queues/feeding_functions.py b/tensorflow/contrib/learn/python/learn/dataframe/queues/feeding_functions.py
index b891bf2301..dfe08bb863 100644
--- a/tensorflow/contrib/learn/python/learn/dataframe/queues/feeding_functions.py
+++ b/tensorflow/contrib/learn/python/learn/dataframe/queues/feeding_functions.py
@@ -25,4 +25,5 @@ from tensorflow.python.estimator.inputs.queues.feeding_functions import _enqueue
from tensorflow.python.estimator.inputs.queues.feeding_functions import _GeneratorFeedFn
from tensorflow.python.estimator.inputs.queues.feeding_functions import _OrderedDictNumpyFeedFn
from tensorflow.python.estimator.inputs.queues.feeding_functions import _PandasFeedFn
+from tensorflow.python.estimator.inputs.queues.feeding_functions import _GeneratorFeedFn
# pylint: enable=unused-import
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 68f71ff022..ae01c678b6 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -611,7 +611,7 @@ def _create_model_fn_ops(features,
if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
weight_tensor = _weight_tensor(features, weight_column_name)
loss, weighted_average_loss = loss_fn(labels, logits, weight_tensor)
- logging_ops.scalar_summary(
+ summary.scalar(
_summary_key(head_name, mkey.LOSS), weighted_average_loss)
if mode == model_fn.ModeKeys.TRAIN:
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index 9b8cba1526..abaf3a61a1 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -124,7 +124,7 @@ class PoissonHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=logits)
self._assert_output_alternatives(model_fn_ops)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["regression_head/loss"])
_assert_no_variables(self)
loss = self._log_poisson_loss(logits, labels)
_assert_metrics(self, loss, {"loss": loss}, model_fn_ops)
@@ -150,7 +150,7 @@ class RegressionHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((1.,), (1.,), (3.,)))
self._assert_output_alternatives(model_fn_ops)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["regression_head/loss"])
_assert_no_variables(self)
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
@@ -180,7 +180,7 @@ class RegressionHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["regression_head/loss"])
_assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)
def testRegressionWithLogitsAndLogitsInput(self):
@@ -208,7 +208,7 @@ class RegressionHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["regression_head/loss"])
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
def testRegressionWithLabelName(self):
@@ -223,7 +223,7 @@ class RegressionHeadTest(test.TestCase):
logits=((1.,), (1.,), (3.,)))
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["regression_head/loss"])
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
def testRegressionWithWeights(self):
@@ -238,7 +238,7 @@ class RegressionHeadTest(test.TestCase):
logits=((1.,), (1.,), (3.,)))
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["regression_head/loss"])
_assert_metrics(self, 2. / len(weights), {"loss": 2. / np.sum(weights)},
model_fn_ops)
@@ -261,7 +261,7 @@ class RegressionHeadTest(test.TestCase):
expected_trainable=("regression_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(
- self, ["loss", "regression_head/centered_bias/bias_0"])
+ self, ["regression_head/loss", "regression_head/centered_bias/bias_0"])
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
def testRegressionErrorInSparseTensorLabels(self):
@@ -330,7 +330,7 @@ class MultiLabelHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = .89985204
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@@ -347,7 +347,7 @@ class MultiLabelHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn, logits=logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = 1.00320443
_assert_metrics(self, expected_loss, {
"accuracy": 0.,
@@ -387,7 +387,7 @@ class MultiLabelHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = .69314718
_assert_metrics(self, expected_loss, {
"accuracy": 2. / 3,
@@ -432,7 +432,7 @@ class MultiLabelHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = .89985204
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@@ -451,7 +451,7 @@ class MultiLabelHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = 1.377779
expected_eval_metrics = {
"accuracy": 1. / 3,
@@ -519,7 +519,7 @@ class MultiLabelHeadTest(test.TestCase):
head_lib.no_op_train_fn, logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = .89985204
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@@ -539,7 +539,7 @@ class MultiLabelHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_label_head/loss"])
_assert_metrics(self, .089985214,
self._expected_eval_metrics(2.69956), model_fn_ops)
@@ -559,7 +559,7 @@ class MultiLabelHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_label_head/loss"])
_assert_metrics(self, 0.089985214,
self._expected_eval_metrics(0.089985214), model_fn_ops)
@@ -583,7 +583,7 @@ class MultiLabelHeadTest(test.TestCase):
expected_trainable=("multi_label_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(self, (
- "loss",
+ "multi_label_head/loss",
"multi_label_head/centered_bias/bias_0",
"multi_label_head/centered_bias/bias_1",
"multi_label_head/centered_bias/bias_2"
@@ -608,7 +608,7 @@ class MultiLabelHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=self._logits)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = .89985204
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@@ -674,7 +674,7 @@ class BinaryClassificationHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["binary_logistic_head/loss"])
expected_loss = .81326175
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@@ -702,7 +702,7 @@ class BinaryClassificationHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["binary_logistic_head/loss"])
expected_loss = .69314718
label_mean = np.mean(self._labels)
_assert_metrics(self, expected_loss, {
@@ -738,7 +738,7 @@ class BinaryClassificationHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["binary_logistic_head/loss"])
expected_loss = .81326175
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@@ -817,7 +817,7 @@ class BinaryClassificationHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["binary_logistic_head/loss"])
expected_loss = .81326175
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@@ -838,7 +838,7 @@ class BinaryClassificationHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["binary_logistic_head/loss"])
expected_total_loss = .31326166
_assert_metrics(
self,
@@ -871,7 +871,7 @@ class BinaryClassificationHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["binary_logistic_head/loss"])
# logloss: z:label, x:logit
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
# expected_loss is (total_weighted_loss)/1 since htere is 1 nonzero
@@ -911,7 +911,8 @@ class BinaryClassificationHeadTest(test.TestCase):
expected_trainable=("binary_logistic_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(
- self, ["loss", "binary_logistic_head/centered_bias/bias_0"])
+ self, ["binary_logistic_head/loss",
+ "binary_logistic_head/centered_bias/bias_0"])
expected_loss = .81326175
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@@ -960,7 +961,7 @@ class MultiClassHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 1.5514447
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@@ -999,7 +1000,7 @@ class MultiClassHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 1.0986123
_assert_metrics(self, expected_loss, {
"accuracy": 0.,
@@ -1050,7 +1051,7 @@ class MultiClassHeadTest(test.TestCase):
expected_trainable=("multi_class_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(self,
- ["loss",
+ ["multi_class_head/loss",
"multi_class_head/centered_bias/bias_0",
"multi_class_head/centered_bias/bias_1",
"multi_class_head/centered_bias/bias_2"])
@@ -1068,7 +1069,7 @@ class MultiClassHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 1.5514447
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@@ -1087,7 +1088,7 @@ class MultiClassHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 3.1698461
expected_eval_metrics = {
"accuracy": 0.,
@@ -1126,7 +1127,7 @@ class MultiClassHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 1.5514447
_assert_metrics(self, expected_loss * weight,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@@ -1150,7 +1151,7 @@ class MultiClassHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 1.5514447 * weight
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@@ -1257,7 +1258,7 @@ class MultiClassHeadTest(test.TestCase):
data_flow_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 1.5514447
expected_eval_metrics = {
"accuracy": 0.,
@@ -1283,7 +1284,7 @@ class MultiClassHeadTest(test.TestCase):
data_flow_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 0.5514447
expected_eval_metrics = {
"accuracy": 1.,
@@ -1322,7 +1323,7 @@ class BinarySvmHeadTest(test.TestCase):
logits=self._predictions)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["binary_svm_head/loss"])
expected_loss = np.average(self._expected_losses)
_assert_metrics(self, expected_loss, {
"accuracy": 1.,
@@ -1352,7 +1353,7 @@ class BinarySvmHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["binary_svm_head/loss"])
expected_loss = 1.
_assert_metrics(self, expected_loss, {
"accuracy": .5,
@@ -1384,7 +1385,7 @@ class BinarySvmHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["binary_svm_head/loss"])
expected_loss = np.average(self._expected_losses)
_assert_metrics(self, expected_loss, {
"accuracy": 1.,
@@ -1403,7 +1404,7 @@ class BinarySvmHeadTest(test.TestCase):
logits=self._predictions)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["binary_svm_head/loss"])
expected_loss = np.average(self._expected_losses)
_assert_metrics(self, expected_loss, {
"accuracy": 1.,
@@ -1422,7 +1423,7 @@ class BinarySvmHeadTest(test.TestCase):
logits=self._predictions)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
- _assert_summary_tags(self, ["loss"])
+ _assert_summary_tags(self, ["binary_svm_head/loss"])
expected_weighted_sum = np.sum(
np.multiply(weights, self._expected_losses))
_assert_metrics(self, expected_weighted_sum / len(weights), {
@@ -1450,7 +1451,8 @@ class BinarySvmHeadTest(test.TestCase):
expected_trainable=("binary_svm_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(
- self, ["loss", "binary_svm_head/centered_bias/bias_0"])
+ self, ["binary_svm_head/loss",
+ "binary_svm_head/centered_bias/bias_0"])
expected_loss = np.average(self._expected_losses)
_assert_metrics(self, expected_loss, {
"accuracy": 1.,
diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
index 25da51675a..109c8d25e1 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
@@ -89,9 +89,9 @@ class ClusterConfig(object):
```
cluster = {'ps': ['host1:2222', 'host2:2222'],
'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
- os.environ['TF_CONFIG'] = json.dumps({
+ os.environ['TF_CONFIG'] = json.dumps(
{'cluster': cluster,
- 'task': {'type': 'worker', 'index': 1}}})
+ 'task': {'type': 'worker', 'index': 1}})
config = ClusterConfig()
assert config.master == 'host4:2222'
assert config.task_id == 1
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
index 3848832f38..b132b1e8f8 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
@@ -19,8 +19,6 @@ limitations under the License.
#include "tensorflow/core/kernels/segment_reduction_ops.h"
-#include <stdio.h>
-
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
diff --git a/tensorflow/core/kernels/self_adjoint_eig_v2_op.cc b/tensorflow/core/kernels/self_adjoint_eig_v2_op.cc
index c647d3aaac..7a1db4e558 100644
--- a/tensorflow/core/kernels/self_adjoint_eig_v2_op.cc
+++ b/tensorflow/core/kernels/self_adjoint_eig_v2_op.cc
@@ -69,7 +69,7 @@ class SelfAdjointEigV2Op : public LinearAlgebraOp<Scalar> {
errors::InvalidArgument("Self Adjoint Eigen decomposition was not "
"successful. The input might not be valid."));
- outputs->at(0) = eig.eigenvalues();
+ outputs->at(0) = eig.eigenvalues().template cast<Scalar>();
if (compute_v_) {
outputs->at(1) = eig.eigenvectors();
}
@@ -81,7 +81,15 @@ class SelfAdjointEigV2Op : public LinearAlgebraOp<Scalar> {
REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<float>), float);
REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<double>), double);
+REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<complex64>),
+ complex64);
+REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<complex128>),
+ complex128);
REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<float>), float);
REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<double>),
double);
+REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<complex64>),
+ complex64);
+REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<complex128>),
+ complex128);
} // namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index e0f448d1c1..11df3c43c7 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1330,10 +1330,9 @@ this operation will permute `params` accordingly.
`indices` are always validated to be within range. If assigned to GPU,
out-of-bound indices result in safe but unspecified behavior, which may include
raising an error.
-`0`, but this may become an error in the future).
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../../images/Gather.png" alt>
+<img style="width:100%" src="../../../images/Gather.png" alt>
</div>
)doc");
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index a2762cf206..872824b885 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -318,7 +318,7 @@ REGISTER_OP("SelfAdjointEigV2")
.Output("e: T")
.Output("v: T")
.Attr("compute_v: bool = True")
- .Attr("T: {double, float}")
+ .Attr("T: {double, float, complex64, complex128}")
.SetShapeFn(SelfAdjointEigV2ShapeFn)
.Doc(R"doc(
Computes the eigen decomposition of one or more square self-adjoint matrices.
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 1866a2b4cc..d42d01d6ba 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -26364,6 +26364,59 @@ op {
description: "Read @{$math_ops#segmentation$the section on segmentation} for an explanation of\nsegments.\n\nComputes a tensor such that\n`(output[i] = sum_{j...} data[j...]` where the sum is over tuples `j...` such\nthat `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\nrange of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"https://www.tensorflow.org/images/UnsortedSegmentSum.png\" alt>\n</div>"
}
op {
+ name: "UnsortedSegmentSum"
+ input_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "segment_ids"
+ description: "A tensor whose shape is a prefix of `data.shape`."
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "num_segments"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ description: "Has same shape as data, except for the first `segment_ids.rank`\ndimensions, which are replaced with a single dimension which has size\n`num_segments`."
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ summary: "Computes the max along segments of a tensor."
+ description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\sum_j data_j\\\\) where sum is over `j` such\nthat `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\n range of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>"
+}
+op {
name: "Unstage"
output_arg {
name: "values"
diff --git a/tensorflow/docs_src/extend/adding_an_op.md b/tensorflow/docs_src/extend/adding_an_op.md
index 45f7530506..c75c7f111d 100644
--- a/tensorflow/docs_src/extend/adding_an_op.md
+++ b/tensorflow/docs_src/extend/adding_an_op.md
@@ -182,10 +182,10 @@ g++ -std=c++11 -shared zero_out.cc -o zero_out.so -fPIC -I $TF_INC -O2
On Mac OS X, the additional flag "-undefined dynamic_lookup" is required when
building the `.so` file.
-> Note on gcc version 5: gcc5 uses the new C++
-> [ABI](https://gcc.gnu.org/gcc-5/changes.html#libstdcxx). The binary pip
-> packages available on the TensorFlow website are built with gcc4 that uses
-> the older ABI. If you compile your op library with gcc5, add
+> Note on `gcc` version `>=5`: gcc uses the new C++
+> [ABI](https://gcc.gnu.org/gcc-5/changes.html#libstdcxx) since version `5`. The binary pip
+> packages available on the TensorFlow website are built with `gcc4` that uses
+> the older ABI. If you compile your op library with `gcc>=5`, add
> `-D_GLIBCXX_USE_CXX11_ABI=0` to the command line to make the library
> compatible with the older abi.
> Furthermore if you are using TensorFlow package created from source remember to add `-cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"`
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index 7d452feafa..5f351c40b4 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -298,7 +298,7 @@ invoke the following command:
<pre>$ <b>bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package</b> </pre>
-**NOTE on gcc 5 or later:** the binary pip packages available on the TensorFlow website are built with gcc 4, which uses the older ABI. To make your build compatible with the older ABI, you need to add `-cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"` to your `bazel build` command. ABI compatibility allows custom ops built against the TensorFlow pip package to continue to work against your built package.
+**NOTE on gcc 5 or later:** the binary pip packages available on the TensorFlow website are built with gcc 4, which uses the older ABI. To make your build compatible with the older ABI, you need to add `--cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"` to your `bazel build` command. ABI compatibility allows custom ops built against the TensorFlow pip package to continue to work against your built package.
<b>Tip:</b> By default, building TensorFlow from sources consumes
a lot of RAM. If RAM is an issue on your system, you may limit RAM usage
@@ -367,6 +367,7 @@ of one of the following guides:
* @{$install_linux#CommonInstallationProblems$Installing TensorFlow on Linux}
* @{$install_mac#CommonInstallationProblems$Installing TensorFlow on Mac OS}
+ * @{$install_windows#CommonInstallationProblems$Installing TensorFlow on Windows}
Beyond the errors documented in those two guides, the following table
notes additional errors specific to building TensorFlow. Note that we
diff --git a/tensorflow/docs_src/performance/quantization.md b/tensorflow/docs_src/performance/quantization.md
index 86d2b92494..ad23bab443 100644
--- a/tensorflow/docs_src/performance/quantization.md
+++ b/tensorflow/docs_src/performance/quantization.md
@@ -91,8 +91,8 @@ eight-bit computations:
```sh
curl http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz -o /tmp/inceptionv3.tgz
tar xzf /tmp/inceptionv3.tgz -C /tmp/
-bazel build tensorflow/tools/quantization/tools:quantize_graph
-bazel-bin/tensorflow/tools/quantization/tools/quantize_graph \
+bazel build tensorflow/tools/quantization:quantize_graph
+bazel-bin/tensorflow/tools/quantization/quantize_graph \
--input=/tmp/classify_image_graph_def.pb \
--output_node_names="softmax" --output=/tmp/quantized_graph.pb \
--mode=eightbit
diff --git a/tensorflow/examples/learn/iris_custom_decay_dnn.py b/tensorflow/examples/learn/iris_custom_decay_dnn.py
index 73c526cd4e..31acbd30cd 100644
--- a/tensorflow/examples/learn/iris_custom_decay_dnn.py
+++ b/tensorflow/examples/learn/iris_custom_decay_dnn.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""Example of DNNClassifier for Iris plant dataset, with exponential decay."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
diff --git a/tensorflow/examples/learn/text_classification_character_cnn.py b/tensorflow/examples/learn/text_classification_character_cnn.py
index 0c96976146..5ad53acf9f 100644
--- a/tensorflow/examples/learn/text_classification_character_cnn.py
+++ b/tensorflow/examples/learn/text_classification_character_cnn.py
@@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""This is an example of using convolutional networks over characters for DBpedia dataset to predict class from description of an entity.
+"""This is an example of using convolutional networks over characters for
+ DBpedia dataset to predict class from description of an entity.
This model is similar to one described in this paper:
"Character-level Convolutional Networks for Text Classification"
@@ -54,7 +55,7 @@ def char_cnn_model(features, target):
# Apply Convolution filtering on input sequence.
conv1 = tf.contrib.layers.convolution2d(
byte_list, N_FILTERS, FILTER_SHAPE1, padding='VALID')
- # Add a RELU for non linearity.
+ # Add a ReLU for non linearity.
conv1 = tf.nn.relu(conv1)
# Max pooling across output of Convolution+Relu.
pool1 = tf.nn.max_pool(
diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
index 75ea0b9c67..698c97ca1d 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""A simple MNIST classifier which displays summaries in TensorBoard.
- This is an unimpressive MNIST model, but it is a good example of using
+This is an unimpressive MNIST model, but it is a good example of using
tf.name_scope to make a graph legible in the TensorBoard graph explorer, and of
naming summary tags so that they are grouped meaningfully in TensorBoard.
@@ -78,7 +78,7 @@ def train():
def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu):
"""Reusable code for making a simple neural net layer.
- It does a matrix multiply, bias add, and then uses relu to nonlinearize.
+ It does a matrix multiply, bias add, and then uses ReLU to nonlinearize.
It also sets up name scoping so that the resultant graph is easy to read,
and adds a number of summary ops.
"""
diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
index 36b3ed33d8..9ab9304572 100644
--- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
+++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
@@ -78,7 +78,7 @@ def _GetSelfAdjointEigTest(dtype_, shape_):
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(dtype_)
a += a.T
a = np.tile(a, batch_shape + (1, 1))
- if dtype_ == np.float32:
+ if dtype_ == np.float32 or dtype_ == np.complex64:
atol = 1e-4
else:
atol = 1e-12
@@ -150,13 +150,14 @@ def _GetSelfAdjointEigGradTest(dtype_, shape_):
if __name__ == '__main__':
- for dtype in np.float32, np.float64:
+ for dtype in np.float32, np.float64, np.complex64, np.complex128:
for size in 1, 2, 5, 10:
for batch_dims in [(), (3,)] + [(3, 2)] * (max(size, size) < 10):
shape = batch_dims + (size, size)
name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape)))
setattr(SelfAdjointEigTest, 'testSelfAdjointEig_' + name,
_GetSelfAdjointEigTest(dtype, shape))
- setattr(SelfAdjointEigGradTest, 'testSelfAdjointEigGrad_' + name,
- _GetSelfAdjointEigGradTest(dtype, shape))
+ if dtype in [np.float32, np.float64]:
+ setattr(SelfAdjointEigGradTest, 'testSelfAdjointEigGrad_' + name,
+ _GetSelfAdjointEigGradTest(dtype, shape))
test.main()
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 34b663119e..41846ae0cd 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -364,6 +364,23 @@ def batch_normalization(inputs,
Sergey Ioffe, Christian Szegedy
+ Note: the operations which update the `moving_mean` and `moving_variance`
+ variables will not be added as dependencies of your training operation and so
+ must be run separately. For example:
+ ```
+ extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+ sess.run([train_op, extra_update_ops], ...)
+ ```
+ Alternatively, add the operations as a dependency to your training operation
+ manually, and then just run your training operation as normal:
+ ```
+ extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+ with tf.control_dependencies(extra_update_ops):
+ train_op = optimizer.minimize(loss)
+ ...
+ sess.run([train_op], ...)
+ ```
+
Arguments:
inputs: Tensor input.
axis: Integer, the axis that should be normalized (typically the features
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index ebf17c8a41..b1f50fd341 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -512,6 +512,16 @@ def _MaxPoolGrad(op, grad):
data_format=op.get_attr("data_format"))
+@ops.RegisterGradient("MaxPoolWithArgmax")
+def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
+ return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0],
+ grad,
+ op.outputs[1],
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ padding=op.get_attr("padding"))
+
+
@ops.RegisterGradient("MaxPoolGrad")
def _MaxPoolGradGrad(op, grad):
return (array_ops.zeros(
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 66ccedf546..ccce9402c7 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1478,14 +1478,14 @@ def _softmax(logits, compute_op, dim=-1, name=None):
InvalidArgumentError: if `logits` is empty or `dim` is beyond the last
dimension of `logits`.
"""
- def _swap_axis(logits, dim_index, last_index):
+ def _swap_axis(logits, dim_index, last_index, name=None):
"""Swaps logits's dim_index and last_index."""
return array_ops.transpose(logits,
array_ops.concat([
math_ops.range(dim_index), [last_index],
math_ops.range(dim_index + 1, last_index),
[dim_index]
- ], 0))
+ ], 0), name=name)
logits = ops.convert_to_tensor(logits)
@@ -1501,8 +1501,8 @@ def _softmax(logits, compute_op, dim=-1, name=None):
if is_last_dim:
input_shape = array_ops.shape(logits)
logits = _flatten_outer_dims(logits)
- output = compute_op(logits, name=name)
- output = array_ops.reshape(output, input_shape)
+ output = compute_op(logits)
+ output = array_ops.reshape(output, input_shape, name=name)
return output
# If dim is not the last dimension, we have to do a reshape and transpose so
@@ -1517,11 +1517,11 @@ def _softmax(logits, compute_op, dim=-1, name=None):
logits = _flatten_outer_dims(logits)
# Do the actual softmax on its last dimension.
- output = compute_op(logits, name=name)
+ output = compute_op(logits)
# Transform back the output tensor.
output = array_ops.reshape(output, shape_after_swap)
- output = _swap_axis(output, dim, math_ops.subtract(input_rank, 1))
+ output = _swap_axis(output, dim, math_ops.subtract(input_rank, 1), name=name)
# Make shape inference work since reshape and transpose may erase its static
# shape.
diff --git a/tensorflow/python/summary/writer/event_file_writer.py b/tensorflow/python/summary/writer/event_file_writer.py
index 8940d9b72e..2936a279bd 100644
--- a/tensorflow/python/summary/writer/event_file_writer.py
+++ b/tensorflow/python/summary/writer/event_file_writer.py
@@ -24,6 +24,7 @@ import time
import six
+from tensorflow.core.util import event_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import gfile
from tensorflow.python.util import compat
@@ -67,14 +68,20 @@ class EventFileWriter(object):
self._event_queue = six.moves.queue.Queue(max_queue)
self._ev_writer = pywrap_tensorflow.EventsWriter(
compat.as_bytes(os.path.join(self._logdir, "events")))
+ self._flush_secs = flush_secs
+ self._sentinel_event = self._get_sentinel_event()
if filename_suffix:
self._ev_writer.InitWithSuffix(compat.as_bytes(filename_suffix))
self._closed = False
self._worker = _EventLoggerThread(self._event_queue, self._ev_writer,
- flush_secs)
+ self._flush_secs, self._sentinel_event)
self._worker.start()
+ def _get_sentinel_event(self):
+ """Generate a sentinel event for terminating worker."""
+ return event_pb2.Event()
+
def get_logdir(self):
"""Returns the directory where event file will be written."""
return self._logdir
@@ -88,6 +95,9 @@ class EventFileWriter(object):
Does nothing if the EventFileWriter was not closed.
"""
if self._closed:
+ self._worker = _EventLoggerThread(self._event_queue, self._ev_writer,
+ self._flush_secs, self._sentinel_event)
+ self._worker.start()
self._closed = False
def add_event(self, event):
@@ -113,7 +123,9 @@ class EventFileWriter(object):
Call this method when you do not need the summary writer anymore.
"""
+ self.add_event(self._sentinel_event)
self.flush()
+ self._worker.join()
self._ev_writer.Close()
self._closed = True
@@ -121,7 +133,7 @@ class EventFileWriter(object):
class _EventLoggerThread(threading.Thread):
"""Thread that logs events."""
- def __init__(self, queue, ev_writer, flush_secs):
+ def __init__(self, queue, ev_writer, flush_secs, sentinel_event):
"""Creates an _EventLoggerThread.
Args:
@@ -130,6 +142,8 @@ class _EventLoggerThread(threading.Thread):
the visualizer.
flush_secs: How often, in seconds, to flush the
pending file to disk.
+ sentinel_event: A sentinel element in queue that tells this thread to
+ terminate.
"""
threading.Thread.__init__(self)
self.daemon = True
@@ -138,10 +152,14 @@ class _EventLoggerThread(threading.Thread):
self._flush_secs = flush_secs
# The first event will be flushed immediately.
self._next_event_flush_time = 0
+ self._sentinel_event = sentinel_event
def run(self):
while True:
event = self._queue.get()
+ if event is self._sentinel_event:
+ self._queue.task_done()
+ break
try:
self._ev_writer.WriteEvent(event)
# Flush the event writer every so often.
diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py
index b31c41d112..8c34eb82e3 100644
--- a/tensorflow/python/summary/writer/writer_test.py
+++ b/tensorflow/python/summary/writer/writer_test.py
@@ -258,6 +258,15 @@ class SummaryWriterTestCase(test.TestCase):
# We should be done.
self.assertRaises(StopIteration, lambda: next(rr))
+ def testNonBlockingClose(self):
+ test_dir = self._CleanTestDir("non_blocking_close")
+ sw = writer.FileWriter(test_dir)
+ # Sleep 1.2 seconds to make sure event queue is empty.
+ time.sleep(1.2)
+ time_before_close = time.time()
+ sw.close()
+ self._assertRecent(time_before_close)
+
# Checks that values returned from session Run() calls are added correctly to
# summaries. These are numpy types so we need to check they fit in the
# protocol buffers correctly.
diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md
index 06ae78ef5d..df9191fb96 100644
--- a/tensorflow/tools/graph_transforms/README.md
+++ b/tensorflow/tools/graph_transforms/README.md
@@ -103,7 +103,7 @@ output layers of the model are. The best source for these is the model training
process, where for a classifier the inputs will be the nodes that receive the
data from the training set, and the output will be the predictions. If you're
unsure, the
-[summarize_graph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/summarize_graph_main.cc)
+[`summarize_graph`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/summarize_graph_main.cc)
tool can inspect the model and provide guesses about likely input and output nodes,
as well as other information that's useful for debugging. Here's an example of
how to use it on the [Inception V3
@@ -315,7 +315,7 @@ themselves contain commas (for example shape definitions).
The --inputs and --outputs are shared across all transforms, since it's common
to need to know what the ingoing and outgoing nodes in the graph are. You should
make sure you set these correctly before calling the graph transform tool, and
-if you're in doubt check with the model's author, or use the `check_graph` tool
+if you're in doubt check with the model's author, or use the [`summarize_graph`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms#inspecting-graphs) tool
to examine likely inputs and outputs.
All transforms can be passed the `ignore_errors` flag, with the value set to