aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-01-29 09:34:18 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-29 20:15:13 -0800
commit8a59748c087a2fee535c0d5067dbabb01920e812 (patch)
tree179f23b84fb0c47cf17d9551f62e9a6c11c32f79
parentfaf747a15d4efc8fff03a10a3fdb37393197c2d3 (diff)
Use cc_binary rather than cc_library to reduce size of native library in APK from 5.5mb to 3.2mb (compressed).
Change: 113369407
-rwxr-xr-xconfigure61
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/client/tensor_c_api.cc92
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc97
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc329
-rw-r--r--tensorflow/core/common_runtime/direct_session.h94
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc135
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util.cc282
-rw-r--r--tensorflow/core/framework/op_def_util.cc21
-rw-r--r--tensorflow/core/framework/op_def_util.h4
-rw-r--r--tensorflow/core/graph/tensor_id.h8
-rw-r--r--tensorflow/core/graph/validate.cc47
-rw-r--r--tensorflow/core/graph/validate.h24
-rw-r--r--tensorflow/core/graph/validate_test.cc65
-rw-r--r--tensorflow/core/kernels/matrix_solve_ls_op.cc4
-rw-r--r--tensorflow/core/kernels/sparse_xent_op.cc105
-rw-r--r--tensorflow/core/kernels/sparse_xent_op.h204
-rw-r--r--tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc51
-rw-r--r--tensorflow/core/kernels/sparse_xent_op_test.cc78
-rw-r--r--tensorflow/core/ops/nn_ops.cc23
-rw-r--r--tensorflow/core/ops/ops.pbtxt35
-rw-r--r--tensorflow/core/platform/default/build_config.bzl10
-rw-r--r--tensorflow/core/platform/default/build_config/BUILD3
-rw-r--r--tensorflow/core/public/session.h27
-rw-r--r--tensorflow/core/public/tensor_c_api.h36
-rw-r--r--tensorflow/examples/android/BUILD32
-rw-r--r--tensorflow/examples/udacity/README.md18
-rw-r--r--tensorflow/g3doc/api_docs/python/array_ops.md3
-rw-r--r--tensorflow/g3doc/api_docs/python/index.md1
-rw-r--r--tensorflow/g3doc/api_docs/python/nn.md46
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md23
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/index.md4
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/client/session.py215
-rw-r--r--tensorflow/python/client/session_test.py45
-rw-r--r--tensorflow/python/client/tf_session.i37
-rw-r--r--tensorflow/python/client/tf_session_helper.cc73
-rw-r--r--tensorflow/python/client/tf_session_helper.h31
-rw-r--r--tensorflow/python/framework/ops.py2
-rw-r--r--tensorflow/python/framework/python_op_gen.cc14
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py37
-rw-r--r--tensorflow/python/kernel_tests/sparse_xent_op_test.py256
-rw-r--r--tensorflow/python/kernel_tests/tensor_array_ops_test.py15
-rw-r--r--tensorflow/python/ops/control_flow_ops.py56
-rw-r--r--tensorflow/python/ops/nn.py1
-rw-r--r--tensorflow/python/ops/nn_grad.py8
-rw-r--r--tensorflow/python/ops/nn_ops.py57
-rw-r--r--tensorflow/python/ops/rnn_cell.py70
-rw-r--r--tensorflow/python/ops/tensor_array_grad.py10
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py22
-rw-r--r--tensorflow/stream_executor/dso_loader.cc18
-rw-r--r--tensorflow/tensorboard/BUILD8
-rw-r--r--tensorflow/tensorboard/backend/__init__.py0
-rw-r--r--tensorflow/tensorboard/backend/float_wrapper.py (renamed from tensorflow/tensorboard/float_wrapper.py)0
-rw-r--r--tensorflow/tensorboard/backend/float_wrapper_test.py (renamed from tensorflow/tensorboard/float_wrapper_test.py)2
-rw-r--r--tensorflow/tensorboard/backend/tensorboard.py (renamed from tensorflow/tensorboard/tensorboard.py)2
-rw-r--r--tensorflow/tensorboard/backend/tensorboard_handler.py (renamed from tensorflow/tensorboard/tensorboard_handler.py)2
-rw-r--r--tensorflow/tensorboard/components/tf-image-dashboard/tf-image-loader.html1
-rw-r--r--third_party/gpus/cuda/BUILD36
-rwxr-xr-xthird_party/gpus/cuda/cuda_config.sh32
60 files changed, 2576 insertions, 438 deletions
diff --git a/configure b/configure
index da5fb25f8f..f217de4e93 100755
--- a/configure
+++ b/configure
@@ -1,5 +1,9 @@
#!/bin/bash
+if [ "$TF_UNOFFICIAL_SETTING" == "1" ]; then
+ echo -e "\nWARNING: You are configuring unofficial settings in TensorFlow. Because some external libraries are not backward compatible, these settings are largely untested and unsupported. \n" 1>&2
+fi
+
## Set up python-related environment settings
while true; do
fromuser=""
@@ -44,32 +48,55 @@ fi
# Find out where the CUDA toolkit is installed
while true; do
+ # Configure the Cuda SDK version to use.
+ default_cuda_version="7.0"
+ if [ "$TF_UNOFFICIAL_SETTING" == "1" ]; then
+ if [ -z "$TF_CUDA_VERSION" ]; then
+ read -p "Please specify the Cuda SDK version you want to use. [Default is $default_cuda_version]: " TF_CUDA_VERSION
+ fi
+ fi
+ if [ -z "$TF_CUDA_VERSION" ]; then
+ TF_CUDA_VERSION=$default_cuda_version
+ fi
+
fromuser=""
if [ -z "$CUDA_TOOLKIT_PATH" ]; then
default_cuda_path=/usr/local/cuda
- read -p "Please specify the location where CUDA 7.0 toolkit is installed. Refer to README.md for more details. [Default is $default_cuda_path]: " CUDA_TOOLKIT_PATH
+ read -p "Please specify the location where CUDA $TF_CUDA_VERSION toolkit is installed. Refer to README.md for more details. [Default is $default_cuda_path]: " CUDA_TOOLKIT_PATH
fromuser="1"
if [ -z "$CUDA_TOOLKIT_PATH" ]; then
CUDA_TOOLKIT_PATH=$default_cuda_path
fi
fi
- if [ -e "$CUDA_TOOLKIT_PATH/lib64/libcudart.so.7.0" ]; then
+ if [ -e "$CUDA_TOOLKIT_PATH/lib64/libcudart.so.$TF_CUDA_VERSION" ]; then
break
fi
- echo "Invalid path to CUDA 7.0 toolkit. ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so.7.0 cannot be found"
+ echo "Invalid path to CUDA $TF_CUDA_VERSION toolkit. ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so.$TF_CUDA_VERSION cannot be found"
if [ -z "$fromuser" ]; then
exit 1
fi
+ TF_CUDA_VERSION=""
CUDA_TOOLKIT_PATH=""
# Retry
done
# Find out where the cuDNN library is installed
while true; do
+ # Configure the Cudnn version to use.
+ default_cudnn_version="6.5"
+ if [ "$TF_UNOFFICIAL_SETTING" == "1" ]; then
+ if [ -z "$TF_CUDNN_VERSION" ]; then
+ read -p "Please specify the Cudnn version you want to use. [Default is $default_cudnn_version]: " TF_CUDNN_VERSION
+ fi
+ fi
+ if [ -z "$TF_CUDNN_VERSION" ]; then
+ TF_CUDNN_VERSION=$default_cudnn_version
+ fi
+
fromuser=""
if [ -z "$CUDNN_INSTALL_PATH" ]; then
default_cudnn_path=${CUDA_TOOLKIT_PATH}
- read -p "Please specify the location where cuDNN v2 library is installed. Refer to README.md for more details. [Default is $default_cudnn_path]: " CUDNN_INSTALL_PATH
+ read -p "Please specify the location where cuDNN $TF_CUDNN_VERSION library is installed. Refer to README.md for more details. [Default is $default_cudnn_path]: " CUDNN_INSTALL_PATH
fromuser="1"
if [ -z "$CUDNN_INSTALL_PATH" ]; then
CUDNN_INSTALL_PATH=$default_cudnn_path
@@ -78,21 +105,22 @@ while true; do
# Going through one more level of expansion to handle that.
CUDNN_INSTALL_PATH=$(bash -c "readlink -f $CUDNN_INSTALL_PATH")
fi
- if [ -e "$CUDNN_INSTALL_PATH/libcudnn.so.6.5" -o -e "$CUDNN_INSTALL_PATH/lib64/libcudnn.so.6.5" ]; then
+ if [ -e "$CUDNN_INSTALL_PATH/libcudnn.so.${TF_CUDNN_VERSION}" -o -e "$CUDNN_INSTALL_PATH/lib64/libcudnn.so.${TF_CUDNN_VERSION}" ]; then
break
fi
- echo "Invalid path to cuDNN v2 toolkit. Neither of the following two files can be found:"
- echo "$CUDNN_INSTALL_PATH/lib64/libcudnn.so.6.5"
- echo "$CUDNN_INSTALL_PATH/libcudnn.so.6.5"
+ echo "Invalid path to cuDNN ${TF_CUDNN_VERSION} toolkit. Neither of the following two files can be found:"
+ echo "$CUDNN_INSTALL_PATH/lib64/libcudnn.so.${TF_CUDNN_VERSION}"
+ echo "$CUDNN_INSTALL_PATH/libcudnn.so.${TF_CUDNN_VERSION}"
if [ -z "$fromuser" ]; then
exit 1
fi
+ TF_CUDNN_VERSION=""
CUDNN_INSTALL_PATH=""
# Retry
done
cat > third_party/gpus/cuda/cuda.config <<EOF
-# CUDA_TOOLKIT_PATH refers to the CUDA toolkit. Tensorflow requires Cuda 7.0
+# CUDA_TOOLKIT_PATH refers to the CUDA toolkit. Tensorflow requires Cuda $TF_CUDA_VERSION
# at the moment.
CUDA_TOOLKIT_PATH="$CUDA_TOOLKIT_PATH"
@@ -100,10 +128,23 @@ CUDA_TOOLKIT_PATH="$CUDA_TOOLKIT_PATH"
# files can be either in this directory, or under include/ and lib64/
# directories separately.
CUDNN_INSTALL_PATH="$CUDNN_INSTALL_PATH"
+
+# The Cuda SDK version that should be used in this build
+TF_CUDA_VERSION=$TF_CUDA_VERSION
+
+# The Cudnn version that should be used in this build
+TF_CUDNN_VERSION=$TF_CUDNN_VERSION
+
EOF
function UnofficialSetting() {
- echo -e "\nWARNING: You are configuring unofficial settings in TensorFlow. Because some external libraries are not backward compatible, these settings are largely untested and unsupported. \n" 1>&2
+ # Configure the Cuda toolkit version to work with.
+ perl -pi -e "s,CUDA_VERSION = '[0-9\.]*',CUDA_VERSION = '$TF_CUDA_VERSION',s" tensorflow/core/platform/default/build_config.bzl
+ perl -pi -e "s,(GetCudaVersion.*return )\"[0-9\.]*\",\1\"$TF_CUDA_VERSION\",s" tensorflow/stream_executor/dso_loader.cc
+
+ # Configure the Cudnn version to work with.
+ perl -pi -e "s,CUDNN_VERSION = '[0-9\.]*',CUDNN_VERSION = '$TF_CUDNN_VERSION',s" tensorflow/core/platform/default/build_config.bzl
+ perl -pi -e "s,(GetCudnnVersion.*return )\"[0-9\.]*\",\1\"$TF_CUDNN_VERSION\",s" tensorflow/stream_executor/dso_loader.cc
# Configure the compute capabilities that TensorFlow builds for.
# Since Cuda toolkit is not backward-compatible, this is not guaranteed to work.
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 95d5531876..cb4d5e723c 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -298,6 +298,7 @@ tf_cuda_library(
"graph/graph_constructor.h",
"graph/graph_def_builder.h",
"graph/node_builder.h",
+ "graph/validate.h",
"public/session.h",
"public/session_options.h",
"public/tensor_c_api.h",
diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/core/client/tensor_c_api.cc
index 2140fc8afd..853c309091 100644
--- a/tensorflow/core/client/tensor_c_api.cc
+++ b/tensorflow/core/client/tensor_c_api.cc
@@ -316,16 +316,16 @@ Status LoadLibrary(const char* library_filename, void** result,
} // namespace tensorflow
-extern "C" {
-
-void TF_Run(TF_Session* s,
- // Input tensors
- const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
- // Output tensors
- const char** c_output_tensor_names, TF_Tensor** c_outputs,
- int noutputs,
- // Target nodes
- const char** c_target_node_names, int ntargets, TF_Status* status) {
+void TF_Run_Helper(TF_Session* s, const char* handle,
+ // Input tensors
+ const char** c_input_names, TF_Tensor** c_inputs,
+ int ninputs,
+ // Output tensors
+ const char** c_output_tensor_names, TF_Tensor** c_outputs,
+ int noutputs,
+ // Target nodes
+ const char** c_target_node_names, int ntargets,
+ TF_Status* status) {
status->status = Status::OK();
for (int i = 0; i < noutputs; i++) {
c_outputs[i] = NULL;
@@ -365,8 +365,13 @@ void TF_Run(TF_Session* s,
for (int i = 0; i < ntargets; i++) {
target_node_names[i] = c_target_node_names[i];
}
- Status result =
- s->session->Run(inputs, output_tensor_names, target_node_names, &outputs);
+ Status result;
+ if (handle == nullptr) {
+ result = s->session->Run(inputs, output_tensor_names, target_node_names,
+ &outputs);
+ } else {
+ result = s->session->PRun(handle, inputs, output_tensor_names, &outputs);
+ }
if (!result.ok()) {
status->status = result;
return;
@@ -392,6 +397,69 @@ void TF_Run(TF_Session* s,
}
}
+extern "C" {
+
+void TF_Run(TF_Session* s,
+ // Input tensors
+ const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
+ // Output tensors
+ const char** c_output_tensor_names, TF_Tensor** c_outputs,
+ int noutputs,
+ // Target nodes
+ const char** c_target_node_names, int ntargets, TF_Status* status) {
+ TF_Run_Helper(s, nullptr, c_input_names, c_inputs, ninputs,
+ c_output_tensor_names, c_outputs, noutputs, c_target_node_names,
+ ntargets, status);
+}
+
+void TF_PRunSetup(TF_Session* s,
+ // Input names
+ const char** c_input_names, int ninputs,
+ // Output names
+ const char** c_output_tensor_names, int noutputs,
+ // Target nodes
+ const char** c_target_node_names, int ntargets, char** handle,
+ TF_Status* status) {
+ status->status = Status::OK();
+
+ std::vector<tensorflow::string> input_names(ninputs);
+ std::vector<tensorflow::string> output_tensor_names(noutputs);
+ std::vector<tensorflow::string> target_node_names(ntargets);
+ for (int i = 0; i < ninputs; i++) {
+ input_names[i] = c_input_names[i];
+ }
+ for (int i = 0; i < noutputs; i++) {
+ output_tensor_names[i] = c_output_tensor_names[i];
+ }
+ for (int i = 0; i < ntargets; i++) {
+ target_node_names[i] = c_target_node_names[i];
+ }
+ tensorflow::string new_handle;
+ Status result;
+ result = s->session->PRunSetup(input_names, output_tensor_names,
+ target_node_names, &new_handle);
+ if (result.ok()) {
+ *handle = new char[new_handle.size() + 1];
+ memcpy(*handle, new_handle.c_str(), new_handle.size() + 1);
+ } else {
+ status->status = result;
+ }
+}
+
+void TF_PRun(TF_Session* s, const char* handle,
+ // Input tensors
+ const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
+ // Output tensors
+ const char** c_output_tensor_names, TF_Tensor** c_outputs,
+ int noutputs,
+ // Target nodes
+ const char** c_target_node_names, int ntargets,
+ TF_Status* status) {
+ TF_Run_Helper(s, handle, c_input_names, c_inputs, ninputs,
+ c_output_tensor_names, c_outputs, noutputs, c_target_node_names,
+ ntargets, status);
+}
+
const void* TF_BufferData(TF_Buffer* buffer) { return buffer->data; }
size_t TF_BufferLength(TF_Buffer* buffer) { return buffer->length; }
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index e0a5fef6d0..00f3f17d78 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -53,58 +53,57 @@ void CopyTensor::ViaDMA(const string& edge_name,
StatusCallback done) {
initialization_done = true;
port::Tracing::ScopedAnnotation annotation(edge_name);
- VLOG(1) << "CopyViaDMA " << edge_name;
- const size_t total_bytes = input->TotalBytes();
-
- // Note that 0-size tensors have no backing buffer.
- if (total_bytes > 0) {
- const DeviceType src_device_type(src_alloc_attr.on_host()
- ? DEVICE_CPU
- : src->attributes().device_type());
- const DeviceType dst_device_type(dst_alloc_attr.on_host()
- ? DEVICE_CPU
- : dst->attributes().device_type());
- const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU);
- const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU);
-
- if (non_cpu_src) {
- if (non_cpu_dst) {
- // Device to device copy. Look through registry for an appropriate
- // CopyFunction.
- std::vector<RegistrationInfo>* registry = MutableRegistry();
- for (const RegistrationInfo& ri : *registry) {
- if (ri.sender_device_type == src_device_type &&
- ri.receiver_device_type == dst_device_type) {
- ri.copy_function(send_dev_context, recv_dev_context, src, dst,
- src_alloc_attr, dst_alloc_attr, input, output,
- done);
- return;
- }
- }
-
- // TODO(josh11b): If no CopyFunction is found, we currently fail
- // but we could copy between devices via CPU.
- done(errors::Unimplemented(
- "No function registered to copy from devices of type ",
- src_device_type.type(), " to devices of type ",
- dst_device_type.type()));
- } else {
- // Device to host copy.
- return send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src,
- output, done);
+ VLOG(1) << "Copy " << edge_name;
+
+ const DeviceType src_device_type(
+ src_alloc_attr.on_host() ? DEVICE_CPU : src->attributes().device_type());
+ const DeviceType dst_device_type(
+ dst_alloc_attr.on_host() ? DEVICE_CPU : dst->attributes().device_type());
+ const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU);
+ const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU);
+
+ // E.g., gpu -> gpu
+ if (non_cpu_src && non_cpu_dst) {
+ // Device to device copy. Look through registry for an appropriate
+ // CopyFunction.
+ std::vector<RegistrationInfo>* registry = MutableRegistry();
+ for (const RegistrationInfo& ri : *registry) {
+ if (ri.sender_device_type == src_device_type &&
+ ri.receiver_device_type == dst_device_type) {
+ ri.copy_function(send_dev_context, recv_dev_context, src, dst,
+ src_alloc_attr, dst_alloc_attr, input, output, done);
+ return;
}
- } else if (non_cpu_dst) {
- // Host to Device copy.
- // Note that this is already an async copy.
- recv_dev_context->CopyCPUTensorToDevice(input, dst, output, done);
- } else {
- *output = *input;
- done(Status::OK());
}
- } else {
- // buffer is empty
- done(Status::OK());
+
+ // TODO(josh11b): If no CopyFunction is found, we currently fail
+ // but we could copy between devices via CPU.
+ done(errors::Unimplemented(
+ "No function registered to copy from devices of type ",
+ src_device_type.type(), " to devices of type ",
+ dst_device_type.type()));
+ return;
+ }
+
+ // E.g., gpu -> cpu
+ if (non_cpu_src && !non_cpu_dst) {
+ // Device to host copy.
+ send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, output,
+ done);
+ return;
}
+
+ // E.g., cpu -> gpu
+ if (!non_cpu_src && non_cpu_dst) {
+ // Host to Device copy.
+ recv_dev_context->CopyCPUTensorToDevice(input, dst, output, done);
+ return;
+ }
+
+ // cpu -> cpu
+ CHECK(!non_cpu_src && !non_cpu_dst);
+ *output = *input;
+ done(Status::OK());
}
// static
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 199916b3c4..ae8f934c20 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/session_factory.h"
#include "tensorflow/core/common_runtime/simple_placer.h"
#include "tensorflow/core/framework/function.h"
@@ -156,6 +155,9 @@ DirectSession::DirectSession(const SessionOptions& options,
}
DirectSession::~DirectSession() {
+ for (auto& it : partial_runs_) {
+ delete it.second;
+ }
for (auto d : device_mgr_->ListDevices()) {
d->op_segment()->RemoveHold(session_handle_);
}
@@ -240,7 +242,8 @@ Status DirectSession::ExtendLocked(const GraphDef& graph) {
return Status::OK();
}
-Status DirectSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
+// TODO(yuanbyu): Simplify by treating Run() as "PRunSetup(); PRun()".
+Status DirectSession::Run(const NamedTensorList& inputs,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs) {
@@ -261,49 +264,213 @@ Status DirectSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
// Check if we already have an executor for these arguments.
ExecutorsAndKeys* executors_and_keys;
- Status s = GetOrCreateExecutors(input_tensor_names, output_names,
- target_nodes, &executors_and_keys);
- if (!s.ok()) {
- return s;
+ RunStateArgs run_state_args;
+ TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
+ target_nodes, &executors_and_keys,
+ &run_state_args));
+
+ // Create a run state and start execution.
+ RunState run_state(input_tensor_names, output_names);
+ run_state.graph = run_state_args.graph;
+ run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
+
+ // Send inputs.
+ TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez));
+
+ // Start parallel Executors.
+ const int num_executors = executors_and_keys->items.size();
+ ExecutorBarrier* barrier = new ExecutorBarrier(
+ num_executors, run_state.rendez, [&run_state](const Status& ret) {
+ run_state.status = ret;
+ run_state.executors_done.Notify();
+ });
+
+ Executor::Args args;
+ args.rendezvous = run_state.rendez;
+ args.cancellation_manager = cancellation_manager_;
+ args.runner = [this](Executor::Args::Closure c) { SchedClosure(c); };
+
+ for (const auto& item : executors_and_keys->items) {
+ item.executor->RunAsync(args, barrier->Get());
}
- IntraProcessRendezvous* rendez =
- new IntraProcessRendezvous(device_mgr_.get());
- core::ScopedUnref rendez_unref(rendez);
+ run_state.executors_done.WaitForNotification();
+ TF_RETURN_IF_ERROR(run_state.status);
- // Insert the input tensors into the local rendezvous by their
- // rendezvous key.
- for (const auto& input : inputs) {
- const string& input_key = executors_and_keys->input_keys[input.first];
- s = rendez->Send(input_key, Rendezvous::Args(), input.second, false);
- if (!s.ok()) {
- rendez->StartAbort(s);
- return s;
+ // Receive outputs.
+ TF_RETURN_IF_ERROR(
+ RecvOutputs(output_names, executors_and_keys, &run_state, outputs));
+ return Status::OK();
+}
+
+Status DirectSession::PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ {
+ mutex_lock l(graph_def_lock_);
+ if (!graph_created_) {
+ return errors::InvalidArgument(
+ "Session was not created with a graph before PRunSetup()!");
+ }
+ }
+
+ // Check if we already have an executor for these arguments.
+ ExecutorsAndKeys* executors_and_keys;
+ RunStateArgs run_state_args;
+ run_state_args.is_partial_run = true;
+ Status s = GetOrCreateExecutors(input_names, output_names, target_nodes,
+ &executors_and_keys, &run_state_args);
+ TF_RETURN_IF_ERROR(s);
+
+ // Create the run state and save it for future PRun calls.
+ RunState* run_state = new RunState(input_names, output_names);
+ run_state->graph = run_state_args.graph;
+ run_state->rendez = new IntraProcessRendezvous(device_mgr_.get());
+ {
+ mutex_lock l(executor_lock_);
+ if (!partial_runs_.insert({run_state_args.handle, run_state}).second) {
+ return errors::Internal("The handle ", run_state_args.handle,
+ " created for this partial"
+ " run is not unique.");
}
}
// Start parallel Executors.
- Notification executors_done;
+ Notification& executors_done = run_state->executors_done;
+ Status* run_status = &run_state->status;
const int num_executors = executors_and_keys->items.size();
ExecutorBarrier* barrier = new ExecutorBarrier(
- num_executors, rendez, [&executors_done, &s](const Status& ret) {
- s = ret;
+ num_executors, run_state->rendez,
+ [&executors_done, run_status, this](const Status& ret) {
+ if (!ret.ok()) {
+ mutex_lock l(executor_lock_);
+ *run_status = ret;
+ }
executors_done.Notify();
});
Executor::Args args;
- args.rendezvous = rendez;
+ args.rendezvous = run_state->rendez;
args.cancellation_manager = cancellation_manager_;
args.runner = [this](Executor::Args::Closure c) { SchedClosure(c); };
- for (const auto& item : executors_and_keys->items) {
- item.executor->RunAsync(args, barrier->Get());
+ for (auto& item : executors_and_keys->items) {
+ Executor* exec = item.executor;
+ exec->RunAsync(args, barrier->Get());
}
- executors_done.WaitForNotification();
+ *handle = run_state_args.handle;
+ return Status::OK();
+}
- TF_RETURN_IF_ERROR(s);
+Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ std::vector<string> parts = str_util::Split(handle, ';');
+ const string& key = parts[0];
+ // Get the executors for this partial run.
+ ExecutorsAndKeys* executors_and_keys;
+ RunState* run_state;
+ {
+ mutex_lock l(executor_lock_); // could use reader lock
+ auto exc_it = executors_.find(key);
+ if (exc_it == executors_.end()) {
+ return errors::InvalidArgument(
+ "Must run 'setup' before performing partial runs!");
+ }
+ executors_and_keys = exc_it->second;
+
+ auto prun_it = partial_runs_.find(handle);
+ if (prun_it == partial_runs_.end()) {
+ return errors::InvalidArgument(
+ "Must run 'setup' before performing partial runs!");
+ }
+ run_state = prun_it->second;
+
+ // Make sure that this is a new set of feeds that are still pending.
+ for (const auto& input : inputs) {
+ auto it = run_state->pending_inputs.find(input.first);
+ if (it == run_state->pending_inputs.end()) {
+ return errors::InvalidArgument("The feed ", input.first,
+ " had already been fed.");
+ }
+ }
+ // Check that this is a new set of fetches that are still pending.
+ for (const auto& output : output_names) {
+ auto it = run_state->pending_outputs.find(output);
+ if (it == run_state->pending_outputs.end()) {
+ return errors::InvalidArgument("The fetch ", output,
+ " had already been fetched.");
+ }
+ }
+ }
+ // Check that this new set of fetches can be computed from all the
+ // feeds we have supplied.
+ TF_RETURN_IF_ERROR(CheckFetch(inputs, output_names, run_state));
+
+ // Send inputs.
+ Status s = SendInputs(inputs, executors_and_keys, run_state->rendez);
+
+ // Receive outputs.
+ if (s.ok()) {
+ s = RecvOutputs(output_names, executors_and_keys, run_state, outputs);
+ }
+
+ // Delete the run state if there is an error or all fetches are done.
+ {
+ mutex_lock l(executor_lock_);
+ bool done = true;
+ if (s.ok()) {
+ if (!run_state->status.ok()) {
+ LOG(WARNING) << "An error unrelated to this prun has been detected. "
+ << run_state->status;
+ }
+ for (const auto& it : inputs) {
+ run_state->pending_inputs.erase(it.first);
+ }
+ for (const auto& name : output_names) {
+ run_state->pending_outputs.erase(name);
+ }
+ done = run_state->pending_outputs.size() == 0;
+ }
+ if (done) {
+ run_state->executors_done.WaitForNotification();
+ partial_runs_.erase(handle);
+ delete run_state;
+ }
+ }
+ return s;
+}
+
+Status DirectSession::SendInputs(const NamedTensorList& inputs,
+ const ExecutorsAndKeys* executors_and_keys,
+ IntraProcessRendezvous* rendez) {
+ Status s;
+ // Insert the input tensors into the local rendezvous by their
+ // rendezvous key.
+ for (const auto& input : inputs) {
+ auto it = executors_and_keys->input_keys.find(input.first);
+ if (it == executors_and_keys->input_keys.end()) {
+ return errors::InvalidArgument("'", input.first,
+ "' is not a pre-defined feed!");
+ }
+ const string& input_key = it->second;
+ s = rendez->Send(input_key, Rendezvous::Args(), input.second, false);
+ if (!s.ok()) {
+ rendez->StartAbort(s);
+ return s;
+ }
+ }
+ return Status::OK();
+}
+
+Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
+ const ExecutorsAndKeys* executors_and_keys,
+ RunState* run_state,
+ std::vector<Tensor>* outputs) {
+ Status s;
if (!output_names.empty()) {
outputs->resize(output_names.size());
}
@@ -311,14 +478,21 @@ Status DirectSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
// Get the outputs from the rendezvous
for (size_t output_offset = 0; output_offset < output_names.size();
++output_offset) {
- const string& output_key =
- executors_and_keys->output_keys[output_names[output_offset]];
+ const string& output_name = output_names[output_offset];
+ auto it = executors_and_keys->output_keys.find(output_name);
+ if (it == executors_and_keys->output_keys.end()) {
+ return errors::InvalidArgument("'", output_name,
+ "' was not defined as a fetch"
+ " target in PRunSetup.");
+ }
+ const string& output_key = it->second;
Tensor output_tensor;
bool is_dead;
// Fetch data from the Rendezvous.
+ IntraProcessRendezvous* rendez = run_state->rendez;
s = rendez->Recv(output_key, Rendezvous::Args(), &output_tensor, &is_dead);
- if (is_dead) {
+ if (is_dead && s.ok()) {
s = errors::InvalidArgument("The tensor returned for ",
output_names[output_offset],
" was not valid.");
@@ -331,14 +505,74 @@ Status DirectSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
(*outputs)[output_offset] = output_tensor;
}
+ return Status::OK();
+}
- return s;
+Status DirectSession::CheckFetch(const NamedTensorList& feeds,
+ const std::vector<string>& fetches,
+ const RunState* run_state) {
+ const Graph* g = run_state->graph;
+ std::unordered_map<StringPiece, Node*, StringPiece::Hasher> name_to_node;
+ for (Node* n : g->nodes()) {
+ name_to_node[n->name()] = n;
+ }
+
+ // Build the set of pending feeds that we haven't seen.
+ std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
+ {
+ mutex_lock l(executor_lock_);
+ for (const string& feed : run_state->pending_inputs) {
+ TensorId id(ParseTensorName(feed));
+ auto it = name_to_node.find(id.first);
+ if (it == name_to_node.end()) {
+ return errors::NotFound("Feed ", feed, ": not found");
+ }
+ pending_feeds.insert(id);
+ }
+ }
+ for (const auto& it : feeds) {
+ TensorId id(ParseTensorName(it.first));
+ pending_feeds.erase(id);
+ }
+
+ // Initialize the stack with the fecth nodes.
+ std::vector<const Node*> stack;
+ for (const string& fetch : fetches) {
+ TensorId id(ParseTensorName(fetch));
+ auto it = name_to_node.find(id.first);
+ if (it == name_to_node.end()) {
+ return errors::NotFound("Fetch ", fetch, ": not found");
+ }
+ stack.push_back(it->second);
+ }
+
+ // Any tensor needed for fetches can't be in pending_feeds.
+ std::vector<bool> visited(g->num_node_ids(), false);
+ while (!stack.empty()) {
+ const Node* n = stack.back();
+ stack.pop_back();
+
+ for (const Edge* in_edge : n->in_edges()) {
+ const Node* in_node = in_edge->src();
+ if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
+ return errors::InvalidArgument("Fetch ", in_node->name(), ":",
+ in_edge->src_output(),
+ " can't be computed from the feeds"
+ " that have been fed so far.");
+ }
+ if (!visited[in_node->id()]) {
+ visited[in_node->id()] = true;
+ stack.push_back(in_node);
+ }
+ }
+ }
+ return Status::OK();
}
Status DirectSession::GetOrCreateExecutors(
gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
- gtl::ArraySlice<string> target_nodes,
- ExecutorsAndKeys** executors_and_keys) {
+ gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
+ RunStateArgs* run_state_args) {
// Sort the inputs and outputs, so we don't create separate
// executors when a user passes in the same inputs/outputs in
// different orders.
@@ -370,10 +604,9 @@ Status DirectSession::GetOrCreateExecutors(
// being created.
FunctionLibraryDefinition* fdefs;
std::unordered_map<string, Graph*> graphs;
- Status s = CreateGraphs(inputs, outputs, target_nodes, &fdefs, &graphs);
- if (!s.ok()) {
- return s;
- }
+ Status s = CreateGraphs(inputs, outputs, target_nodes, &fdefs, &graphs,
+ run_state_args);
+ TF_RETURN_IF_ERROR(s);
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
ek->func_defs = fdefs;
@@ -386,9 +619,7 @@ Status DirectSession::GetOrCreateExecutors(
Device* device;
s = device_mgr_->LookupDevice(partition_name, &device);
- if (!s.ok()) {
- return s;
- }
+ TF_RETURN_IF_ERROR(s);
ek->items.resize(ek->items.size() + 1);
auto* item = &(ek->items.back());
@@ -434,6 +665,12 @@ Status DirectSession::GetOrCreateExecutors(
output, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
}
+ // Set the handle.
+ {
+ mutex_lock l(mu_);
+ run_state_args->handle = strings::StrCat(key, ";", name_counter_++);
+ }
+
// Reacquire the lock, try to insert into the map.
mutex_lock l(executor_lock_);
const bool inserted = executors_.insert(std::make_pair(key, ek.get())).second;
@@ -470,10 +707,12 @@ void DirectSession::RestoreStatefulNodes(Graph* graph) {
}
}
-Status DirectSession::CreateGraphs(
- gtl::ArraySlice<string> feeds, gtl::ArraySlice<string> fetches,
- gtl::ArraySlice<string> target_nodes, FunctionLibraryDefinition** func_defs,
- std::unordered_map<string, Graph*>* outputs) {
+Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds,
+ gtl::ArraySlice<string> fetches,
+ gtl::ArraySlice<string> target_nodes,
+ FunctionLibraryDefinition** func_defs,
+ std::unordered_map<string, Graph*>* outputs,
+ RunStateArgs* run_state_args) {
std::unique_ptr<FunctionLibraryDefinition> fdefs;
std::unique_ptr<Graph> graph;
GraphConstructorOptions opts{
@@ -511,6 +750,12 @@ Status DirectSession::CreateGraphs(
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def_, graph.get()));
}
+ // Remember the graph in run state if this is a partial run.
+ if (run_state_args->is_partial_run) {
+ run_state_args->graph = new Graph(fdefs.get());
+ CopyGraph(*graph.get(), run_state_args->graph);
+ }
+
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
graph.get(), feeds, fetches, target_nodes,
device_set_.client_device()->attributes()));
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 178dfe8851..0a7430ff04 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -19,11 +19,13 @@ limitations under the License.
#include <memory>
#include <string>
#include <unordered_map>
+#include <unordered_set>
#include <vector>
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
@@ -46,12 +48,24 @@ class DirectSession : public Session {
DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr);
~DirectSession() override;
+ typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
+
::tensorflow::Status Create(const GraphDef& graph) override;
::tensorflow::Status Extend(const GraphDef& graph) override;
- ::tensorflow::Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
+ ::tensorflow::Status Run(const NamedTensorList& inputs,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs) override;
+
+ // NOTE: PRunSetup and PRun are added to support partial execution. This
+ // feature is experimental and subject to change.
+ ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) override;
+ ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) override;
::tensorflow::Status Close() override;
private:
@@ -85,24 +99,86 @@ class DirectSession : public Session {
}
};
+ // For each live partial execution, the session maintains a RunState.
+ // 'status' is the current status of this partial execution. 'executor_done'
+ // is "notified" when all executors are done. 'graph' is the graph being
+ // executed. 'pending_inputs' are the set of pending feeds and
+ // 'pending_outputs' are the set of pending fetches.
+ struct RunState {
+ Status status;
+ IntraProcessRendezvous* rendez = nullptr;
+ Notification executors_done;
+ Graph* graph = nullptr;
+ std::unordered_set<string> pending_inputs;
+ std::unordered_set<string> pending_outputs;
+
+ RunState(const std::vector<string>& input_names,
+ const std::vector<string>& output_names) {
+ // Initially all the feeds and fetches are pending.
+ for (auto& name : input_names) {
+ pending_inputs.emplace(name);
+ }
+ for (auto& name : output_names) {
+ pending_outputs.emplace(name);
+ }
+ }
+
+ ~RunState() {
+ if (rendez != nullptr) {
+ if (!executors_done.HasBeenNotified()) {
+ rendez->StartAbort(errors::Cancelled("PRun cancellation"));
+ executors_done.WaitForNotification();
+ }
+ rendez->Unref();
+ }
+ delete graph;
+ }
+ };
+
+ struct RunStateArgs {
+ bool is_partial_run = false;
+ string handle;
+ Graph* graph = nullptr;
+ };
+
// Retrieves an already existing set of executors to run 'inputs' and
// 'outputs', or creates and caches them for future use.
::tensorflow::Status GetOrCreateExecutors(
gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
gtl::ArraySlice<string> target_nodes,
- ExecutorsAndKeys** executors_and_keys);
+ ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
// Creates several graphs given the existing graph_def_ and the
// input feeds and fetches, given 'devices'.
- ::tensorflow::Status CreateGraphs(
- gtl::ArraySlice<string> feeds, gtl::ArraySlice<string> fetches,
- gtl::ArraySlice<string> target_nodes,
- FunctionLibraryDefinition** func_defs,
- std::unordered_map<string, Graph*>* outputs);
+ ::tensorflow::Status CreateGraphs(gtl::ArraySlice<string> feeds,
+ gtl::ArraySlice<string> fetches,
+ gtl::ArraySlice<string> target_nodes,
+ FunctionLibraryDefinition** func_defs,
+ std::unordered_map<string, Graph*>* outputs,
+ RunStateArgs* run_state_args);
::tensorflow::Status ExtendLocked(const GraphDef& graph)
EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
+ // Feeds more inputs to the executors, triggering further execution.
+ ::tensorflow::Status SendInputs(
+ const std::vector<std::pair<string, Tensor>>& inputs,
+ const ExecutorsAndKeys* executors_and_keys,
+ IntraProcessRendezvous* rendez);
+
+ // Fetches more outputs from the executors. It waits until the output
+ // tensors are computed.
+ ::tensorflow::Status RecvOutputs(const std::vector<string>& output_names,
+ const ExecutorsAndKeys* executors_and_keys,
+ RunState* run_state,
+ std::vector<Tensor>* outputs);
+
+ // Check if the specified fetches can be computed from the feeds
+ // that we have already provided.
+ ::tensorflow::Status CheckFetch(
+ const std::vector<std::pair<string, Tensor>>& feeds,
+ const std::vector<string>& fetches, const RunState* run_state);
+
const SessionOptions options_;
// Device structures.
@@ -129,6 +205,10 @@ class DirectSession : public Session {
std::unordered_map<string, ExecutorsAndKeys*> executors_
GUARDED_BY(executor_lock_);
+ // Holds mappings from handle to partial run state.
+ std::unordered_map<string, RunState*> partial_runs_
+ GUARDED_BY(executor_lock_);
+
CancellationManager* cancellation_manager_;
// Saves and restores device placements for stateful nodes.
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 9109da76dd..68009428f8 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -298,8 +298,6 @@ TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
TEST(DirectSessionTest, MultipleFeedTest) {
GraphDef def;
Graph g(OpRegistry::Global());
- Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10}));
- var->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
Tensor first_value(DT_FLOAT, TensorShape({}));
first_value.scalar<float>()() = 1.0;
@@ -396,5 +394,138 @@ TEST(DirectSessionTest, DarthKernel) {
delete sess;
}
+TEST(DirectSessionTest, PartialRunTest) {
+ GraphDef def;
+ Graph g(OpRegistry::Global());
+
+ Tensor first_value(DT_FLOAT, TensorShape({}));
+ first_value.scalar<float>()() = 1.0;
+ Node* first_const = test::graph::Constant(&g, first_value);
+ Node* first_identity = test::graph::Identity(&g, first_const);
+
+ Tensor second_value(DT_FLOAT, TensorShape({}));
+ second_value.scalar<float>()() = 2.0;
+ Node* second_const = test::graph::Constant(&g, second_value);
+ Node* second_identity = test::graph::Identity(&g, second_const);
+
+ Node* third = test::graph::Add(&g, first_identity, second_identity);
+ Node* third_identity = test::graph::Identity(&g, third);
+
+ test::graph::ToGraphDef(&g, &def);
+
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def));
+
+ std::vector<Tensor> outputs;
+
+ string handle;
+ Status s = session->PRunSetup(
+ {first_const->name(), second_const->name()},
+ {first_identity->name() + ":0", second_identity->name() + ":0",
+ third_identity->name() + ":0"},
+ {}, &handle);
+ ASSERT_TRUE(s.ok());
+
+ Tensor value_11(DT_FLOAT, TensorShape({}));
+ value_11.scalar<float>()() = 11.0;
+ Tensor value_22(DT_FLOAT, TensorShape({}));
+ value_22.scalar<float>()() = 22.0;
+
+ // Feed first_const, fetch first_identity
+ s = session->PRun(handle, {{first_const->name(), value_11}},
+ {first_identity->name() + ":0"}, &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(1, outputs.size());
+ ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
+
+ // Feed second_const, fetch second_identity and third_identity
+ s = session->PRun(
+ handle, {{second_const->name(), value_22}},
+ {second_identity->name() + ":0", third_identity->name() + ":0"},
+ &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(2, outputs.size());
+ ASSERT_EQ(22.0, outputs[0].flat<float>()(0));
+ ASSERT_EQ(11.0 + 22.0, outputs[1].flat<float>()(0));
+}
+
+TEST(DirectSessionTest, PartialRunMissingFeed) {
+ GraphDef def;
+ Graph g(OpRegistry::Global());
+
+ Tensor first_value(DT_FLOAT, TensorShape({}));
+ first_value.scalar<float>()() = 1.0;
+ Node* first_const = test::graph::Constant(&g, first_value);
+ Node* first_identity = test::graph::Identity(&g, first_const);
+
+ Tensor second_value(DT_FLOAT, TensorShape({}));
+ second_value.scalar<float>()() = 2.0;
+ Node* second_const = test::graph::Constant(&g, second_value);
+ Node* second_identity = test::graph::Identity(&g, second_const);
+
+ Node* third = test::graph::Add(&g, first_identity, second_identity);
+ Node* third_identity = test::graph::Identity(&g, third);
+
+ test::graph::ToGraphDef(&g, &def);
+
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def));
+
+ std::vector<Tensor> outputs;
+
+ string handle;
+ Status s = session->PRunSetup({first_const->name(), second_const->name()},
+ {third_identity->name() + ":0"}, {}, &handle);
+ ASSERT_TRUE(s.ok());
+
+ // Feed first_const, fetch third_identity
+ Tensor value_11(DT_FLOAT, TensorShape({}));
+ value_11.scalar<float>()() = 11.0;
+ s = session->PRun(handle, {{first_const->name(), value_11}},
+ {third_identity->name() + ":0"}, &outputs);
+ ASSERT_TRUE(errors::IsInvalidArgument(s));
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("can't be computed from the feeds"));
+}
+
+TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
+ GraphDef def;
+ Graph g(OpRegistry::Global());
+
+ Tensor bool_value(DT_BOOL, TensorShape({}));
+ bool_value.scalar<bool>()() = true;
+ Node* bool_const = test::graph::Constant(&g, bool_value);
+ Node* switch_node = test::graph::Switch(&g, bool_const, bool_const);
+ Node* fourth_identity = test::graph::Identity(&g, switch_node, 1);
+
+ test::graph::ToGraphDef(&g, &def);
+
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def));
+
+ std::vector<Tensor> outputs;
+
+ string handle;
+ Status s = session->PRunSetup({switch_node->name() + ":1"},
+ {fourth_identity->name() + ":0"}, {}, &handle);
+ ASSERT_TRUE(s.ok());
+
+ // Fetch fourth_identity without feeds.
+ s = session->PRun(handle, {}, {fourth_identity->name() + ":0"}, &outputs);
+ ASSERT_TRUE(errors::IsInvalidArgument(s));
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("can't be computed from the feeds"));
+
+ // Feed switch_node:1 and fetch fourth_identity.
+ s = session->PRun(handle, {{switch_node->name() + ":1", bool_value}},
+ {fourth_identity->name() + ":0"}, &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(1, outputs.size());
+ ASSERT_EQ(true, outputs[0].flat<bool>()(0));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc
index 2accf92503..f34ac256d1 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_util.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc
@@ -37,6 +37,18 @@ limitations under the License.
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/util/util.h"
+// IMPLEMENTATION NOTE:
+//
+// 1. Within this module, we intentionally LOG(FATAL) if any stream
+// involved in memcpy becomes !stream->ok(), because TF process
+// today (1/2016) can not properly recover from such an error.
+//
+// 2. When 0-size tensor is being copied, we should not schedule a
+// copy ThenMemcpy since there is no byte to move. However, we must
+// ensure the causal ordering by arranging the copy done callback
+// happens-after all activities scheduled on the given stream being
+// finished.
+
// If this need to be runtime configurable, consider adding options to
// ConfigProto.
const tensorflow::int64 FLAGS_brain_gpu_util_debug_string_maxlen = 128;
@@ -50,60 +62,106 @@ namespace tensorflow {
namespace gpu = ::perftools::gputools;
+Status PrepareCopy(Device* device, const DeviceContext* ctx, const Tensor& src,
+ const Tensor* dst,
+ const DeviceBase::GpuDeviceInfo** dev_info,
+ gpu::Stream** stream) {
+ if (device == nullptr) {
+ return errors::Internal("Unexpected null device.");
+ }
+ auto di = device->tensorflow_gpu_device_info();
+ if (di == nullptr) {
+ return errors::Internal("Unexpected null device info.");
+ }
+ *dev_info = di;
+ if (ctx == nullptr) {
+ return errors::Internal("Unexpected null device context.");
+ }
+ auto gs = static_cast<const GPUDeviceContext*>(ctx)->stream();
+ if (gs == nullptr) {
+ return errors::Internal("No gpu stream is available.");
+ }
+ *stream = gs;
+ if (dst != nullptr) {
+ if (src.dtype() != dst->dtype()) {
+ return errors::Internal("Can't copy a tensor of ",
+ DataTypeString(src.dtype()), " into a tensor of ",
+ DataTypeString(dst->dtype()));
+ }
+ if (src.TotalBytes() != dst->TotalBytes()) {
+ return errors::Internal("Can't copy ", src.TotalBytes(),
+ " bytes of a tensor into another with ",
+ dst->TotalBytes(), " bytes buffer.");
+ }
+ if ((src.TotalBytes() > 0) && !src.IsInitialized()) {
+ return errors::Internal("Src tensor is not initialized.");
+ }
+ if ((dst->TotalBytes() > 0) && !dst->IsInitialized()) {
+ return errors::Internal("Dst tensor is not initialized.");
+ }
+ }
+ if (!DMAHelper::CanUseDMA(&src)) {
+ return errors::Internal("GPU copy from non-DMA ",
+ DataTypeString(src.dtype()), "tensor");
+ }
+ return Status::OK();
+}
+
+void* GetBase(const Tensor* src) {
+ return const_cast<void*>(DMAHelper::base(src));
+}
+
+void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
+
/*static*/
void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev,
const DeviceContext* device_context,
TensorProto* proto, bool is_dead,
StatusCallback done) {
VLOG(1) << "SetProtoFromGPU device_context " << device_context;
+ const DeviceBase::GpuDeviceInfo* dev_info = nullptr;
+ gpu::Stream* stream = nullptr;
+ Status s =
+ PrepareCopy(dev, device_context, tensor, nullptr, &dev_info, &stream);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+
// Tensor values need to be copied from GPU to CPU ram so that
// we can build the protobuf response for a RecvTensor RPC.
// "device context" identifies the stream where the _Send op executed.
- CHECK(device_context);
- gpu::Stream* stream =
- static_cast<const GPUDeviceContext*>(device_context)->stream();
-
- if (!DMAHelper::CanUseDMA(&tensor)) {
- done(errors::Internal(strings::StrCat(
- "GPU copy from non-DMA ", DataTypeString(tensor.dtype()), "tensor")));
- return;
- }
proto->set_dtype(tensor.dtype());
tensor.shape().AsProto(proto->mutable_tensor_shape());
- // Prepare a Cord with the right data buf size, and DMA the
- // data over from the GPU buffer. Note that 0-size tensors
- // do not have a backing buffer.
- const size_t num_bytes = is_dead ? 0 : tensor.TotalBytes();
- if (num_bytes > 0) {
+
+ // Prepare a proto with the right data buf size, and DMA the data
+ // over from the GPU buffer. Note that 0-size tensors do not have a
+ // backing buffer.
+ Allocator* alloc = nullptr;
+ char* buf = nullptr;
+ const int64 total_bytes = is_dead ? 0 : tensor.TotalBytes();
+ if (total_bytes > 0) {
port::Tracing::ScopedAnnotation annotation("SetProtoFromGPU");
- Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
- char* mb = alloc->Allocate<char>(num_bytes);
- const char* src_ptr =
- reinterpret_cast<const char*>(DMAHelper::base(&tensor));
- DeviceMemoryBase gpu_src_ptr(const_cast<char*>(src_ptr), num_bytes);
- stream->ThenMemcpy(mb, gpu_src_ptr, num_bytes);
- // Use of tensor may outlive stack scope, so keep a ref.
- TensorReference tensor_ref(tensor);
- dev->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
- stream, [stream, done, proto, mb, num_bytes, alloc, tensor_ref]() {
- if (!stream->ok()) {
- done(errors::Internal("SetProtoFromGPU: GPU Memcpy failed"));
- // TODO(pbar) We currently have no way to recover the
- // worker from a GPU stream in the error state. Until
- // there is a way to reset the CUDA driver, it is
- // preferable to crash the process and restart. Tracked
- // under b/23717097
- LOG(FATAL) << "SetProtoFromGPU: GPU Memcpy failed";
- return;
- }
- tensor_ref.Unref();
- port::CopyFromArray(proto->mutable_tensor_content(), mb, num_bytes);
- alloc->Deallocate<char>(mb, num_bytes);
- done(Status::OK());
- });
- } else {
- done(Status::OK());
+ alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
+ buf = alloc->Allocate<char>(total_bytes);
+ void* src_ptr = GetBase(&tensor);
+ DeviceMemoryBase gpu_src_ptr(src_ptr, total_bytes);
+ stream->ThenMemcpy(buf, gpu_src_ptr, total_bytes);
}
+ // Use of tensor may outlive stack scope, so keep a ref.
+ TensorReference tensor_ref(tensor);
+ dev_info->event_mgr->ThenExecute(stream, [stream, done, proto, buf,
+ total_bytes, alloc, tensor_ref]() {
+ if (!stream->ok()) {
+ LOG(FATAL) << "SetProtoFromGPU: GPU Memcpy failed";
+ }
+ tensor_ref.Unref();
+ if (total_bytes > 0) {
+ port::CopyFromArray(proto->mutable_tensor_content(), buf, total_bytes);
+ alloc->Deallocate<char>(buf, total_bytes);
+ }
+ done(Status::OK());
+ });
}
// static
@@ -114,67 +172,67 @@ void GPUUtil::DeviceToDeviceCopy(DeviceContext* send_dev_context,
AllocatorAttributes dst_alloc_attr,
const Tensor* input, Tensor* output,
StatusCallback done) {
- const void* src_ptr = DMAHelper::base(input);
- void* dst_ptr = DMAHelper::base(output);
- VLOG(2) << "src_ptr " << src_ptr << " dst_ptr " << dst_ptr;
- const size_t total_bytes = input->TotalBytes();
-
- gpu::Stream* stream = send_dev_context->stream();
- if (stream == nullptr) {
- done(errors::Internal("Failed to find device stream"));
+ const DeviceBase::GpuDeviceInfo* dev_info = nullptr;
+ gpu::Stream* stream = nullptr;
+ Status s =
+ PrepareCopy(src, send_dev_context, *input, output, &dev_info, &stream);
+ if (!s.ok()) {
+ done(s);
return;
}
- auto* src_dev_info = src->tensorflow_gpu_device_info();
- CHECK(src_dev_info);
- DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
- stream->ThenMemcpy(&gpu_dst_ptr,
- DeviceMemoryBase{const_cast<void*>(src_ptr), total_bytes},
- total_bytes);
- if (dst->attributes().device_type() == DeviceType(DEVICE_GPU).type()) {
- // Use of input may outlive stack scope, so keep a ref.
- TensorReference input_ref(*input);
- src_dev_info->event_mgr->ThenExecute(stream, [done, stream, input_ref]() {
- input_ref.Unref();
- if (!stream->ok()) {
- done(errors::Internal("GPU->GPU Memcpy failed"));
- } else {
- done(Status::OK());
- }
- });
+ const int64 total_bytes = input->TotalBytes();
+ if (total_bytes > 0) {
+ void* src_ptr = GetBase(input);
+ DeviceMemoryBase gpu_src_ptr(src_ptr, total_bytes);
+ void* dst_ptr = GetBase(output);
+ DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
+ VLOG(2) << "src_ptr " << src_ptr << " dst_ptr " << dst_ptr;
+ stream->ThenMemcpy(&gpu_dst_ptr, gpu_src_ptr, total_bytes);
}
+
+ // Use of input may outlive stack scope, so keep a ref.
+ TensorReference input_ref(*input);
+ dev_info->event_mgr->ThenExecute(stream, [done, stream, input_ref]() {
+ input_ref.Unref();
+ if (!stream->ok()) {
+ LOG(FATAL) << "GPU->GPU Memcpy failed";
+ }
+ done(Status::OK());
+ });
send_dev_context->MaintainLifetimeOnStream(input, stream);
}
static CopyTensor::Registration register_gpu_gpu_copy(
DEVICE_GPU, DEVICE_GPU, GPUUtil::DeviceToDeviceCopy);
+// static
void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device,
const DeviceContext* device_context,
const Tensor* gpu_tensor, Tensor* cpu_tensor,
StatusCallback done) {
VLOG(1) << "CopyGPUTensorToCPU";
- size_t total_bytes = gpu_tensor->TotalBytes();
- // Note that 0-size tensors have no backing buffer.
+ const DeviceBase::GpuDeviceInfo* dev_info = nullptr;
+ gpu::Stream* stream = nullptr;
+ Status s = PrepareCopy(gpu_device, device_context, *gpu_tensor, cpu_tensor,
+ &dev_info, &stream);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+ const int64 total_bytes = gpu_tensor->TotalBytes();
if (total_bytes > 0) {
- const void* src_ptr = DMAHelper::base(gpu_tensor);
- void* dst_ptr = DMAHelper::base(cpu_tensor);
- CHECK(dst_ptr);
- auto* stream = gpu_device->tensorflow_gpu_device_info()->stream;
- if (device_context) {
- stream = static_cast<const GPUDeviceContext*>(device_context)->stream();
- }
- stream->ThenMemcpy(
- dst_ptr, DeviceMemoryBase{const_cast<void*>(src_ptr), total_bytes},
- total_bytes);
- stream->BlockHostUntilDone();
+ void* src_ptr = GetBase(gpu_tensor);
+ DeviceMemoryBase gpu_src_ptr(src_ptr, total_bytes);
+ void* dst_ptr = GetBase(cpu_tensor);
+ stream->ThenMemcpy(dst_ptr, gpu_src_ptr, total_bytes);
+ }
+ dev_info->event_mgr->ThenExecute(stream, [stream, done]() {
if (!stream->ok()) {
- done(errors::Internal("CopyGPUTensorToCPU: GPU->CPU Memcpy failed"));
- return;
+ LOG(FATAL) << "GPU->CPU Memcpy failed";
}
- }
-
- done(Status::OK());
+ done(Status::OK());
+ });
}
/* static */
@@ -183,47 +241,31 @@ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor,
Device* gpu_device, Tensor* gpu_tensor,
StatusCallback done) {
VLOG(1) << "CopyCPUTensorToGPU";
- CHECK(DeviceType(gpu_device->attributes().device_type()) ==
- DeviceType(DEVICE_GPU));
-
- auto* dev_info = gpu_device->tensorflow_gpu_device_info();
- if (!dev_info) {
- done(errors::Internal("Failed to find dest device GPUDeviceInfo"));
- return;
- }
- if (cpu_tensor->TotalBytes() != gpu_tensor->TotalBytes()) {
- done(errors::Internal(
- strings::StrCat("Can't copy ", cpu_tensor->TotalBytes(),
- " bytes of a tensor into another with ",
- gpu_tensor->TotalBytes(), " bytes buffer.")));
+ const DeviceBase::GpuDeviceInfo* dev_info = nullptr;
+ gpu::Stream* stream = nullptr;
+ Status s = PrepareCopy(gpu_device, device_context, *cpu_tensor, gpu_tensor,
+ &dev_info, &stream);
+ if (!s.ok()) {
+ done(s);
return;
}
const int64 total_bytes = cpu_tensor->TotalBytes();
// Note that 0-size tensors have no backing buffer.
if (total_bytes > 0) {
- const void* src_ptr = DMAHelper::base(cpu_tensor);
- void* dst_ptr = DMAHelper::base(gpu_tensor);
+ void* src_ptr = GetBase(cpu_tensor);
+ void* dst_ptr = GetBase(gpu_tensor);
DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
-
- CHECK(device_context);
- auto* stream =
- static_cast<const GPUDeviceContext*>(device_context)->stream();
stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes);
- auto* dev_info = gpu_device->tensorflow_gpu_device_info();
- // Use of cpu_tensor may outlive stack scope, so keep a ref.
- TensorReference input_ref(*cpu_tensor);
- dev_info->event_mgr->ThenExecute(stream, [stream, done, input_ref]() {
- input_ref.Unref();
- if (!stream->ok()) {
- done(errors::Internal("CopyCPUTensorToGPU: GPU Memcpy failed"));
- } else {
- done(Status::OK());
- }
- });
- } else {
- // empty tensor case
- done(Status::OK());
}
+ // Use of cpu_tensor may outlive stack scope, so keep a ref.
+ TensorReference input_ref(*cpu_tensor);
+ dev_info->event_mgr->ThenExecute(stream, [stream, done, input_ref]() {
+ input_ref.Unref();
+ if (!stream->ok()) {
+ LOG(FATAL) << "CPU->GPU Memcpy failed";
+ }
+ done(Status::OK());
+ });
}
Status GPUUtil::Sync(Device* gpu_device) {
@@ -257,7 +299,7 @@ string GPUUtil::MemoryDebugString(const Device* device, Tensor* tensor) {
CHECK(tensor);
const int64 num_bytes = std::min<int64>(
FLAGS_brain_gpu_util_debug_string_maxlen, tensor->TotalBytes());
- void* ptr = (num_bytes > 0) ? DMAHelper::base(tensor) : nullptr;
+ void* ptr = (num_bytes > 0) ? GetBase(tensor) : nullptr;
strings::Appendf(&ret, "%p:", ptr);
if (num_bytes > 0) {
auto* dev_info = device->tensorflow_gpu_device_info();
@@ -295,14 +337,14 @@ uint64 GPUUtil::Checksum(Device* gpu_device,
}
uint64 GPUUtil::Checksum(const Tensor& tensor) {
- const float* fptr = reinterpret_cast<const float*>(DMAHelper::base(&tensor));
+ const float* fptr = reinterpret_cast<const float*>(GetBase(&tensor));
size_t num_bytes = tensor.TotalBytes();
size_t num_floats = num_bytes / sizeof(float);
for (size_t i = 0; i < num_floats; ++i) {
CHECK(!std::isnan(fptr[i])) << " i " << i;
}
// TODO(tucker): consider using crc32c instead.
- return Hash64(reinterpret_cast<const char*>(DMAHelper::base(&tensor)),
+ return Hash64(reinterpret_cast<const char*>(GetBase(&tensor)),
tensor.TotalBytes(), 0);
}
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index ac1e0554dd..b94207e2e8 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -553,4 +553,25 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
return Status::OK();
}
+void RemoveDescriptionsFromOpDef(OpDef* op_def) {
+ for (int i = 0; i < op_def->input_arg_size(); ++i) {
+ op_def->mutable_input_arg(i)->clear_description();
+ }
+ for (int i = 0; i < op_def->output_arg_size(); ++i) {
+ op_def->mutable_output_arg(i)->clear_description();
+ }
+ for (int i = 0; i < op_def->attr_size(); ++i) {
+ op_def->mutable_attr(i)->clear_description();
+ }
+ op_def->clear_summary();
+ op_def->clear_description();
+}
+
+void RemoveDescriptionsFromOpList(OpList* op_list) {
+ for (int i = 0; i < op_list->op_size(); ++i) {
+ OpDef* op_def = op_list->mutable_op(i);
+ RemoveDescriptionsFromOpDef(op_def);
+ }
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h
index 34dd4f374b..de350d9aaa 100644
--- a/tensorflow/core/framework/op_def_util.h
+++ b/tensorflow/core/framework/op_def_util.h
@@ -54,6 +54,10 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
const OpDef& penultimate_op,
const OpDef& new_op);
+// Remove all docs from *op_def / *op_list.
+void RemoveDescriptionsFromOpDef(OpDef* op_def);
+void RemoveDescriptionsFromOpList(OpList* op_list);
+
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h
index 9718e252a5..25391009f9 100644
--- a/tensorflow/core/graph/tensor_id.h
+++ b/tensorflow/core/graph/tensor_id.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@@ -33,6 +34,13 @@ struct TensorId : public std::pair<StringPiece, int> {
using Base::pair;
string ToString() const { return strings::StrCat(first, ":", second); }
+
+ struct Hasher {
+ public:
+ std::size_t operator()(const TensorId& x) const {
+ return Hash32(x.first.data(), x.first.size(), x.second);
+ }
+ };
};
TensorId ParseTensorName(const string& name);
diff --git a/tensorflow/core/graph/validate.cc b/tensorflow/core/graph/validate.cc
index 465ca12098..faf1ea89e0 100644
--- a/tensorflow/core/graph/validate.cc
+++ b/tensorflow/core/graph/validate.cc
@@ -15,8 +15,13 @@ limitations under the License.
#include "tensorflow/core/graph/validate.h"
+#include <unordered_map>
+
+#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace graph {
@@ -34,5 +39,47 @@ Status ValidateGraphDef(const GraphDef& graph_def,
return s;
}
+namespace {
+
+class OpListOpRegistry : public OpRegistryInterface {
+ public:
+ // Does not take ownership of op_list, *op_list must outlive *this.
+ OpListOpRegistry(const OpList* op_list) {
+ for (const OpDef& op_def : op_list->op()) {
+ index_[op_def.name()] = &op_def;
+ }
+ }
+ ~OpListOpRegistry() override {}
+
+ const OpDef* LookUp(const string& op_type_name,
+ Status* status) const override {
+ auto iter = index_.find(op_type_name);
+ if (iter == index_.end()) {
+ status->Update(
+ errors::NotFound("Op type not registered '", op_type_name, "'"));
+ return nullptr;
+ }
+ return iter->second;
+ }
+
+ private:
+ std::unordered_map<string, const OpDef*> index_;
+};
+
+} // namespace
+
+Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def,
+ const OpList& op_list) {
+ OpListOpRegistry registry(&op_list);
+ GraphDef copy(graph_def);
+ TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&copy, &registry, 0));
+ return ValidateGraphDef(copy, &registry);
+}
+
+void GetOpListForValidation(OpList* op_list, const OpRegistry* op_registry) {
+ op_registry->Export(false, op_list);
+ RemoveDescriptionsFromOpList(op_list);
+}
+
} // namespace graph
} // namespace tensorflow
diff --git a/tensorflow/core/graph/validate.h b/tensorflow/core/graph/validate.h
index af91f76511..b98f2e0f01 100644
--- a/tensorflow/core/graph/validate.h
+++ b/tensorflow/core/graph/validate.h
@@ -23,15 +23,29 @@ limitations under the License.
namespace tensorflow {
namespace graph {
-// Returns OK if 'graph_def' has the following properties:
+// Returns OK if every NodeDef in `graph_def` is valid with respect to
+// its corresponding OpDef (as defined by ValidateNodeDef()) as
+// registered in `op_registry`.
//
-// 1) Every NodeDef is valid with respect to its corresponding OpDef
-// as registered in 'op_registry'.
-//
-// REQUIRES: 'op_registry' is not nullptr.
+// REQUIRES:
+// * `op_registry` is not nullptr.
+// * `graph_def` has default attrs filled in (see AddDefaultAttrsToGraphDef()).
Status ValidateGraphDef(const GraphDef& graph_def,
const OpRegistryInterface* op_registry);
+// Like ValidateGraphDef() except:
+// * Takes an OpList instead of an OpRegistryInterface.
+// Note that the OpList need not have descriptions, which can be a big
+// space savings, see GetOpListForValidation() below.
+// * Makes a copy of `graph_def` and calls AddDefaultAttrsToGraphDef()
+// on the copy.
+Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def,
+ const OpList& op_list);
+
+// Get an OpList from `*op_registry` with all the descriptions removed.
+void GetOpListForValidation(
+ OpList* op_list, const OpRegistry* op_registry = OpRegistry::Global());
+
} // namespace graph
} // namespace tensorflow
diff --git a/tensorflow/core/graph/validate_test.cc b/tensorflow/core/graph/validate_test.cc
index c1f75f9c47..17e56d7443 100644
--- a/tensorflow/core/graph/validate_test.cc
+++ b/tensorflow/core/graph/validate_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/subgraph.h"
@@ -34,9 +35,9 @@ REGISTER_OP("FloatInput").Output("o: float");
REGISTER_OP("Int32Input").Output("o: int32");
TEST(ValidateGraphDefTest, TestValidGraph) {
- string graph_def_str =
- "node { name: 'A' op: 'FloatInput'}"
- "node { name: 'B' op: 'FloatInput'}"
+ const string graph_def_str =
+ "node { name: 'A' op: 'FloatInput' }"
+ "node { name: 'B' op: 'FloatInput' }"
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }";
GraphDef graph_def;
@@ -46,9 +47,9 @@ TEST(ValidateGraphDefTest, TestValidGraph) {
}
TEST(ValidateGraphDefTest, GraphWithUnspecifiedDefaultAttr) {
- string graph_def_str =
- "node { name: 'A' op: 'FloatInput'}"
- "node { name: 'B' op: 'Int32Input'}"
+ const string graph_def_str =
+ "node { name: 'A' op: 'FloatInput' }"
+ "node { name: 'B' op: 'Int32Input' }"
"node { "
" name: 'C' op: 'Sum' "
" attr { key: 'T' value { type: DT_FLOAT } }"
@@ -70,8 +71,8 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedDefaultAttr) {
TEST(ValidateGraphDefTest, GraphWithUnspecifiedRequiredAttr) {
// "DstT" attribute is missing.
- string graph_def_str =
- "node { name: 'A' op: 'FloatInput'}"
+ const string graph_def_str =
+ "node { name: 'A' op: 'FloatInput' }"
"node { "
" name: 'B' op: 'Cast' "
" attr { key: 'SrcT' value { type: DT_FLOAT } }"
@@ -93,5 +94,53 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedRequiredAttr) {
EXPECT_TRUE(StringPiece(s.ToString()).contains("NodeDef missing attr"));
}
+TEST(ValidateGraphDefAgainstOpListTest, GraphWithOpOnlyInOpList) {
+ OpList op_list;
+ TF_ASSERT_OK(OpDefBuilder("UniqueSnowflake").Finalize(op_list.add_op()));
+ const string graph_def_str = "node { name: 'A' op: 'UniqueSnowflake' }";
+ GraphDef graph_def;
+ auto parser = protobuf::TextFormat::Parser();
+ CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
+ TF_ASSERT_OK(graph::ValidateGraphDefAgainstOpList(graph_def, op_list));
+}
+
+TEST(ValidateGraphDefAgainstOpListTest, GraphWithGlobalOpNotInOpList) {
+ OpList op_list;
+ TF_ASSERT_OK(OpDefBuilder("NotAnywhere").Finalize(op_list.add_op()));
+ const string graph_def_str = "node { name: 'A' op: 'FloatInput' }";
+ GraphDef graph_def;
+ auto parser = protobuf::TextFormat::Parser();
+ CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
+ ASSERT_FALSE(graph::ValidateGraphDefAgainstOpList(graph_def, op_list).ok());
+}
+
+REGISTER_OP("HasDocs").Doc("This is in the summary.");
+
+TEST(GetOpListForValidationTest, ShouldStripDocs) {
+ bool found_float = false;
+ bool found_int32 = false;
+ bool found_has_docs = false;
+ OpList op_list;
+ graph::GetOpListForValidation(&op_list);
+ for (const OpDef& op_def : op_list.op()) {
+ if (op_def.name() == "FloatInput") {
+ EXPECT_FALSE(found_float);
+ found_float = true;
+ }
+ if (op_def.name() == "Int32Input") {
+ EXPECT_FALSE(found_int32);
+ found_int32 = true;
+ }
+ if (op_def.name() == "HasDocs") {
+ EXPECT_FALSE(found_has_docs);
+ found_has_docs = true;
+ EXPECT_TRUE(op_def.summary().empty());
+ }
+ }
+ EXPECT_TRUE(found_float);
+ EXPECT_TRUE(found_int32);
+ EXPECT_TRUE(found_has_docs);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/matrix_solve_ls_op.cc b/tensorflow/core/kernels/matrix_solve_ls_op.cc
index b752a7ed6e..c69c93fcb1 100644
--- a/tensorflow/core/kernels/matrix_solve_ls_op.cc
+++ b/tensorflow/core/kernels/matrix_solve_ls_op.cc
@@ -21,11 +21,11 @@ limitations under the License.
#include "third_party/eigen3/Eigen/QR"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/binary_linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/port.h"
-#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc
new file mode 100644
index 0000000000..f083c50743
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_xent_op.cc
@@ -0,0 +1,105 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/sparse_xent_op.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class SparseSoftmaxXentWithLogitsOp : public OpKernel {
+ public:
+ explicit SparseSoftmaxXentWithLogitsOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& logits_in = context->input(0);
+ const Tensor& labels_in = context->input(1);
+ OP_REQUIRES(context, logits_in.shape().dim_size(0) == labels_in.NumElements(),
+ errors::InvalidArgument(
+ "logits first dimension must match labels size. logits shape=",
+ logits_in.shape().DebugString(), " labels shape=",
+ labels_in.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
+ errors::InvalidArgument("logits must be 2-dimensional"));
+ // As we already tested that both inputs have the same shape no need to
+ // check that "labels" is a matrix too.
+
+ // loss is 1-D (one per example), and size is batch_size.
+
+ Tensor scratch;
+ OP_REQUIRES_OK(
+ context, context->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({logits_in.dim_size(0)}),
+ &scratch));
+
+ Tensor* loss_out = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, TensorShape({logits_in.dim_size(0)}), &loss_out));
+ Tensor* back_out = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, logits_in.shape(), &back_out));
+
+ functor::SparseXentFunctor<Device, T> functor;
+ functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
+ labels_in.vec<int64>(), scratch.vec<T>(), loss_out->vec<T>(),
+ back_out->matrix<T>());
+ }
+};
+
+// Partial specialization for a CPUDevice, that uses the Eigen implementation
+// from XentEigenImpl.
+namespace functor {
+template <typename T>
+struct SparseXentFunctor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<int64>::ConstVec labels,
+ typename TTypes<T>::Vec scratch,
+ typename TTypes<T>::Vec loss,
+ typename TTypes<T>::Matrix backprop) {
+ SparseXentEigenImpl<CPUDevice, T>::Compute(d, logits, labels, scratch, loss,
+ backprop);
+ }
+};
+} // namespace functor
+
+REGISTER_KERNEL_BUILDER(Name("SparseSoftmaxCrossEntropyWithLogits")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ SparseSoftmaxXentWithLogitsOp<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("SparseSoftmaxCrossEntropyWithLogits")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<double>("T"),
+ SparseSoftmaxXentWithLogitsOp<CPUDevice, double>);
+
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("SparseSoftmaxCrossEntropyWithLogits")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T"),
+ SparseSoftmaxXentWithLogitsOp<GPUDevice, float>);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_xent_op.h b/tensorflow/core/kernels/sparse_xent_op.h
new file mode 100644
index 0000000000..5014794cb9
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_xent_op.h
@@ -0,0 +1,204 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_XENT_OP_H_
+#define TENSORFLOW_KERNELS_XENT_OP_H_
+// Functor definition for SparseXentOp, must be compilable by nvcc.
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+namespace sparse_xent_helpers {
+
+template <typename T>
+typename TTypes<const T, 1>::Tensor32Bit To32BitConst(
+ typename TTypes<T>::Vec in) {
+ return To32Bit(typename TTypes<T>::ConstVec(in.data(), in.dimensions()));
+}
+
+template <typename T>
+typename TTypes<const T, 2>::Tensor32Bit To32BitConst(
+ typename TTypes<T>::Matrix in) {
+ return To32Bit(typename TTypes<T>::ConstMatrix(in.data(), in.dimensions()));
+}
+
+} // namespace sparse_xent_helpers
+
+namespace generator {
+
+// Generator for calculation of the sparse Xent loss.
+// This generator takes the logits, the sum of the exponentiated
+// logits, and the label indices. For each minibatch entry, ignoring
+// the batch index b, it calculates:
+//
+// loss[j] = (log(sum_exp_logits) - logits[j]) * 1{ j == label }
+//
+// for j = 0 .. num_classes. This value must be summed over all j for
+// the final loss.
+template <typename T>
+class SparseXentLossGenerator {
+ public:
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentLossGenerator(
+ typename TTypes<const T, 2>::Tensor32Bit logits,
+ typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits,
+ TTypes<const int64, 1>::Tensor32Bit labels)
+ : logits_(logits), sum_exp_logits_(sum_exp_logits), labels_(labels) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
+ operator()(const Eigen::array<int, 2>& coords) const {
+ int batch = coords[0];
+ int depth = coords[1];
+ return (labels_(batch) == depth)
+ ? (std::log(sum_exp_logits_(batch)) - logits_(coords))
+ : T(0.0);
+ };
+
+ private:
+ typename TTypes<const T, 2>::Tensor32Bit logits_;
+ typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_;
+ TTypes<const int64, 1>::Tensor32Bit labels_;
+};
+
+// Generator for calculation of the sparse Xent gradient.
+// This generator takes the logits, the sum of the exponentiated
+// logits, and the label indices. For each minibatch entry, ignoring
+// the batch index b, it calculates:
+//
+// exp(logits[j]) / sum_exp_logits - 1{ j == label }
+//
+// for j = 0 .. num_classes.
+template <typename T>
+class SparseXentGradGenerator {
+ public:
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentGradGenerator(
+ typename TTypes<const T, 2>::Tensor32Bit logits,
+ typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits,
+ TTypes<const int64, 1>::Tensor32Bit labels)
+ : logits_(logits), sum_exp_logits_(sum_exp_logits), labels_(labels) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
+ operator()(const Eigen::array<int, 2>& coords) const {
+ int batch = coords[0];
+ int depth = coords[1];
+ T subtract = (depth == labels_(batch)) ? T(1.0) : T(0.0);
+ return std::exp(logits_(coords)) / sum_exp_logits_(batch) - subtract;
+ };
+
+ private:
+ typename TTypes<const T, 2>::Tensor32Bit logits_;
+ typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_;
+ TTypes<const int64, 1>::Tensor32Bit labels_;
+};
+
+} // namespace generator
+
+namespace functor {
+
+// Functor used by SparseXentOp to do the computations.
+template <typename Device, typename T>
+struct SparseXentFunctor {
+ // Computes Cross Entropy loss and backprop.
+ //
+ // logits: batch_size, num_classes.
+ // labels: num_classes.
+ // scratch: temporary tensor, dims: batch_size, 1
+ // loss: output tensor for the loss, dims: batch_size.
+ // backprop: output tensor for the backprop, dims: batch_size, num_classes.
+ void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<int64>::ConstVec labels,
+ typename TTypes<T>::Vec scratch, typename TTypes<T>::Vec loss,
+ typename TTypes<T>::Matrix backprop);
+};
+
+// Eigen code implementing SparseXentFunctor::operator().
+// This code works for both CPU and GPU and is used by the functor
+// specializations for both device types.
+template <typename Device, typename T>
+struct SparseXentEigenImpl {
+ static void Compute(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<int64>::ConstVec labels,
+ typename TTypes<T>::Vec scratch,
+ typename TTypes<T>::Vec loss,
+ typename TTypes<T>::Matrix backprop) {
+ // NOTE(touts): This duplicates some of the computations in softmax_op
+ // because we need the intermediate (logits -max(logits)) values to
+ // avoid a log(exp()) in the computation of the loss.
+
+ const int kBatchDim = 0;
+ const int kClassDim = 1;
+
+ const int batch_size = logits.dimension(kBatchDim);
+ const int num_classes = logits.dimension(kClassDim);
+
+// These arrays are used to reduce along the class dimension, and broadcast
+// the resulting value to all classes.
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::array<int, 1> along_class;
+ along_class[0] = kClassDim;
+ Eigen::array<int, 1> batch_only;
+ batch_only[0] = batch_size;
+ Eigen::array<int, 2> batch_by_one;
+ batch_by_one[0] = batch_size;
+ batch_by_one[1] = 1;
+ Eigen::array<int, 2> one_by_class;
+ one_by_class[0] = 1;
+ one_by_class[1] = num_classes;
+#else
+ Eigen::IndexList<Eigen::type2index<kClassDim> > along_class;
+ Eigen::IndexList<int, Eigen::type2index<1> > batch_by_one;
+ batch_by_one.set(0, batch_size);
+ Eigen::IndexList<int> batch_only;
+ batch_only.set(0, batch_size);
+ Eigen::IndexList<Eigen::type2index<1>, int> one_by_class;
+ one_by_class.set(1, num_classes);
+#endif
+
+ // scratch = max_logits along classes.
+ To32Bit(scratch).device(d) = To32Bit(logits).maximum(along_class);
+
+ // backprop = logits - max_logits.
+ To32Bit(backprop).device(d) =
+ To32Bit(logits) -
+ To32Bit(scratch).reshape(batch_by_one).broadcast(one_by_class);
+
+ // scratch = sum(exp(logits - max_logits)) along classes.
+ To32Bit(scratch).device(d) = To32Bit(backprop).exp().sum(along_class);
+
+ // sum(-labels *
+ // ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
+ // along classes
+ generator::SparseXentLossGenerator<T> sparse_xent_loss_gen(
+ sparse_xent_helpers::To32BitConst<T>(backprop),
+ sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels));
+ To32Bit(loss).device(d) =
+ To32Bit(backprop).generate(sparse_xent_loss_gen).sum(along_class);
+
+ // backprop: prob - labels, where
+ // prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
+ generator::SparseXentGradGenerator<T> sparse_xent_grad_gen(
+ sparse_xent_helpers::To32BitConst<T>(backprop),
+ sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels));
+ To32Bit(backprop).device(d) =
+ To32Bit(backprop).generate(sparse_xent_grad_gen);
+ }
+};
+
+} // namespace functor
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_XENT_OP_H_
diff --git a/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc
new file mode 100644
index 0000000000..296e4352a7
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc
@@ -0,0 +1,51 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/sparse_xent_op.h"
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Partial specialization for a GPUDevice, that uses the Eigen implementation
+// from XentEigenImpl.
+namespace functor {
+template <typename T>
+struct SparseXentFunctor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<int64>::ConstVec labels,
+ typename TTypes<T>::Vec scratch,
+ typename TTypes<T>::Vec loss,
+ typename TTypes<T>::Matrix backprop) {
+ SparseXentEigenImpl<GPUDevice, T>::Compute(d, logits, labels, scratch, loss,
+ backprop);
+ }
+};
+} // end namespace functor
+
+// Instantiate the GPU implementation for float.
+template struct functor::SparseXentFunctor<GPUDevice, float>;
+template class generator::SparseXentGradGenerator<float>;
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/sparse_xent_op_test.cc b/tensorflow/core/kernels/sparse_xent_op_test.cc
new file mode 100644
index 0000000000..16335a9f1b
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_xent_op_test.cc
@@ -0,0 +1,78 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <random>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/kernels/xent_op.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+static Graph* SparseXent(int batch_size, int num_classes) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor logits(DT_FLOAT, TensorShape({batch_size, num_classes}));
+ logits.flat<float>().setRandom();
+ Tensor labels(DT_INT64, TensorShape({batch_size}));
+ std::random_device rd;
+ std::mt19937 gen(rd());
+ std::uniform_int_distribution<> dist(0, num_classes - 1);
+ auto labels_t = labels.flat<int64>();
+ for (int i = 0; i < batch_size; ++i) {
+ labels_t(i) = dist(gen);
+ }
+ test::graph::Binary(g, "SparseSoftmaxCrossEntropyWithLogits",
+ test::graph::Constant(g, logits),
+ test::graph::Constant(g, labels));
+ return g;
+}
+
+#define BM_SparseXentDev(BATCH, CLASS, DEVICE) \
+ static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE(int iters) { \
+ testing::ItemsProcessed(static_cast<int64>(iters) * BATCH * CLASS); \
+ test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS)).Run(iters); \
+ } \
+ BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE);
+
+/// The representative tests for ptb_word on GPU
+BM_SparseXentDev(8, 1000000, gpu);
+
+BM_SparseXentDev(16, 10000, gpu);
+BM_SparseXentDev(16, 30000, gpu);
+BM_SparseXentDev(16, 100000, gpu);
+
+BM_SparseXentDev(32, 10000, gpu);
+BM_SparseXentDev(32, 30000, gpu);
+BM_SparseXentDev(32, 100000, gpu);
+
+BM_SparseXentDev(64, 10000, gpu);
+BM_SparseXentDev(64, 30000, gpu);
+BM_SparseXentDev(64, 100000, gpu);
+
+// CPU
+BM_SparseXentDev(8, 1000000, cpu);
+
+BM_SparseXentDev(16, 10000, cpu);
+BM_SparseXentDev(16, 100000, cpu);
+
+BM_SparseXentDev(32, 10000, cpu);
+BM_SparseXentDev(32, 100000, cpu);
+
+BM_SparseXentDev(64, 10000, cpu);
+BM_SparseXentDev(64, 100000, cpu);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 83f28bdacf..a51a0b3469 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -551,6 +551,29 @@ loss: Per example loss (batch_size vector).
backprop: backpropagated gradients (batch_size x num_classes matrix).
)doc");
+REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits")
+ .Input("features: T")
+ .Input("labels: int64")
+ .Output("loss: T")
+ .Output("backprop: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Computes softmax cross entropy cost and gradients to backpropagate.
+
+Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept
+a matrix of label probabilities, but rather a single label per row
+of features. This label is considered to have probability 1.0 for the
+given row.
+
+Inputs are the logits, not probabilities.
+
+features: batch_size x num_classes matrix
+labels: batch_size vector with values in [0, num_classes).
+ This is the label for the given minibatch entry.
+loss: Per example loss (batch_size vector).
+backprop: backpropagated gradients (batch_size x num_classes matrix).
+)doc");
+
// --------------------------------------------------------------------------
REGISTER_OP("InTopK")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index fbb589bf24..c9f0dda3a1 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -8509,6 +8509,41 @@ op {
description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nLike `SegmentSum`, but `segment_ids` can have rank less than `data`\'s first\ndimension, selecting a subset of dimension 0, specified by `indices`.\n\nFor example:\n\n```prettyprint\nc = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])\n\n# Select two rows, one segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))\n ==> [[0 0 0 0]]\n\n# Select two rows, two segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))\n ==> [[ 1 2 3 4]\n [-1 -2 -3 -4]]\n\n# Select all rows, two segments.\ntf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))\n ==> [[0 0 0 0]\n [5 6 7 8]]\n\n# Which is equivalent to:\ntf.segment_sum(c, tf.constant([0, 0, 1]))\n```"
}
op {
+ name: "SparseSoftmaxCrossEntropyWithLogits"
+ input_arg {
+ name: "features"
+ description: "batch_size x num_classes matrix"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "labels"
+ description: "batch_size vector with values in [0, num_classes).\nThis is the label for the given minibatch entry."
+ type: DT_INT64
+ }
+ output_arg {
+ name: "loss"
+ description: "Per example loss (batch_size vector)."
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprop"
+ description: "backpropagated gradients (batch_size x num_classes matrix)."
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ summary: "Computes softmax cross entropy cost and gradients to backpropagate."
+ description: "Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept\na matrix of label probabilities, but rather a single label per row\nof features. This label is considered to have probability 1.0 for the\ngiven row.\n\nInputs are the logits, not probabilities."
+}
+op {
name: "SparseSplit"
input_arg {
name: "split_dim"
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index a5932cf0e1..258e4b6210 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -3,6 +3,10 @@
load("/google/protobuf/protobuf", "cc_proto_library")
load("/google/protobuf/protobuf", "py_proto_library")
+# configure may change the following lines.
+CUDA_VERSION = '7.0'
+CUDNN_VERSION = '6.5'
+
# Appends a suffix to a list of deps.
def tf_deps(deps, suffix):
tf_deps = []
@@ -68,3 +72,9 @@ def tf_additional_test_srcs():
def tf_kernel_tests_linkstatic():
return 0
+
+def tf_get_cuda_version():
+ return CUDA_VERSION
+
+def tf_get_cudnn_version():
+ return CUDNN_VERSION
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index 44dbc47ad1..b30c3ecf3d 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -9,6 +9,7 @@ exports_files(["LICENSE"])
load("/tensorflow/tensorflow", "tf_copts")
load("/tensorflow/tensorflow", "tf_cuda_library")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_get_cuda_version")
cc_library(
name = "gtest",
@@ -74,7 +75,7 @@ filegroup(
cc_library(
name = "cuda",
data = [
- "//third_party/gpus/cuda:lib64/libcudart.so.7.0",
+ "//third_party/gpus/cuda:lib64/libcudart.so." + tf_get_cuda_version(),
],
linkopts = [
"-Wl,-rpath,third_party/gpus/cuda/lib64",
diff --git a/tensorflow/core/public/session.h b/tensorflow/core/public/session.h
index 6d364edef3..cbaaf602ee 100644
--- a/tensorflow/core/public/session.h
+++ b/tensorflow/core/public/session.h
@@ -114,6 +114,33 @@ class Session {
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs) = 0;
+ /// \brief Sets up a graph for partial execution. All future feeds and
+ /// fetches are specified by 'input_names' and 'output_names'. Returns
+ /// 'handle' that can be used to perform a sequence of partial feeds and
+ /// fetches.
+ /// NOTE: This API is still experimental and may change.
+ virtual Status PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ return errors::Unimplemented(
+ "Partial run is not supported for"
+ " this session.");
+ }
+
+ /// \brief Continues the pending execution specified by 'handle' with the
+ /// provided input tensors and fills `outputs` for the endpoints specified
+ /// in `output_names`.
+ /// NOTE: This API is still experimental and may change.
+ virtual Status PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ return errors::Unimplemented(
+ "Partial run is not supported for"
+ " this session.");
+ }
+
/// \brief Closes this session.
///
/// Closing a session releases the resources used by this session
diff --git a/tensorflow/core/public/tensor_c_api.h b/tensorflow/core/public/tensor_c_api.h
index 2c70713f14..e2210dcdcf 100644
--- a/tensorflow/core/public/tensor_c_api.h
+++ b/tensorflow/core/public/tensor_c_api.h
@@ -253,7 +253,7 @@ extern void TF_ExtendGraph(TF_Session*, const void* proto, size_t proto_len,
// implementation will eventually call TF_DeleteTensor on each input).
//
// On success, the tensors corresponding to output_names[0,noutputs-1]
-// are placed in outputs[]. and these outputs[] become the property
+// are placed in outputs[], and these outputs[] become the property
// of the caller (the caller must eventually call TF_DeleteTensor on
// them).
//
@@ -269,6 +269,40 @@ extern void TF_Run(TF_Session*,
// Output status
TF_Status*);
+// Set up the graph with the intended feeds and fetches for a sequence
+// of partial run calls.
+//
+// On success, returns a handle that is used for subsequent PRun calls.
+//
+// On failure, out_status contains a tensorflow::Status with an error
+// message.
+// NOTE: This is EXPERIMENTAL and subject to change.
+extern void TF_PRunSetup(TF_Session*,
+ // Input names
+ const char** input_names, int ninputs,
+ // Output names
+ const char** output_tensor_names, int noutputs,
+ // Target nodes
+ const char** target_node_names, int ntargets,
+ // Output handle
+ char** handle,
+ // Output status
+ TF_Status*);
+
+// Continue to run the graph with additional feeds and fetches. The
+// execution state is uniquely identified by the handle.
+// NOTE: This is EXPERIMENTAL and subject to change.
+extern void TF_PRun(TF_Session*, const char* handle,
+ // Input tensors
+ const char** input_names, TF_Tensor** inputs, int ninputs,
+ // Output tensors
+ const char** output_tensor_names, TF_Tensor** outputs,
+ int noutputs,
+ // Target nodes
+ const char** target_node_names, int ntargets,
+ // Output status
+ TF_Status*);
+
// --------------------------------------------------------------------------
// Load plugins containing custom ops and kernels
diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD
index 4f1fabfe56..5eb6d846fd 100644
--- a/tensorflow/examples/android/BUILD
+++ b/tensorflow/examples/android/BUILD
@@ -13,16 +13,29 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-cc_library(
- name = "tensorflow_native_libs",
- srcs = glob(["jni/**/*.cc"]) + [":libpthread.so"],
- hdrs = glob(["jni/**/*.h"]),
+cc_binary(
+ name = "libtensorflow_demo.so",
+ srcs = glob([
+ "jni/**/*.cc",
+ "jni/**/*.h",
+ ]) + [":libpthread.so"],
copts = [
"-std=c++11",
"-mfpu=neon",
"-O2",
],
- linkopts = ["-llog -landroid -lm -ljnigraphics"],
+ linkopts = [
+ "-landroid",
+ "-ljnigraphics",
+ "-llog",
+ "-lm",
+ "-z defs",
+ "-s",
+ "-Wl,--icf=all", # Identical Code Folding
+ "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export
+ ],
+ linkshared = 1,
+ linkstatic = 1,
tags = [
"manual",
"notap",
@@ -43,6 +56,14 @@ cc_binary(
],
)
+cc_library(
+ name = "tensorflow_native_libs",
+ srcs = [
+ ":libpthread.so",
+ ":libtensorflow_demo.so",
+ ],
+)
+
android_binary(
name = "tensorflow_demo",
srcs = glob([
@@ -52,7 +73,6 @@ android_binary(
assets_dir = "assets",
custom_package = "org.tensorflow.demo",
inline_constants = 1,
- legacy_native_support = 0,
manifest = "AndroidManifest.xml",
resource_files = glob(["res/**"]),
tags = [
diff --git a/tensorflow/examples/udacity/README.md b/tensorflow/examples/udacity/README.md
index bab75c7973..a3adc5f155 100644
--- a/tensorflow/examples/udacity/README.md
+++ b/tensorflow/examples/udacity/README.md
@@ -17,6 +17,24 @@ On mac, find the virtual machine's IP using:
Then go to: http://IP:8888 (likely http://192.168.99.100:8888)
+FAQ
+---
+
+* **I'm getting a MemoryError when loading data in the first notebook.**
+
+If you're using a Mac, Docker works by running a VM locally (which
+is controlled by `docker-machine`). It's quite likely that you'll
+need to bump up the amount of RAM allocated to the VM beyond the
+default (which is 1G).
+[This Stack Overflow question](http://stackoverflow.com/questions/32834082/how-to-increase-docker-machine-memory-mac)
+has two good suggestions; we recommend using 8G.
+
+In addition, you may need to pass `--memory=8g` as an extra argument to
+`docker run`.
+
+Notes for anyone needing to build their own containers (mostly instructors)
+===========================================================================
+
Building a local Docker container
---------------------------------
diff --git a/tensorflow/g3doc/api_docs/python/array_ops.md b/tensorflow/g3doc/api_docs/python/array_ops.md
index d224ebbe5a..6a272c0786 100644
--- a/tensorflow/g3doc/api_docs/python/array_ops.md
+++ b/tensorflow/g3doc/api_docs/python/array_ops.md
@@ -884,7 +884,7 @@ tf.transpose(b, perm=[0, 2, 1]) ==> [[[1 4]
- - -
-### `tf.gather(params, indices, name=None)` {#gather}
+### `tf.gather(params, indices, validate_indices=None, name=None)` {#gather}
Gather slices from `params` according to `indices`.
@@ -912,6 +912,7 @@ this operation will permute `params` accordingly.
* <b>`params`</b>: A `Tensor`.
* <b>`indices`</b>: A `Tensor`. Must be one of the following types: `int32`, `int64`.
+* <b>`validate_indices`</b>: An optional `bool`. Defaults to `True`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index 253f13f38f..b1a9cd4a14 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -332,6 +332,7 @@
* [`softmax_cross_entropy_with_logits`](../../api_docs/python/nn.md#softmax_cross_entropy_with_logits)
* [`softplus`](../../api_docs/python/nn.md#softplus)
* [`softsign`](../../api_docs/python/nn.md#softsign)
+ * [`sparse_softmax_cross_entropy_with_logits`](../../api_docs/python/nn.md#sparse_softmax_cross_entropy_with_logits)
* [`tanh`](../../api_docs/python/nn.md#tanh)
* [`top_k`](../../api_docs/python/nn.md#top_k)
* [`uniform_candidate_sampler`](../../api_docs/python/nn.md#uniform_candidate_sampler)
diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md
index 6677320a8d..7e056e2f11 100644
--- a/tensorflow/g3doc/api_docs/python/nn.md
+++ b/tensorflow/g3doc/api_docs/python/nn.md
@@ -751,6 +751,12 @@ classes are mutually exclusive (each entry is in exactly one class). For
example, each CIFAR-10 image is labeled with one and only one label: an image
can be a dog or a truck, but not both.
+**NOTE:**: While the classes are mutually exclusive, their probabilities
+need not be. All that is required is that each row of `labels` is
+a valid probability distribution. If using exclusive `labels`
+(wherein one and only one class is true at a time), see
+`sparse_softmax_cross_entropy_with_logits`.
+
**WARNING:** This op expects unscaled logits, since it performs a `softmax`
on `logits` internally for efficiency. Do not call this op with the
output of `softmax`, as it will produce incorrect results.
@@ -771,6 +777,46 @@ and the same dtype (either `float32` or `float64`).
softmax cross entropy loss.
+- - -
+
+### `tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name=None)` {#sparse_softmax_cross_entropy_with_logits}
+
+Computes sparse softmax cross entropy between `logits` and `labels`.
+
+Measures the probability error in discrete classification tasks in which the
+classes are mutually exclusive (each entry is in exactly one class). For
+example, each CIFAR-10 image is labeled with one and only one label: an image
+can be a dog or a truck, but not both.
+
+**NOTE:**: For this operation, the probability of a given label is considered
+exclusive. That is, soft classes are not allowed, and the `labels` vector
+must provide a single specific index for the true class for each row of
+`logits` (each minibatch entry). For soft softmax classification with
+a probability distribution for each entry, see
+`softmax_cross_entropy_with_logits`.
+
+**WARNING:** This op expects unscaled logits, since it performs a `softmax`
+on `logits` internally for efficiency. Do not call this op with the
+output of `softmax`, as it will produce incorrect results.
+
+`logits` and must have the shape `[batch_size, num_classes]`
+and the dtype (either `float32` or `float64`).
+
+`labels` must have the shape `[batch_size]` and the dtype `int64`.
+
+##### Args:
+
+
+* <b>`logits`</b>: Unscaled log probabilities.
+* <b>`labels`</b>: Each entry `labels[i]` must be an index in `[0, num_classes)`.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
+ softmax cross entropy loss.
+
+
## Embeddings
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index 1b442309f9..f89dd1936f 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -448,6 +448,29 @@ Setting up Cuda nvvm
Configuration finished
```
+##### Using a different Cuda SDK and Cudnn versions
+TensorFlow officially supports Cuda 7.0 and Cudnn V2 (6.5) at this point. In
+order to use a different Cuda SDK or Cudnn libraries, use the unofficial setting
+with "configure"
+
+```bash
+$ TF_UNOFFICIAL_SETTING=1 ./configure
+...
+Please specify the Cuda SDK version you want to use. [Default is 7.0]: 7.5
+Please specify the location where CUDA 7.5 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: /usr/local/cuda-7.5
+Please specify the Cudnn version you want to use. [Default is 6.5]: 4.0.4
+Please specify the location where cuDNN 4.0.4 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda-7.5]: /usr/local/cudnn-r4-rc/
+...
+Setting up Cuda include
+Setting up Cuda lib64
+Setting up Cuda bin
+Setting up Cuda nvvm
+Configuration finished
+```
+
+For the Cudnn libraries, use '6.5' for R2, '7.0' for R3, and '4.0.4' for
+R4-RC.
+
##### Known issues
* Although it is possible to build both Cuda and non-Cuda configs under the same
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/index.md b/tensorflow/g3doc/how_tos/adding_an_op/index.md
index 116d788449..119935f598 100644
--- a/tensorflow/g3doc/how_tos/adding_an_op/index.md
+++ b/tensorflow/g3doc/how_tos/adding_an_op/index.md
@@ -220,9 +220,9 @@ This asserts that the input is a vector, and returns having set the
for its `SetStatus()` method.
* The condition. For example, there are functions for validating the shape
of a tensor in
- [`tensorflow/core/public/tensor_shape.h`](https://www.tensorflow.org/code/tensorflow/core/public/tensor_shape.h)
+ [`tensorflow/core/framework/tensor_shape.h`](https://www.tensorflow.org/code/tensorflow/core/framework/tensor_shape.h)
* The error itself, which is represented by a `Status` object, see
- [`tensorflow/core/public/status.h`](https://www.tensorflow.org/code/tensorflow/core/public/status.h). A
+ [`tensorflow/core/lib/core/status.h`](https://www.tensorflow.org/code/tensorflow/core/lib/core/status.h). A
`Status` has both a type (frequently `InvalidArgument`, but see the list of
types) and a message. Functions for constructing an error may be found in
[`tensorflow/core/lib/core/errors.h`][validation-macros].
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index ed9481c330..126af9f342 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -582,6 +582,7 @@ tf_gen_op_wrapper_py(
"AvgPoolGrad", # "*Grad" accessible through nn_grad instead of nn_ops.
"BatchNormWithGlobalNormalizationGrad",
"SoftmaxCrossEntropyWithLogits",
+ "SparseSoftmaxCrossEntropyWithLogits",
"LRNGrad",
"MaxPoolGrad",
"MaxPoolGradWithArgmax",
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 918e7e4da6..1fda123150 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -50,8 +50,15 @@ class SessionInterface(object):
def run(self, fetches, feed_dict=None):
"""Runs operations in the session. See `Session.run()` for details."""
- raise NotImplementedError('Run')
+ raise NotImplementedError('run')
+ def partial_run_setup(self, fetches, feeds=None):
+ """Sets up the feeds and fetches for partial runs in the session."""
+ raise NotImplementedError('partial_run_setup')
+
+ def partial_run(self, handle, fetches, feed_dict=None):
+ """Continues the execution with additional feeds and fetches."""
+ raise NotImplementedError('partial_run')
def _get_indexed_slices_value_from_fetches(fetched_vals):
return ops.IndexedSlicesValue(fetched_vals[0], fetched_vals[1],
@@ -213,11 +220,12 @@ class BaseSession(SessionInterface):
return ops.default_session(self)
# Eventually, this registration could be opened up to support custom
- # Tensor expansions. Expects tuples of (Type, fetch_fn, feed_fn),
+ # Tensor expansions. Expects tuples of (Type, fetch_fn, feed_fn1, feed_fn2),
# where the signatures are:
# fetch_fn : Type -> (list of Tensors,
# lambda: list of fetched np.ndarray -> TypeVal)
- # feed_fn : Type, TypeVal -> list of (Tensor, value)
+ # feed_fn1 : Type, TypeVal -> list of (Tensor, value)
+ # feed_fn2 : Type -> list of Tensors
# Conceptually, fetch_fn describes how to expand fetch into its
# component Tensors and how to contracting the fetched results back into
# a single return value. feed_fn describes how to unpack a single fed
@@ -231,7 +239,8 @@ class BaseSession(SessionInterface):
[fetch.indices, fetch.values, fetch.shape],
lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)),
lambda feed, feed_val: list(zip(
- [feed.indices, feed.values, feed.shape], feed_val))),
+ [feed.indices, feed.values, feed.shape], feed_val)),
+ lambda feed: [feed.indices, feed.values, feed.shape]),
# IndexedSlices are fetched as IndexedSlicesValues. They can be fed
# IndexedSlicesValues or normal tuples.
(ops.IndexedSlices,
@@ -239,11 +248,14 @@ class BaseSession(SessionInterface):
[fetch.values, fetch.indices] if fetch.dense_shape is None
else [fetch.values, fetch.indices, fetch.dense_shape],
_get_indexed_slices_value_from_fetches),
- _get_feeds_for_indexed_slices),
+ _get_feeds_for_indexed_slices,
+ lambda feed: [feed.values, feed.indices] if feed.dense_shape is None
+ else [feed.values, feed.indices, feed.dense_shape]),
# The default catches all types and performs no expansions.
(object,
lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
- lambda feed, feed_val: [(feed, feed_val)])]
+ lambda feed, feed_val: [(feed, feed_val)],
+ lambda feed: [feed])]
# pylint: enable=g-long-lambda
def run(self, fetches, feed_dict=None):
@@ -302,17 +314,67 @@ class BaseSession(SessionInterface):
ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
`Tensor` that doesn't exist.
"""
- def _fetch_fn(fetch):
- for tensor_type, fetch_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
- if isinstance(fetch, tensor_type):
- return fetch_fn(fetch)
- raise TypeError('Fetch argument %r has invalid type %r'
- % (fetch, type(fetch)))
+ return self._run(None, fetches, feed_dict)
- def _feed_fn(feed, feed_val):
- for tensor_type, _, feed_fn in BaseSession._REGISTERED_EXPANSIONS:
+ def partial_run(self, handle, fetches, feed_dict=None):
+ """Continues the execution with more feeds and fetches.
+
+ This is EXPERIMENTAL and subject to change.
+
+ To use partial execution, a user first calls `partial_run_setup()` and
+ then a sequence of `partial_run()`. `partial_run_setup` specifies the
+ list of feeds and fetches that will be used in the subsequent
+ `partial_run` calls.
+
+ Below is a simple example:
+
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+ r2 = math_ops.mul(r1, c)
+
+ h = sess.partial_run_setup([r1, r2], [a, b, c])
+ res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
+ res = sess.partial_run(h, r2, feed_dict={c: res})
+
+ Args:
+ handle: A handle for a sequence of partial runs.
+ fetches: A single graph element, or a list of graph elements
+ (described above).
+ feed_dict: A dictionary that maps graph elements to values
+ (described above).
+
+ Returns:
+ Either a single value if `fetches` is a single graph element, or
+ a list of values if `fetches` is a list (described above).
+ """
+ return self._run(handle, fetches, feed_dict)
+
+ def partial_run_setup(self, fetches, feeds=None):
+ """Sets up a graph with feeds and fetches for partial run.
+
+ This is EXPERIMENTAL and subject to change.
+
+ Note that contrary to `run`, `feeds` only specifies the graph elements.
+ The tensors will be supplied by the subsequent `partial_run` calls.
+
+ Args:
+ fetches: A single graph element, or a list of graph elements.
+ feeds: A single graph element, or a list of graph elements.
+
+ Returns:
+ A handle for partial run.
+
+ Raises:
+ RuntimeError: If this `Session` is in an invalid state (e.g. has been
+ closed).
+ TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
+ """
+ def _feed_fn(feed):
+ for tensor_type, _, _, feed_fn in BaseSession._REGISTERED_EXPANSIONS:
if isinstance(feed, tensor_type):
- return feed_fn(feed, feed_val)
+ return feed_fn(feed)
raise TypeError('Feed argument %r has invalid type %r'
% (feed, type(feed)))
@@ -324,6 +386,46 @@ class BaseSession(SessionInterface):
'graph before calling run().')
# Validate and process fetches.
+ unique_fetches, target_list, _ = self._process_fetches(fetches)
+
+ # Create request.
+ feed_list = []
+
+ # Validate and process feed_list.
+ is_list_feed = isinstance(feeds, (list, tuple))
+ if not is_list_feed:
+ feeds = [feeds]
+ for feed in feeds:
+ for subfeed in _feed_fn(feed):
+ try:
+ subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
+ allow_operation=False)
+ feed_list.append(compat.as_bytes(subfeed_t.name))
+ except Exception as e:
+ e.message = ('Cannot interpret feed_list key as Tensor: '
+ + e.message)
+ e.args = (e.message,)
+ raise e
+
+ # Set up a graph with feeds and fetches for partial run.
+ def _setup_fn(session, feed_list, fetch_list, target_list):
+ self._extend_graph()
+ return tf_session.TF_PRunSetup(session, feed_list, fetch_list,
+ target_list)
+
+ return self._do_call(_setup_fn, self._session, feed_list, unique_fetches,
+ target_list)
+
+ def _process_fetches(self, fetches):
+ """Validate and process fetches."""
+ def _fetch_fn(fetch):
+ for tensor_type, fetch_fn, _, _ in BaseSession._REGISTERED_EXPANSIONS:
+ if isinstance(fetch, tensor_type):
+ return fetch_fn(fetch)
+ raise TypeError('Fetch argument %r has invalid type %r'
+ % (fetch, type(fetch)))
+
+ # Validate and process fetches.
is_list_fetch = isinstance(fetches, (list, tuple))
if not is_list_fetch:
fetches = [fetches]
@@ -357,6 +459,26 @@ class BaseSession(SessionInterface):
fetch_info.append((subfetch_names, fetch_contraction_fn))
unique_fetch_targets = list(unique_fetch_targets)
+ return unique_fetch_targets, target_list, fetch_info
+
+ def _run(self, handle, fetches, feed_dict):
+ """Perform either run or partial_run, depending the exitence of `handle`."""
+ def _feed_fn(feed, feed_val):
+ for tensor_type, _, feed_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
+ if isinstance(feed, tensor_type):
+ return feed_fn(feed, feed_val)
+ raise TypeError('Feed argument %r has invalid type %r'
+ % (feed, type(feed)))
+
+ # Check session.
+ if self._closed:
+ raise RuntimeError('Attempted to use a closed Session.')
+ if self.graph.version == 0:
+ raise RuntimeError('The Session graph is empty. Add operations to the '
+ 'graph before calling run().')
+
+ # Validate and process fetches.
+ unique_fetches, target_list, fetch_info = self._process_fetches(fetches)
# Create request.
feed_dict_string = {}
@@ -386,13 +508,14 @@ class BaseSession(SessionInterface):
feed_dict_string[compat.as_bytes(subfeed_t.name)] = np_val
# Run request and get response.
- results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
+ results = self._do_run(handle, target_list, unique_fetches,
+ feed_dict_string)
# User may have fetched the same tensor multiple times, but we
# only fetch them from the runtime once. Furthermore, they may
# be wrapped as a tuple of tensors. Here we map the results back
# to what the client asked for.
- fetched_results = dict(zip(unique_fetch_targets, results))
+ fetched_results = dict(zip(unique_fetches, results))
ret = []
for fetch_names, fetch_contraction_fn in fetch_info:
if fetch_names:
@@ -401,7 +524,7 @@ class BaseSession(SessionInterface):
else:
ret.append(None)
- if is_list_fetch:
+ if isinstance(fetches, (list, tuple)):
return ret
else:
return ret[0]
@@ -409,10 +532,11 @@ class BaseSession(SessionInterface):
# Captures the name of a node in an error status.
_NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')
- def _do_run(self, target_list, fetch_list, feed_dict):
+ def _do_run(self, handle, target_list, fetch_list, feed_dict):
"""Runs a step based on the given fetches and feeds.
Args:
+ handle: a handle for partial_run. None if this is just a call to run().
target_list: A list of byte arrays corresponding to names of tensors
or operations to be run to, but not fetched.
fetch_list: A list of byte arrays corresponding to names of tensors to
@@ -426,28 +550,26 @@ class BaseSession(SessionInterface):
name of an operation, the first Tensor output of that operation
will be returned for that element.
"""
- try:
+ def _run_fn(session, feed_dict, fetch_list, target_list):
# Ensure any changes to the graph are reflected in the runtime.
- with self._extend_lock:
- if self._graph.version > self._current_version:
- graph_def = self._graph.as_graph_def(
- from_version=self._current_version)
+ self._extend_graph()
+ return tf_session.TF_Run(session, feed_dict, fetch_list, target_list)
- try:
- status = tf_session.TF_NewStatus()
- tf_session.TF_ExtendGraph(
- self._session, graph_def.SerializeToString(), status)
- if tf_session.TF_GetCode(status) != 0:
- raise RuntimeError(compat.as_text(tf_session.TF_Message(status)))
- self._opened = True
- finally:
- tf_session.TF_DeleteStatus(status)
-
- self._current_version = self._graph.version
+ def _prun_fn(session, handle, feed_dict, fetch_list):
+ if target_list:
+ raise RuntimeError('partial_run() requires empty target_list.')
+ return tf_session.TF_PRun(session, handle, feed_dict, fetch_list)
- return tf_session.TF_Run(self._session, feed_dict, fetch_list,
- target_list)
+ if handle is None:
+ return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
+ target_list)
+ else:
+ return self._do_call(_prun_fn, self._session, handle, feed_dict,
+ fetch_list)
+ def _do_call(self, fn, *args):
+ try:
+ return fn(*args)
except tf_session.StatusNotOK as e:
e_type, e_value, e_traceback = sys.exc_info()
error_message = compat.as_text(e.error_message)
@@ -466,6 +588,25 @@ class BaseSession(SessionInterface):
# pylint: enable=protected-access
six.reraise(e_type, e_value, e_traceback)
+ def _extend_graph(self):
+ # Ensure any changes to the graph are reflected in the runtime.
+ with self._extend_lock:
+ if self._graph.version > self._current_version:
+ graph_def = self._graph.as_graph_def(
+ from_version=self._current_version)
+
+ try:
+ status = tf_session.TF_NewStatus()
+ tf_session.TF_ExtendGraph(
+ self._session, graph_def.SerializeToString(), status)
+ if tf_session.TF_GetCode(status) != 0:
+ raise RuntimeError(compat.as_text(tf_session.TF_Message(status)))
+ self._opened = True
+ finally:
+ tf_session.TF_DeleteStatus(status)
+
+ self._current_version = self._graph.version
+
class Session(BaseSession):
"""A class for running TensorFlow operations.
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index c7d5453f57..39f720629a 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -811,5 +811,50 @@ class SessionTest(test_util.TensorFlowTestCase):
sess_2.run(c_1.op)
self.assertEqual(2.0, sess_2.run(c_2))
+ def testPartialRun(self):
+ with session.Session() as sess:
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+ r2 = math_ops.mul(r1, c)
+
+ h = sess.partial_run_setup([r1, r2], [a, b, c])
+ res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
+ self.assertEqual(3, res)
+ temp = res * 17
+ res = sess.partial_run(h, r2, feed_dict={c: temp})
+ self.assertEqual(153, res)
+
+ def testPartialRunIncomplete(self):
+ with session.Session() as sess:
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+ r2 = math_ops.mul(r1, c)
+
+ h = sess.partial_run_setup([r1, r2], [a, b, c])
+ res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
+ self.assertEqual(3, res)
+
+ def testConcurrentPartialRun(self):
+ with session.Session() as sess:
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+ r2 = math_ops.mul(r1, c)
+
+ h1 = sess.partial_run_setup([r1], [a, b, c])
+ h2 = sess.partial_run_setup([r1, r2], [a, b, c])
+ res = sess.partial_run(h1, r1, feed_dict={a: 1, b: 2})
+ self.assertEqual(3, res)
+ temp = res * 19
+ res = sess.partial_run(h2, r1, feed_dict={a: temp, b: 9})
+ self.assertEqual(66, res)
+ res = sess.partial_run(h2, r2, feed_dict={c: 7})
+ self.assertEqual(462, res)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 3bf7623eb2..8e8e3132f1 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -170,6 +170,10 @@ tensorflow::ImportNumpy();
tensorflow::PyObjectVector temp) {
$1 = &temp;
}
+%typemap(in, numinputs=0) char** out_handle (
+ char* temp) {
+ $1 = &temp;
+}
// Raise a StatusNotOK exception if the out_status is not OK;
// otherwise build a Python list of outputs and return it.
@@ -196,6 +200,19 @@ tensorflow::ImportNumpy();
}
}
+// Raise a StatusNotOK exception if the out_status is not OK;
+// otherwise return the handle as a python string object.
+%typemap(argout, fragment="StatusNotOK") (
+ tensorflow::Status* out_status, char** out_handle) {
+ if (!$1->ok()) {
+ RaiseStatusNotOK(*$1, $descriptor(tensorflow::Status*));
+ SWIG_fail;
+ } else {
+ $result = PyString_FromStringAndSize(*$2, strlen(*$2));
+ delete *$2;
+ }
+}
+
////////////////////////////////////////////////////////////////////////////////
// END TYPEMAPS FOR tensorflow::TF_Run_wrapper()
////////////////////////////////////////////////////////////////////////////////
@@ -264,6 +281,26 @@ tensorflow::ImportNumpy();
%unignore TF_Run;
%unignore EqualGraphDefWrapper;
+// Include the wrapper for TF_PRunSetup from tf_session_helper.h.
+
+// The %exception block above releases the Python GIL for the length
+// of each wrapped method. We disable this behavior for TF_PRunSetup
+// because it uses the Python allocator.
+%noexception tensorflow::TF_PRunSetup_wrapper;
+%rename(TF_PRunSetup) tensorflow::TF_PRunSetup_wrapper;
+%unignore tensorflow;
+%unignore TF_PRunSetup;
+
+// Include the wrapper for TF_PRun from tf_session_helper.h.
+
+// The %exception block above releases the Python GIL for the length
+// of each wrapped method. We disable this behavior for TF_PRun
+// because it uses the Python allocator.
+%noexception tensorflow::TF_PRun_wrapper;
+%rename(TF_PRun) tensorflow::TF_PRun_wrapper;
+%unignore tensorflow;
+%unignore TF_PRun;
+
%include "tensorflow/python/client/tf_session_helper.h"
%unignoreall
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index 54ebc2447e..2698230b11 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -425,13 +425,11 @@ Safe_PyObjectPtr make_safe(PyObject* o) {
return Safe_PyObjectPtr(o, Py_DECREF_wrapper);
}
-// Wrapper for TF_Run that converts the arguments to appropriate types.
-// If *out_status is OK, the caller becomes the owner of the PyObjects
-// in *out_values.
-void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs,
- const NameVector& output_names,
- const NameVector& target_nodes, Status* out_status,
- PyObjectVector* out_values) {
+void TF_Run_wrapper_helper(TF_Session* session, const char* handle,
+ const FeedVector& inputs,
+ const NameVector& output_names,
+ const NameVector& target_nodes, Status* out_status,
+ PyObjectVector* out_values) {
// 1. Convert the feed inputs to the appropriate form for TF_Run.
NameVector input_names;
Safe_PyObjectVector
@@ -514,10 +512,20 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs,
// 3. Actually call TF_Run().
Py_BEGIN_ALLOW_THREADS;
- TF_Run(session, input_names.data(), inputs_unsafe.data(), input_names.size(),
- const_cast<const char**>(output_names.data()), outputs.data(),
- output_names.size(), const_cast<const char**>(target_nodes.data()),
- target_nodes.size(), status.get());
+ if (handle == nullptr) {
+ TF_Run(session, input_names.data(), inputs_unsafe.data(),
+ input_names.size(), const_cast<const char**>(output_names.data()),
+ outputs.data(), output_names.size(),
+ const_cast<const char**>(target_nodes.data()), target_nodes.size(),
+ status.get());
+ } else {
+ TF_PRun(session, handle, input_names.data(), inputs_unsafe.data(),
+ input_names.size(), const_cast<const char**>(output_names.data()),
+ outputs.data(), output_names.size(),
+ const_cast<const char**>(target_nodes.data()), target_nodes.size(),
+ status.get());
+ }
+
Py_END_ALLOW_THREADS;
// 4. The TensorFlow runtime has taken ownership of the fed tensors,
@@ -558,6 +566,49 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs,
*out_status = Status::OK();
}
+// Wrapper for TF_Run that converts the arguments to appropriate types.
+// If *out_status is OK, the caller becomes the owner of the PyObjects
+// in *out_values.
+void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs,
+ const NameVector& output_names,
+ const NameVector& target_nodes, Status* out_status,
+ PyObjectVector* out_values) {
+ TF_Run_wrapper_helper(session, nullptr, inputs, output_names, target_nodes,
+ out_status, out_values);
+}
+
+// Wrapper for TF_PRunSetup that converts the arguments to appropriate types.
+// If *out_status is OK, the caller becomes the owner of *out_handle.
+void TF_PRunSetup_wrapper(TF_Session* session, const NameVector& input_names,
+ const NameVector& output_names,
+ const NameVector& target_nodes, Status* out_status,
+ char** out_handle) {
+ Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
+ Py_BEGIN_ALLOW_THREADS;
+ TF_PRunSetup(
+ session, const_cast<const char**>(input_names.data()), input_names.size(),
+ const_cast<const char**>(output_names.data()), output_names.size(),
+ const_cast<const char**>(target_nodes.data()), target_nodes.size(),
+ out_handle, status.get());
+ Py_END_ALLOW_THREADS;
+
+ if (TF_GetCode(status.get()) != TF_OK) {
+ *out_status = TF_Status_to_Status(status.get());
+ return;
+ }
+ *out_status = Status::OK();
+}
+
+// Wrapper for TF_PRun that converts the arguments to appropriate types.
+// If *out_status is OK, the caller becomes the owner of the PyObjects
+// in *out_values.
+void TF_PRun_wrapper(TF_Session* session, const char* handle,
+ const FeedVector& inputs, const NameVector& output_names,
+ Status* out_status, PyObjectVector* out_values) {
+ TF_Run_wrapper_helper(session, handle, inputs, output_names, NameVector(),
+ out_status, out_values);
+}
+
void ImportNumpy() { import_array1(); }
string EqualGraphDefWrapper(const string& actual, const string& expected) {
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index 60bbab69f3..a9a38da6e5 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -65,7 +65,7 @@ typedef std::vector<Safe_PyObjectPtr> Safe_PyObjectVector;
Safe_PyObjectPtr make_safe(PyObject* o);
// Run the graph associated with the session starting with the
-// supplied inputs[]. Regardless of success of failure, inputs[] are
+// supplied inputs[]. Regardless of success or failure, inputs[] are
// stolen by the implementation (i.e. the implementation will
// eventually call Py_DECREF on each array input).
//
@@ -80,6 +80,35 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs,
const NameVector& target_nodes, Status* out_status,
PyObjectVector* out_values);
+// Set up the graph with the intended feeds and fetches for partial run.
+// *out_handle is owned by the caller.
+//
+// On success, returns a handle that is used for subsequent PRun calls.
+//
+// On failure, out_status contains a tensorflow::Status with an error
+// message.
+//
+// NOTE: This is EXPERIMENTAL and subject to change.
+void TF_PRunSetup_wrapper(TF_Session* session, const NameVector& input_names,
+ const NameVector& output_names,
+ const NameVector& target_nodes, Status* out_status,
+ char** out_handle);
+
+// Continue to run the graph with additional feeds and fetches. The
+// execution state is uniquely identified by the handle.
+//
+// On success, the tensors corresponding to output_names[0,noutputs-1]
+// are placed in out_values[], and these outputs[] become the property
+// of the caller (the caller must eventually call Py_DECREF on them).
+//
+// On failure, out_status contains a tensorflow::Status with an error
+// message.
+//
+// NOTE: This is EXPERIMENTAL and subject to change.
+void TF_PRun_wrapper(TF_Session* session, const char* handle,
+ const FeedVector& inputs, const NameVector& output_names,
+ Status* out_status, PyObjectVector* out_values);
+
// Import numpy. This wrapper function exists so that the
// PY_ARRAY_UNIQUE_SYMBOL can be safely defined in a .cc file to
// avoid weird linking issues.
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 4845cadbd3..f4389e58b2 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -3338,6 +3338,8 @@ class GraphKeys(object):
MOVING_AVERAGE_VARIABLES = "moving_average_variables"
# Key to collected regularization losses at graph construction.
REGULARIZATION_LOSSES = "regularization_losses"
+ # Key to collect concatenated sharded variables.
+ CONCATENATED_VARIABLES = "concatenated_variables"
def add_to_collection(name, value):
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
index 5f593bce3d..c3bcfed900 100644
--- a/tensorflow/python/framework/python_op_gen.cc
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -114,20 +114,6 @@ void AppendWithinWidth(string* dest, StringPiece append, int width) {
}
}
-void RemoveDescriptionsFromOpDef(OpDef* op_def) {
- for (int i = 0; i < op_def->input_arg_size(); ++i) {
- op_def->mutable_input_arg(i)->clear_description();
- }
- for (int i = 0; i < op_def->output_arg_size(); ++i) {
- op_def->mutable_output_arg(i)->clear_description();
- }
- for (int i = 0; i < op_def->attr_size(); ++i) {
- op_def->mutable_attr(i)->clear_description();
- }
- op_def->clear_summary();
- op_def->clear_description();
-}
-
// Like DataTypeString() but uses the Python names for the
// float types.
string PythonDataTypeString(DataType dtype) {
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index c9344d63ff..eac25fb2c6 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -255,7 +255,7 @@ class LSTMTest(tf.test.TestCase):
input_size = 5
batch_size = 2
num_proj = 4
- num_proj_shards = 4
+ num_proj_shards = 3
num_unit_shards = 2
max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
@@ -281,12 +281,37 @@ class LSTMTest(tf.test.TestCase):
input_value = np.random.randn(batch_size, input_size)
sess.run(outputs, feed_dict={inputs[0]: input_value})
+ def _testTooManyShards(self, use_gpu):
+ num_units = 3
+ input_size = 5
+ num_proj = 4
+ num_proj_shards = 4
+ num_unit_shards = 2
+ max_length = 8
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()):
+ initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
+
+ inputs = max_length * [
+ tf.placeholder(tf.float32, shape=(None, input_size))]
+
+ cell = tf.nn.rnn_cell.LSTMCell(
+ num_units,
+ input_size=input_size,
+ use_peepholes=True,
+ num_proj=num_proj,
+ num_unit_shards=num_unit_shards,
+ num_proj_shards=num_proj_shards,
+ initializer=initializer)
+
+ with self.assertRaises(ValueError):
+ tf.nn.rnn(cell, inputs, dtype=tf.float32)
+
def _testDoubleInput(self, use_gpu):
num_units = 3
input_size = 5
batch_size = 2
num_proj = 4
- num_proj_shards = 4
+ num_proj_shards = 3
num_unit_shards = 2
max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
@@ -318,7 +343,7 @@ class LSTMTest(tf.test.TestCase):
input_size = 5
batch_size = 2
num_proj = 4
- num_proj_shards = 4
+ num_proj_shards = 3
num_unit_shards = 2
max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
@@ -369,7 +394,7 @@ class LSTMTest(tf.test.TestCase):
input_size = 5
batch_size = 2
num_proj = 4
- num_proj_shards = 4
+ num_proj_shards = 3
num_unit_shards = 2
max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
@@ -494,6 +519,10 @@ class LSTMTest(tf.test.TestCase):
self._testProjSharding(use_gpu=False)
self._testProjSharding(use_gpu=True)
+ def testTooManyShards(self):
+ self._testTooManyShards(use_gpu=False)
+ self._testTooManyShards(use_gpu=True)
+
def testShardNoShardEquivalentOutput(self):
self._testShardNoShardEquivalentOutput(use_gpu=False)
self._testShardNoShardEquivalentOutput(use_gpu=True)
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
new file mode 100644
index 0000000000..13c519857d
--- /dev/null
+++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
@@ -0,0 +1,256 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for SparseSoftmaxCrossEntropyWithLogits op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=g-bad-import-order
+
+# pylint: disable=unused-import
+import tensorflow.python.platform
+# pylint: enable=unused_import
+
+import numpy as np
+import tensorflow as tf
+import sys
+import time
+
+from tensorflow.python.client import graph_util
+from tensorflow.python.ops import sparse_ops
+
+# pylint: enable=g-bad-import-order
+
+
+class SparseXentTest(tf.test.TestCase):
+
+ def _npXent(self, features, labels):
+ batch_dim = 0
+ class_dim = 1
+ batch_size = features.shape[batch_dim]
+ e = np.exp(features -
+ np.reshape(np.amax(features, axis=class_dim), [batch_size, 1]))
+ probs = e / np.reshape(np.sum(e, axis=class_dim), [batch_size, 1])
+ labels_mat = np.zeros_like(probs).astype(probs.dtype)
+ labels_mat[np.arange(batch_size), labels] = 1.0
+ bp = (probs - labels_mat)
+ l = -np.sum(labels_mat * np.log(probs + 1.0e-20), axis=1)
+ return l, bp
+
+ def _testXent(self, np_features, np_labels, use_gpu=False):
+ np_loss, np_backprop = self._npXent(np_features, np_labels)
+ with self.test_session(use_gpu=use_gpu) as sess:
+ loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ np_features, np_labels)
+ backprop = loss.op.outputs[1]
+ tf_loss, tf_backprop = sess.run([loss, backprop])
+ self.assertAllClose(np_loss, tf_loss)
+ self.assertAllClose(np_backprop, tf_backprop)
+
+ def _testAll(self, features, labels):
+ self._testXent(features, labels, use_gpu=False)
+ self._testXent(features, labels, use_gpu=True)
+
+ def _testSingleClass(self, use_gpu=False):
+ with self.test_session(use_gpu=use_gpu) as sess:
+ loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ np.array([[1.], [-1.], [0.]]).astype(np.float32),
+ np.array([1, 1, 1]).astype(np.int64))
+ backprop = loss.op.outputs[1]
+ tf_loss, tf_backprop = sess.run([loss, backprop])
+ # loss = -1.0*log(1.0), 1.0*log(1.0), 0.0*log(0.0) == 1.0
+ self.assertAllClose([0.0, 0.0, 0.0], tf_loss)
+ self.assertAllClose([[2.0], [1.0], [0.0]], tf_backprop)
+
+ # def testSingleClass(self):
+ # self._testSingleClass(True)
+ # self._testSingleClass(False)
+
+ def testRankTooLarge(self):
+ np_features = np.array(
+ [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]).astype(np.float32)
+ np_labels = np.array([1, 2]).astype(np.int64)
+ self.assertRaisesRegexp(
+ ValueError, "must have rank 2",
+ tf.nn.sparse_softmax_cross_entropy_with_logits, np_features, np_labels)
+
+ def testNpXent(self):
+ # We create 2 batches of logits for testing.
+ # batch 0 is the boring uniform distribution: 1, 1, 1, 1, with target 3.
+ # batch 1 has a bit of difference: 1, 2, 3, 4, with target 0.
+ features = [[1., 1., 1., 1.], [1., 2., 3., 4.]]
+ labels = [3, 0]
+
+ # For batch 0, we expect the uniform distribution: 0.25, 0.25, 0.25, 0.25
+ # With a hard target 3, the backprop is [0.25, 0.25, 0.25, -0.75]
+ # The loss for this batch is -log(0.25) = 1.386
+ #
+ # For batch 1, we have:
+ # exp(0) = 1
+ # exp(1) = 2.718
+ # exp(2) = 7.389
+ # exp(3) = 20.085
+ # SUM = 31.192
+ # So we have as probabilities:
+ # exp(0) / SUM = 0.032
+ # exp(1) / SUM = 0.087
+ # exp(2) / SUM = 0.237
+ # exp(3) / SUM = 0.644
+ # With a hard 1, the backprop is [0.032 - 1.0 = -0.968, 0.087, 0.237, 0.644]
+ # The loss for this batch is [1.0 * -log(0.25), 1.0 * -log(0.032)]
+ # = [1.3862, 3.4420]
+ np_loss, np_backprop = self._npXent(
+ np.array(features), np.array(labels, dtype=np.int64))
+ self.assertAllClose(np.array([[0.25, 0.25, 0.25, -0.75],
+ [-0.968, 0.087, 0.237, 0.6439]]),
+ np_backprop,
+ rtol=1.e-3, atol=1.e-3)
+ self.assertAllClose(np.array([1.3862, 3.4420]), np_loss,
+ rtol=1.e-3, atol=1.e-3)
+
+ def testShapeMismatch(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.nn.sparse_softmax_cross_entropy_with_logits(
+ [[0., 1.], [2., 3.]], [[0, 2]])
+
+ def testNotMatrix(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.nn.sparse_softmax_cross_entropy_with_logits(
+ [0., 1., 2., 3.], [0, 2])
+
+ def testFloat(self):
+ self._testAll(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32),
+ np.array([3, 0]).astype(np.int64))
+
+ def testDouble(self):
+ self._testXent(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64),
+ np.array([0, 3]).astype(np.int64),
+ use_gpu=False)
+
+ def testGradient(self):
+ with self.test_session():
+ l = tf.constant([3, 0, 1], dtype=tf.int64, name="l")
+ f = tf.constant([0.1, 0.2, 0.3, 0.4,
+ 0.1, 0.4, 0.9, 1.6,
+ 0.1, 0.8, 2.7, 6.4], shape=[3, 4],
+ dtype=tf.float64, name="f")
+ x = tf.nn.sparse_softmax_cross_entropy_with_logits(f, l, name="xent")
+ err = tf.test.compute_gradient_error(f, [3, 4], x, [3])
+ print("cross entropy gradient err = ", err)
+ self.assertLess(err, 5e-8)
+
+
+def _sparse_vs_dense_xent_benchmark_dense(labels, logits):
+ labels = tf.identity(labels)
+ logits = tf.identity(logits)
+ with tf.device("/cpu:0"): # Sparse-to-dense must be on CPU
+ batch_size = tf.shape(logits)[0]
+ num_entries = tf.shape(logits)[1]
+ length = batch_size * num_entries
+ labels += num_entries * tf.range(batch_size)
+ target = sparse_ops.sparse_to_dense(
+ labels, tf.pack([length]), 1.0, 0.0)
+ target = tf.reshape(target, tf.pack([-1, num_entries]))
+ crossent = tf.nn.softmax_cross_entropy_with_logits(
+ logits, target, name="SequenceLoss/CrossEntropy")
+ crossent_sum = tf.reduce_sum(crossent)
+ grads = tf.gradients([crossent_sum], [logits])[0]
+
+ return (crossent_sum, grads)
+
+
+def _sparse_vs_dense_xent_benchmark_sparse(labels, logits):
+ # Using sparse_softmax_cross_entropy_with_logits
+ labels = labels.astype(np.int64)
+ labels = tf.identity(labels)
+ logits = tf.identity(logits)
+ crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ logits, labels, name="SequenceLoss/CrossEntropy")
+ crossent_sum = tf.reduce_sum(crossent)
+ grads = tf.gradients([crossent_sum], [logits])[0]
+
+ return (crossent_sum, grads)
+
+
+def sparse_vs_dense_xent_benchmark(batch_size, num_entries, use_gpu):
+ config = tf.ConfigProto()
+ config.allow_soft_placement = True
+ config.gpu_options.per_process_gpu_memory_fraction = 0.3
+ labels = np.random.randint(num_entries, size=batch_size).astype(np.int32)
+ logits = np.random.randn(batch_size, num_entries).astype(np.float32)
+
+ def _timer(sess, ops):
+ # Warm in
+ for _ in range(20):
+ sess.run(ops)
+
+ # Timing run
+ start = time.time()
+ for _ in range(20):
+ sess.run(ops)
+ end = time.time()
+
+ return (end - start)/20.0 # Average runtime per iteration
+
+ # Using sparse_to_dense and softmax_cross_entropy_with_logits
+ with tf.Session(config=config) as sess:
+ if not use_gpu:
+ with tf.device(graph_util.pin_to_cpu):
+ ops = _sparse_vs_dense_xent_benchmark_dense(labels, logits)
+ else:
+ ops = _sparse_vs_dense_xent_benchmark_dense(labels, logits)
+ delta_dense = _timer(sess, ops)
+
+ # Using sparse_softmax_cross_entropy_with_logits
+ with tf.Session(config=config) as sess:
+ if not use_gpu:
+ with tf.device(graph_util.pin_to_cpu):
+ ops = _sparse_vs_dense_xent_benchmark_sparse(labels, logits)
+ else:
+ ops = _sparse_vs_dense_xent_benchmark_sparse(labels, logits)
+ delta_sparse = _timer(sess, ops)
+
+ print(
+ "%d \t %d \t %s \t %f \t %f \t %f"
+ % (batch_size, num_entries, use_gpu, delta_dense, delta_sparse,
+ delta_sparse/delta_dense))
+
+
+def main(_):
+ print("Sparse Xent vs. SparseToDense + Xent")
+ print("batch \t depth \t gpu \t dt(dense) \t dt(sparse) "
+ "\t dt(sparse)/dt(dense)")
+ for use_gpu in (False, True):
+ for batch_size in (32, 64, 128):
+ for num_entries in (100, 1000, 10000):
+ sparse_vs_dense_xent_benchmark(
+ batch_size, num_entries, use_gpu)
+ sparse_vs_dense_xent_benchmark(
+ 32, 100000, use_gpu)
+ sparse_vs_dense_xent_benchmark(
+ 8, 1000000, use_gpu)
+
+
+if __name__ == "__main__":
+ if "--benchmarks" in sys.argv:
+ sys.argv.remove("--benchmarks")
+ tf.app.run()
+ else:
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index bc8c8b07eb..196d2d0931 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -416,23 +416,20 @@ class TensorArrayTest(tf.test.TestCase):
dtype=tf.float32, tensor_array_name="foo", size=3)
time_0 = tf.identity(0)
- def body(time, flow, state):
+ def body(time, h_t, state):
sliced = tf.slice(v0, begin=tf.pack([time, 0]), size=[1, -1])
sliced = tf.squeeze(sliced)
out = sliced + var + state
state += sliced
- h_n = h
- h_n._flow = flow
- h_n = h_n.write(time, out)
- return (time+1, h_n.flow, state)
+ h_t = h_t.write(time, out)
+ return (time+1, h_t, state)
- (unused_0, final_flow, unused_2) = control_flow_ops.While(
+ (unused_0, h_final, unused_2) = control_flow_ops.While(
cond=lambda time, unused_1, unused_2: time < 3,
body=body,
- loop_vars=(time_0, h.flow, state0),
+ loop_vars=(time_0, h, state0),
parallel_iterations=3)
- h._flow = final_flow
- vout = h.pack()
+ vout = h_final.pack()
grad_val = -np.arange(3*5, dtype=np.float32).reshape(3, 5)
v0_grad = tf.gradients([vout], [v0], [grad_val])[0]
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index d9caa55d5a..51455a9ee7 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -68,6 +68,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import six
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -81,6 +82,7 @@ from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import tensor_array_ops
# pylint: disable=wildcard-import,undefined-variable
from tensorflow.python.ops.gen_control_flow_ops import *
from tensorflow.python.platform import logging
@@ -283,6 +285,25 @@ def _SwitchRefOrTensor(data, pred, name="Switch"):
return switch(data, pred, name=name)
+def _convert_tensorarrays_to_flows(tensors_or_tensor_arrays):
+ return [ta.flow if isinstance(ta, tensor_array_ops.TensorArray)
+ else ta
+ for ta in tensors_or_tensor_arrays]
+
+
+def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
+ if len(tensors_or_tensorarrays) != len(tensors_or_flows):
+ raise ValueError(
+ "Lengths of original Tensor list and new list do not match: %d vs. %d"
+ % (len(tensors_or_tensorarrays), len(tensors_or_flows)))
+ return [
+ tensor_array_ops.TensorArray(
+ dtype=ta.dtype, handle=ta.handle, flow=t_or_flow)
+ if isinstance(ta, tensor_array_ops.TensorArray)
+ else t_or_flow
+ for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)]
+
+
class ControlFlowOpWrapper(object):
"""A wrapper class for Operation.
@@ -1402,6 +1423,10 @@ class WhileContext(ControlFlowContext):
def BuildLoop(self, pred, body, loop_vars):
"""Add the loop termination condition and body to the graph."""
+ # Keep original_loop_vars to identify which are TensorArrays
+ original_loop_vars = loop_vars
+ # Connvert TensorArrays to their flow variables
+ loop_vars = _convert_tensorarrays_to_flows(loop_vars)
loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
# Let the context know the loop variabes so the loop variables
# would be added in the outer contexts properly.
@@ -1421,18 +1446,28 @@ class WhileContext(ControlFlowContext):
self._pivot_for_pred = merge_vars[0]
# Build the graph for pred.
- c = ops.convert_to_tensor(pred(*merge_vars))
+ merge_vars_with_tensor_arrays = (
+ _convert_flows_to_tensorarrays(original_loop_vars, merge_vars))
+ c = ops.convert_to_tensor(pred(*merge_vars_with_tensor_arrays))
self._pivot = loop_cond(c, name="LoopCond")
switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
# Build the graph for body.
vars_for_body = [_Identity(x[1]) for x in switch_vars]
self._pivot_for_body = vars_for_body[0]
+ # Convert TensorArray flow variables inside the context back into
+ # their associated TensorArrays for calling the body.
+ vars_for_body_with_tensor_arrays = (
+ _convert_flows_to_tensorarrays(original_loop_vars, vars_for_body))
- body_result = body(*vars_for_body)
- if not isinstance(body_result, (list, _basetuple)):
+ body_result = body(*vars_for_body_with_tensor_arrays)
+ if not isinstance(body_result, collections.Sequence):
body_result = [body_result]
- result = ops.convert_n_to_tensor_or_indexed_slices(body_result)
+ # Store body_result to keep track of TensorArrays returned by body
+ original_body_result = body_result
+ # Convert TensorArrays returned by body into their flow variables
+ result = _convert_tensorarrays_to_flows(body_result)
+ result = ops.convert_n_to_tensor_or_indexed_slices(result)
next_vars = [_NextIteration(x) for x in result]
# Add the back edges to complete the loop.
@@ -1450,7 +1485,14 @@ class WhileContext(ControlFlowContext):
# Exit the loop.
self.ExitResult(exit_vars)
- return exit_vars[0] if len(exit_vars) == 1 else exit_vars
+
+ # Convert TensorArray flow variables outside the context back into
+ # their associated TensorArrays for returning to caller.
+ exit_vars_with_tensor_arrays = (
+ _convert_flows_to_tensorarrays(original_body_result, exit_vars))
+ return (exit_vars_with_tensor_arrays[0]
+ if len(exit_vars) == 1
+ else exit_vars_with_tensor_arrays)
def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
@@ -1462,6 +1504,10 @@ def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
tensors of the same length and with the same types as the input. `loop_vars`
is a list of tensors that is passed to both `cond` and `body`.
+ In addition to regular Tensors or IndexedSlices, the body may accept and
+ return TensorArray objects. The flows of the TensorArray objects will
+ be appropriately forwarded between loops and during gradient calculations.
+
While `cond` evaluates to true, `body` is executed.
Args:
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index c3d966527e..2dc3ab5095 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -151,6 +151,7 @@ TensorFlow provides several operations that help you perform classification.
@@sigmoid_cross_entropy_with_logits
@@softmax
@@softmax_cross_entropy_with_logits
+@@sparse_softmax_cross_entropy_with_logits
## Embeddings
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 0a7e3b214a..72ec25b96c 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -158,6 +158,14 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
return _BroadcastMul(grad_0, op.outputs[1]), None
+@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")
+def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
+ # grad_0 is the backprop for cost, and we multiply it with the gradients
+ # (which is output[1])
+ # There is no gradient for the labels
+ return _BroadcastMul(grad_0, op.outputs[1]), None
+
+
@ops.RegisterGradient("Conv2D")
def _Conv2DGrad(op, grad):
return [nn_ops.conv2d_backprop_input(array_ops.shape(op.inputs[0]),
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index bdc32360a3..6c0cd339e6 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -157,6 +157,12 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None):
example, each CIFAR-10 image is labeled with one and only one label: an image
can be a dog or a truck, but not both.
+ **NOTE:**: While the classes are mutually exclusive, their probabilities
+ need not be. All that is required is that each row of `labels` is
+ a valid probability distribution. If using exclusive `labels`
+ (wherein one and only one class is true at a time), see
+ `sparse_softmax_cross_entropy_with_logits`.
+
**WARNING:** This op expects unscaled logits, since it performs a `softmax`
on `logits` internally for efficiency. Do not call this op with the
output of `softmax`, as it will produce incorrect results.
@@ -180,6 +186,57 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None):
return cost
+def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
+ """Computes sparse softmax cross entropy between `logits` and `labels`.
+
+ Measures the probability error in discrete classification tasks in which the
+ classes are mutually exclusive (each entry is in exactly one class). For
+ example, each CIFAR-10 image is labeled with one and only one label: an image
+ can be a dog or a truck, but not both.
+
+ **NOTE:**: For this operation, the probability of a given label is considered
+ exclusive. That is, soft classes are not allowed, and the `labels` vector
+ must provide a single specific index for the true class for each row of
+ `logits` (each minibatch entry). For soft softmax classification with
+ a probability distribution for each entry, see
+ `softmax_cross_entropy_with_logits`.
+
+ **WARNING:** This op expects unscaled logits, since it performs a `softmax`
+ on `logits` internally for efficiency. Do not call this op with the
+ output of `softmax`, as it will produce incorrect results.
+
+ `logits` and must have the shape `[batch_size, num_classes]`
+ and the dtype (either `float32` or `float64`).
+
+ `labels` must have the shape `[batch_size]` and the dtype `int64`.
+
+ Args:
+ logits: Unscaled log probabilities.
+ labels: Each entry `labels[i]` must be an index in `[0, num_classes)`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
+ softmax cross entropy loss.
+ """
+ # The second output tensor contains the gradients. We use it in
+ # _CrossEntropyGrad() in nn_grad but not here.
+ cost, unused_backprop = gen_nn_ops._sparse_softmax_cross_entropy_with_logits(
+ logits, labels, name=name)
+ return cost
+
+
+@ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits")
+def _SparseSoftmaxCrossEntropyWithLogitsShape(op):
+ """Shape function for SparseSoftmaxCrossEntropyWithLogits op."""
+ logits_shape = op.inputs[0].get_shape()
+ input_shape = logits_shape.with_rank(2)
+ batch_size = input_shape[0]
+ # labels_shape
+ op.inputs[1].get_shape().merge_with(tensor_shape.vector(batch_size))
+ return [tensor_shape.vector(batch_size.value), input_shape]
+
+
@ops.RegisterShape("SoftmaxCrossEntropyWithLogits")
def _SoftmaxCrossEntropyWithLogitsShape(op):
"""Shape function for SoftmaxCrossEntropyWithLogits op."""
diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py
index 28f33b1e1a..820af12626 100644
--- a/tensorflow/python/ops/rnn_cell.py
+++ b/tensorflow/python/ops/rnn_cell.py
@@ -210,24 +210,42 @@ class BasicLSTMCell(RNNCell):
return new_h, array_ops.concat(1, [new_c, new_h])
-def _get_sharded_variable(name, shape, initializer, dtype, num_shards):
- """Get a list of sharded variables with the given dtype and initializer."""
- unit_shard_size = int(math.ceil(shape[1] / num_shards))
+def _get_concat_variable(name, shape, dtype, num_shards):
+ """Get a sharded variable concatenated into one tensor."""
+ sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
+ if len(sharded_variable) == 1:
+ return sharded_variable[0]
+
+ concat_name = name + "/concat"
+ concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
+ for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
+ if value.name == concat_full_name:
+ return value
+
+ concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
+ ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
+ concat_variable)
+ return concat_variable
+
+
+def _get_sharded_variable(name, shape, dtype, num_shards):
+ """Get a list of sharded variables with the given dtype."""
+ if num_shards > shape[0]:
+ raise ValueError("Too many shards: shape=%s, num_shards=%d" %
+ (shape, num_shards))
+ unit_shard_size = int(math.floor(shape[0] / num_shards))
+ remaining_rows = shape[0] - unit_shard_size * num_shards
shards = []
for i in range(num_shards):
- current_size = min(unit_shard_size, shape[1] - unit_shard_size * i)
- shards.append(vs.get_variable(name + "_%d" % i, [shape[0], current_size],
- initializer=initializer, dtype=dtype))
+ current_size = unit_shard_size
+ if i < remaining_rows:
+ current_size += 1
+ shards.append(vs.get_variable(name + "_%d" % i, [current_size, shape[1]],
+ dtype=dtype))
return shards
-def _matmul_with_sharded_variable(tensor, sharded_tensor):
- """Multiply tensor with each tensor in sharded_tensor, column-concatenated."""
- return array_ops.concat(1, [math_ops.matmul(tensor, shard)
- for shard in sharded_tensor])
-
-
class LSTMCell(RNNCell):
"""Long short-term memory unit (LSTM) recurrent network cell.
@@ -317,10 +335,11 @@ class LSTMCell(RNNCell):
dtype = input_.dtype
- with vs.variable_scope(scope or type(self).__name__): # "LSTMCell"
- sharded_w = _get_sharded_variable(
+ with vs.variable_scope(scope or type(self).__name__,
+ initializer=self._initializer): # "LSTMCell"
+ concat_w = _get_concat_variable(
"W", [self.input_size + num_proj, 4 * self._num_units],
- self._initializer, dtype, self._num_unit_shards)
+ dtype, self._num_unit_shards)
b = vs.get_variable(
"B", shape=[4 * self._num_units],
@@ -328,24 +347,17 @@ class LSTMCell(RNNCell):
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
cell_inputs = array_ops.concat(1, [input_, m_prev])
- lstm_matrix = nn_ops.bias_add(
- _matmul_with_sharded_variable(cell_inputs, sharded_w), b)
+ lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
i, j, f, o = array_ops.split(1, 4, lstm_matrix)
# Diagonal connections
if self._use_peepholes:
w_f_diag = vs.get_variable(
- "W_F_diag", shape=[self._num_units],
- initializer=self._initializer,
- dtype=dtype)
+ "W_F_diag", shape=[self._num_units], dtype=dtype)
w_i_diag = vs.get_variable(
- "W_I_diag", shape=[self._num_units],
- initializer=self._initializer,
- dtype=dtype)
+ "W_I_diag", shape=[self._num_units], dtype=dtype)
w_o_diag = vs.get_variable(
- "W_O_diag", shape=[self._num_units],
- initializer=self._initializer,
- dtype=dtype)
+ "W_O_diag", shape=[self._num_units], dtype=dtype)
if self._use_peepholes:
c = (sigmoid(f + 1 + w_f_diag * c_prev) * c_prev +
@@ -362,11 +374,11 @@ class LSTMCell(RNNCell):
m = sigmoid(o) * tanh(c)
if self._num_proj is not None:
- sharded_w_proj = _get_sharded_variable(
- "W_P", [self._num_units, self._num_proj], self._initializer,
+ concat_w_proj = _get_concat_variable(
+ "W_P", [self._num_units, self._num_proj],
dtype, self._num_proj_shards)
- m = _matmul_with_sharded_variable(m, sharded_w_proj)
+ m = math_ops.matmul(m, concat_w_proj)
return m, array_ops.concat(1, [c, m])
diff --git a/tensorflow/python/ops/tensor_array_grad.py b/tensorflow/python/ops/tensor_array_grad.py
index 23eb02c39c..05cb4f12ad 100644
--- a/tensorflow/python/ops/tensor_array_grad.py
+++ b/tensorflow/python/ops/tensor_array_grad.py
@@ -102,9 +102,8 @@ def _TensorArrayWriteGrad(op, flow):
dtype = op.get_attr("T")
grad_source = _GetGradSource(flow)
g = tensor_array_ops.TensorArray(size=None, dtype=dtype, handle=handle).grad(
- source=grad_source)
- with ops.control_dependencies([flow]):
- grad = g.read(index)
+ source=grad_source, flow=flow)
+ grad = g.read(index)
return [None, None, grad, flow]
@@ -144,8 +143,7 @@ def _TensorArrayUnpackGrad(op, flow):
dtype = op.get_attr("T")
grad_source = _GetGradSource(flow)
g = tensor_array_ops.TensorArray(size=None, dtype=dtype, handle=handle).grad(
- source=grad_source)
- with ops.control_dependencies([flow]):
- grad = g.pack()
+ source=grad_source, flow=flow)
+ grad = g.pack()
return [None, grad, flow]
# pylint: enable=protected-access
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index c12506694f..901dfbe913 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -46,8 +46,8 @@ class TensorArray(object):
@@grad
"""
- def __init__(
- self, dtype, size=None, tensor_array_name=None, handle=None, name=None):
+ def __init__(self, dtype, size=None, tensor_array_name=None,
+ handle=None, flow=None, name=None):
"""Construct a new TensorArray or wrap an existing TensorArray handle.
Args:
@@ -59,6 +59,8 @@ class TensorArray(object):
set, handle should be None.
handle: (optional) A `Tensor` handle to an existing TensorArray. If this
is set, tensor_array_name should be None.
+ flow: (optional) A float `Tensor` scalar coming from an existing
+ TensorArray.flow.
name: A name for the operation (optional).
Raises:
@@ -73,16 +75,15 @@ class TensorArray(object):
if handle is None and size is None:
raise ValueError("Size must be provided if handle is not provided")
- with ops.op_scope([handle, size], name, "TensorArray") as scope:
+ self._dtype = dtype
+ with ops.op_scope([handle, size, flow], name, "TensorArray") as scope:
if handle:
self._handle = handle
else:
self._handle = gen_data_flow_ops._tensor_array(
dtype=dtype, size=size, tensor_array_name=tensor_array_name,
name=scope)
-
- self._flow = constant_op.constant(0, dtype=_dtypes.float32)
- self._dtype = dtype
+ self._flow = flow or constant_op.constant(0, dtype=_dtypes.float32)
@property
def flow(self):
@@ -90,14 +91,19 @@ class TensorArray(object):
return self._flow
@property
+ def dtype(self):
+ """The data type of this TensorArray."""
+ return self._dtype
+
+ @property
def handle(self):
"""The reference to the TensorArray."""
return self._handle
- def grad(self, source):
+ def grad(self, source, flow=None):
g_handle = gen_data_flow_ops._tensor_array_grad(
handle=self._handle, source=source)
- g = TensorArray(dtype=self._dtype, size=None, handle=g_handle)
+ g = TensorArray(dtype=self._dtype, size=None, handle=g_handle, flow=flow)
return g
def read(self, index, name=None):
diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc
index 600f083840..8a7d0925ce 100644
--- a/tensorflow/stream_executor/dso_loader.cc
+++ b/tensorflow/stream_executor/dso_loader.cc
@@ -35,8 +35,13 @@ namespace perftools {
namespace gputools {
namespace internal {
+// TensorFlow OSS configure uses the following lines to configure versions. For
+// any modifications of the format, please make sure the script still works.
+string GetCudaVersion() { return "7.0"; }
+string GetCudnnVersion() { return "6.5"; }
+
/* static */ port::Status DsoLoader::GetCublasDsoHandle(void** dso_handle) {
- return GetDsoHandle(FindDsoPath("libcublas.so.7.0",
+ return GetDsoHandle(FindDsoPath("libcublas.so." + GetCudaVersion(),
"third_party/gpus/cuda/lib64"),
dso_handle);
}
@@ -46,18 +51,19 @@ namespace internal {
// different version number than other CUDA libraries. See b/22397368 for
// some details about the complications surrounding this.
return GetDsoHandle(
- FindDsoPath("libcudnn.so.6.5", "third_party/gpus/cuda/lib64"),
+ FindDsoPath("libcudnn.so." + GetCudnnVersion(),
+ "third_party/gpus/cuda/lib64"),
dso_handle);
}
/* static */ port::Status DsoLoader::GetCufftDsoHandle(void** dso_handle) {
- return GetDsoHandle(FindDsoPath("libcufft.so.7.0",
+ return GetDsoHandle(FindDsoPath("libcufft.so." + GetCudaVersion(),
"third_party/gpus/cuda/lib64"),
dso_handle);
}
/* static */ port::Status DsoLoader::GetCurandDsoHandle(void** dso_handle) {
- return GetDsoHandle(FindDsoPath("libcurand.so.7.0",
+ return GetDsoHandle(FindDsoPath("libcurand.so." + GetCudaVersion(),
"third_party/gpus/cuda/lib64"),
dso_handle);
}
@@ -70,7 +76,7 @@ namespace internal {
/* static */ port::Status DsoLoader::GetLibcuptiDsoHandle(void** dso_handle) {
return GetDsoHandle(
- FindDsoPath("libcupti.so.7.0",
+ FindDsoPath("libcupti.so." + GetCudaVersion(),
"third_party/gpus/cuda/extras/CUPTI/lib64"),
dso_handle);
}
@@ -92,8 +98,6 @@ namespace internal {
if (*dso_handle == nullptr) {
LOG(INFO) << "Couldn't open CUDA library " << path
<< ". LD_LIBRARY_PATH: " << getenv("LD_LIBRARY_PATH");
- // TODO(b/22689637): Eliminate unnecessary ToString once StrCat has been
- // moved to the open-sourceable version.
return port::Status(
port::error::FAILED_PRECONDITION,
port::StrCat("could not dlopen DSO: ", path, "; dlerror: ", dlerror()));
diff --git a/tensorflow/tensorboard/BUILD b/tensorflow/tensorboard/BUILD
index bd4229b5f4..52518130f6 100644
--- a/tensorflow/tensorboard/BUILD
+++ b/tensorflow/tensorboard/BUILD
@@ -15,7 +15,7 @@ filegroup(
py_library(
name = "tensorboard_handler",
- srcs = ["tensorboard_handler.py"],
+ srcs = ["backend/tensorboard_handler.py"],
deps = [
":float_wrapper",
"//tensorflow/python:platform",
@@ -26,14 +26,14 @@ py_library(
py_library(
name = "float_wrapper",
- srcs = ["float_wrapper.py"],
+ srcs = ["backend/float_wrapper.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "float_wrapper_test",
size = "small",
- srcs = ["float_wrapper_test.py"],
+ srcs = ["backend/float_wrapper_test.py"],
deps = [
":float_wrapper",
"//tensorflow/python:platform_test",
@@ -43,7 +43,7 @@ py_test(
py_binary(
name = "tensorboard",
- srcs = ["tensorboard.py"],
+ srcs = ["backend/tensorboard.py"],
data = [":tensorboard_frontend"],
deps = [
":tensorboard_handler",
diff --git a/tensorflow/tensorboard/backend/__init__.py b/tensorflow/tensorboard/backend/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/tensorboard/backend/__init__.py
diff --git a/tensorflow/tensorboard/float_wrapper.py b/tensorflow/tensorboard/backend/float_wrapper.py
index 5d5f0ea6a1..5d5f0ea6a1 100644
--- a/tensorflow/tensorboard/float_wrapper.py
+++ b/tensorflow/tensorboard/backend/float_wrapper.py
diff --git a/tensorflow/tensorboard/float_wrapper_test.py b/tensorflow/tensorboard/backend/float_wrapper_test.py
index 7f731b911b..ae29880285 100644
--- a/tensorflow/tensorboard/float_wrapper_test.py
+++ b/tensorflow/tensorboard/backend/float_wrapper_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import tensorflow.python.platform
from tensorflow.python.platform import googletest
-from tensorflow.tensorboard import float_wrapper
+from tensorflow.tensorboard.backend import float_wrapper
_INFINITY = float('inf')
diff --git a/tensorflow/tensorboard/tensorboard.py b/tensorflow/tensorboard/backend/tensorboard.py
index bbb98ff4aa..30a31c6468 100644
--- a/tensorflow/tensorboard/tensorboard.py
+++ b/tensorflow/tensorboard/backend/tensorboard.py
@@ -38,7 +38,7 @@ from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import status_bar
from tensorflow.python.summary import event_accumulator
from tensorflow.python.summary import event_multiplexer
-from tensorflow.tensorboard import tensorboard_handler
+from tensorflow.tensorboard.backend import tensorboard_handler
flags.DEFINE_string('logdir', None, """logdir specifies the directory where
TensorBoard will look to find TensorFlow event files that it can display.
diff --git a/tensorflow/tensorboard/tensorboard_handler.py b/tensorflow/tensorboard/backend/tensorboard_handler.py
index c5ab674121..46c29e3d9b 100644
--- a/tensorflow/tensorboard/tensorboard_handler.py
+++ b/tensorflow/tensorboard/backend/tensorboard_handler.py
@@ -45,7 +45,7 @@ from tensorflow.python.platform import logging
from tensorflow.python.platform import resource_loader
from tensorflow.python.summary import event_accumulator
from tensorflow.python.util import compat
-from tensorflow.tensorboard import float_wrapper
+from tensorflow.tensorboard.backend import float_wrapper
DATA_PREFIX = '/data'
diff --git a/tensorflow/tensorboard/components/tf-image-dashboard/tf-image-loader.html b/tensorflow/tensorboard/components/tf-image-dashboard/tf-image-loader.html
index 10f03b0006..b782e3bc65 100644
--- a/tensorflow/tensorboard/components/tf-image-dashboard/tf-image-loader.html
+++ b/tensorflow/tensorboard/components/tf-image-dashboard/tf-image-loader.html
@@ -15,6 +15,7 @@ future for loading older images.
img {
width: 100%;
height: 100%;
+ image-rendering: pixelated;
}
</style>
<template>
diff --git a/third_party/gpus/cuda/BUILD b/third_party/gpus/cuda/BUILD
index 690728c5d2..47adb1d9d0 100644
--- a/third_party/gpus/cuda/BUILD
+++ b/third_party/gpus/cuda/BUILD
@@ -1,6 +1,10 @@
licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like
load("/tensorflow/tensorflow", "if_cuda")
+load("//tensorflow/core:platform/default/build_config.bzl",
+ "tf_get_cuda_version",
+ "tf_get_cudnn_version",
+ )
package(default_visibility = ["//visibility:public"])
@@ -53,10 +57,10 @@ cc_library(
cc_library(
name = "cudart",
srcs = [
- "lib64/libcudart.so.7.0",
+ "lib64/libcudart.so." + tf_get_cuda_version(),
],
data = [
- "lib64/libcudart.so.7.0",
+ "lib64/libcudart.so." + tf_get_cuda_version(),
],
includes = ["include/"],
visibility = ["//visibility:public"],
@@ -66,10 +70,10 @@ cc_library(
cc_library(
name = "cublas",
srcs = [
- "lib64/libcublas.so.7.0",
+ "lib64/libcublas.so." + tf_get_cuda_version(),
],
data = [
- "lib64/libcublas.so.7.0",
+ "lib64/libcublas.so." + tf_get_cuda_version(),
],
includes = ["include/"],
visibility = ["//visibility:public"],
@@ -79,10 +83,10 @@ cc_library(
cc_library(
name = "cudnn",
srcs = [
- "lib64/libcudnn.so.6.5",
+ "lib64/libcudnn.so." + tf_get_cudnn_version(),
],
data = [
- "lib64/libcudnn.so.6.5",
+ "lib64/libcudnn.so." + tf_get_cudnn_version(),
],
includes = ["include/"],
visibility = ["//visibility:public"],
@@ -92,10 +96,10 @@ cc_library(
cc_library(
name = "cufft",
srcs = [
- "lib64/libcufft.so.7.0",
+ "lib64/libcufft.so." + tf_get_cuda_version(),
],
data = [
- "lib64/libcufft.so.7.0",
+ "lib64/libcufft.so." + tf_get_cuda_version(),
],
includes = ["include/"],
visibility = ["//visibility:public"],
@@ -130,10 +134,10 @@ genrule(
"include/cublas.h",
"include/cudnn.h",
"lib64/libcudart_static.a",
- "lib64/libcublas.so.7.0",
- "lib64/libcudnn.so.6.5",
- "lib64/libcudart.so.7.0",
- "lib64/libcufft.so.7.0",
+ "lib64/libcublas.so." + tf_get_cuda_version(),
+ "lib64/libcudnn.so." + tf_get_cudnn_version(),
+ "lib64/libcudart.so." + tf_get_cuda_version(),
+ "lib64/libcufft.so." + tf_get_cuda_version(),
],
cmd = if_cuda(
# Under cuda config, create all the symbolic links to the actual cuda files
@@ -147,10 +151,10 @@ genrule(
"touch $(@D)/include/cublas.h",
"touch $(@D)/include/cudnn.h",
"touch $(@D)/lib64/libcudart_static.a",
- "touch $(@D)/lib64/libcublas.so.7.0",
- "touch $(@D)/lib64/libcudnn.so.6.5",
- "touch $(@D)/lib64/libcudart.so.7.0",
- "touch $(@D)/lib64/libcufft.so.7.0"
+ "touch $(@D)/lib64/libcublas.so." + tf_get_cuda_version(),
+ "touch $(@D)/lib64/libcudnn.so." + tf_get_cudnn_version(),
+ "touch $(@D)/lib64/libcudart.so." + tf_get_cuda_version(),
+ "touch $(@D)/lib64/libcufft.so." + tf_get_cuda_version(),
]),
),
local = 1,
diff --git a/third_party/gpus/cuda/cuda_config.sh b/third_party/gpus/cuda/cuda_config.sh
index 44e07433a0..87c35349c0 100755
--- a/third_party/gpus/cuda/cuda_config.sh
+++ b/third_party/gpus/cuda/cuda_config.sh
@@ -16,7 +16,7 @@
# A simple script to configure the Cuda tree needed for the TensorFlow GPU
-# build. We need both Cuda toolkit 7.0 and Cudnn 6.5.
+# build. We need both Cuda toolkit $TF_CUDA_VERSION and Cudnn $TF_CUDNN_VERSION.
# Useage:
# * User edit cuda.config to point both Cuda toolkit and Cudnn libraries to their local path
# * run cuda_config.sh to generate symbolic links in the source tree to reflect
@@ -62,8 +62,8 @@ function CudaError {
cat << EOF
##############################################################################
##############################################################################
-Cuda 7.0 toolkit is missing.
-1. Download and install the CUDA 7.0 toolkit and CUDNN 6.5 library;
+Cuda $TF_CUDA_VERSION toolkit is missing.
+1. Download and install the CUDA $TF_CUDA_VERSION toolkit and CUDNN $TF_CUDNN_VERSION library;
2. Run configure from the root of the source tree, before rerunning bazel;
Please refer to README.md for more details.
##############################################################################
@@ -78,8 +78,8 @@ function CudnnError {
cat << EOF
##############################################################################
##############################################################################
-Cudnn 6.5 is missing.
-1. Download and install the CUDA 7.0 toolkit and CUDNN 6.5 library;
+Cudnn $TF_CUDNN_VERSION is missing.
+1. Download and install the CUDA $TF_CUDA_VERSION toolkit and CUDNN $TF_CUDNN_VERSION library;
2. Run configure from the root of the source tree, before rerunning bazel;
Please refer to README.md for more details.
##############################################################################
@@ -110,18 +110,18 @@ if [ "$CHECK_ONLY" == "1" ]; then
CheckAndLinkToSrcTree CudaError include/cublas.h
CheckAndLinkToSrcTree CudnnError include/cudnn.h
CheckAndLinkToSrcTree CudaError lib64/libcudart_static.a
- CheckAndLinkToSrcTree CudaError lib64/libcublas.so.7.0
- CheckAndLinkToSrcTree CudnnError lib64/libcudnn.so.6.5
- CheckAndLinkToSrcTree CudaError lib64/libcudart.so.7.0
- CheckAndLinkToSrcTree CudaError lib64/libcufft.so.7.0
+ CheckAndLinkToSrcTree CudaError lib64/libcublas.so.$TF_CUDA_VERSION
+ CheckAndLinkToSrcTree CudnnError lib64/libcudnn.so.$TF_CUDNN_VERSION
+ CheckAndLinkToSrcTree CudaError lib64/libcudart.so.$TF_CUDA_VERSION
+ CheckAndLinkToSrcTree CudaError lib64/libcufft.so.$TF_CUDA_VERSION
exit 0
fi
# Actually configure the source tree for TensorFlow's canonical view of Cuda
# libraries.
-if test ! -e ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so.7.0; then
- CudaError "cannot find ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so.7.0"
+if test ! -e ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so.$TF_CUDA_VERSION; then
+ CudaError "cannot find ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so.$TF_CUDA_VERSION"
fi
if test ! -d ${CUDNN_INSTALL_PATH}; then
@@ -137,13 +137,13 @@ else
CudnnError "cannot find cudnn.h under: ${CUDNN_INSTALL_PATH}"
fi
-# Locate libcudnn.so.6.5
-if test -e ${CUDNN_INSTALL_PATH}/libcudnn.so.6.5; then
+# Locate libcudnn.so.${$TF_CUDNN_VERSION}
+if test -e ${CUDNN_INSTALL_PATH}/libcudnn.so.$TF_CUDNN_VERSION; then
CUDNN_LIB_PATH=${CUDNN_INSTALL_PATH}
-elif test -e ${CUDNN_INSTALL_PATH}/lib64/libcudnn.so.6.5; then
+elif test -e ${CUDNN_INSTALL_PATH}/lib64/libcudnn.so.$TF_CUDNN_VERSION; then
CUDNN_LIB_PATH=${CUDNN_INSTALL_PATH}/lib64
else
- CudnnError "cannot find libcudnn.so.6.5 under: ${CUDNN_INSTALL_PATH}"
+ CudnnError "cannot find libcudnn.so.$TF_CUDNN_VERSION under: ${CUDNN_INSTALL_PATH}"
fi
# Helper function to build symbolic links for all files under a directory.
@@ -182,4 +182,4 @@ LinkAllFiles ${CUDA_TOOLKIT_PATH}/nvvm $OUTPUTDIR/third_party/gpus/cuda/nvvm ||
# Set up symbolic link for cudnn
ln -sf $CUDNN_HEADER_PATH/cudnn.h $OUTPUTDIR/third_party/gpus/cuda/include/cudnn.h || exit -1
-ln -sf $CUDNN_LIB_PATH/libcudnn.so.6.5 $OUTPUTDIR/third_party/gpus/cuda/lib64/libcudnn.so.6.5 || exit -1
+ln -sf $CUDNN_LIB_PATH/libcudnn.so.$TF_CUDNN_VERSION $OUTPUTDIR/third_party/gpus/cuda/lib64/libcudnn.so.$TF_CUDNN_VERSION || exit -1