diff options
author | 2016-01-29 09:34:18 -0800 | |
---|---|---|
committer | 2016-01-29 20:15:13 -0800 | |
commit | 8a59748c087a2fee535c0d5067dbabb01920e812 (patch) | |
tree | 179f23b84fb0c47cf17d9551f62e9a6c11c32f79 | |
parent | faf747a15d4efc8fff03a10a3fdb37393197c2d3 (diff) |
Use cc_binary rather than cc_library to reduce size of native library in APK from 5.5mb to 3.2mb (compressed).
Change: 113369407
60 files changed, 2576 insertions, 438 deletions
@@ -1,5 +1,9 @@ #!/bin/bash +if [ "$TF_UNOFFICIAL_SETTING" == "1" ]; then + echo -e "\nWARNING: You are configuring unofficial settings in TensorFlow. Because some external libraries are not backward compatible, these settings are largely untested and unsupported. \n" 1>&2 +fi + ## Set up python-related environment settings while true; do fromuser="" @@ -44,32 +48,55 @@ fi # Find out where the CUDA toolkit is installed while true; do + # Configure the Cuda SDK version to use. + default_cuda_version="7.0" + if [ "$TF_UNOFFICIAL_SETTING" == "1" ]; then + if [ -z "$TF_CUDA_VERSION" ]; then + read -p "Please specify the Cuda SDK version you want to use. [Default is $default_cuda_version]: " TF_CUDA_VERSION + fi + fi + if [ -z "$TF_CUDA_VERSION" ]; then + TF_CUDA_VERSION=$default_cuda_version + fi + fromuser="" if [ -z "$CUDA_TOOLKIT_PATH" ]; then default_cuda_path=/usr/local/cuda - read -p "Please specify the location where CUDA 7.0 toolkit is installed. Refer to README.md for more details. [Default is $default_cuda_path]: " CUDA_TOOLKIT_PATH + read -p "Please specify the location where CUDA $TF_CUDA_VERSION toolkit is installed. Refer to README.md for more details. [Default is $default_cuda_path]: " CUDA_TOOLKIT_PATH fromuser="1" if [ -z "$CUDA_TOOLKIT_PATH" ]; then CUDA_TOOLKIT_PATH=$default_cuda_path fi fi - if [ -e "$CUDA_TOOLKIT_PATH/lib64/libcudart.so.7.0" ]; then + if [ -e "$CUDA_TOOLKIT_PATH/lib64/libcudart.so.$TF_CUDA_VERSION" ]; then break fi - echo "Invalid path to CUDA 7.0 toolkit. ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so.7.0 cannot be found" + echo "Invalid path to CUDA $TF_CUDA_VERSION toolkit. ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so.$TF_CUDA_VERSION cannot be found" if [ -z "$fromuser" ]; then exit 1 fi + TF_CUDA_VERSION="" CUDA_TOOLKIT_PATH="" # Retry done # Find out where the cuDNN library is installed while true; do + # Configure the Cudnn version to use. + default_cudnn_version="6.5" + if [ "$TF_UNOFFICIAL_SETTING" == "1" ]; then + if [ -z "$TF_CUDNN_VERSION" ]; then + read -p "Please specify the Cudnn version you want to use. [Default is $default_cudnn_version]: " TF_CUDNN_VERSION + fi + fi + if [ -z "$TF_CUDNN_VERSION" ]; then + TF_CUDNN_VERSION=$default_cudnn_version + fi + fromuser="" if [ -z "$CUDNN_INSTALL_PATH" ]; then default_cudnn_path=${CUDA_TOOLKIT_PATH} - read -p "Please specify the location where cuDNN v2 library is installed. Refer to README.md for more details. [Default is $default_cudnn_path]: " CUDNN_INSTALL_PATH + read -p "Please specify the location where cuDNN $TF_CUDNN_VERSION library is installed. Refer to README.md for more details. [Default is $default_cudnn_path]: " CUDNN_INSTALL_PATH fromuser="1" if [ -z "$CUDNN_INSTALL_PATH" ]; then CUDNN_INSTALL_PATH=$default_cudnn_path @@ -78,21 +105,22 @@ while true; do # Going through one more level of expansion to handle that. CUDNN_INSTALL_PATH=$(bash -c "readlink -f $CUDNN_INSTALL_PATH") fi - if [ -e "$CUDNN_INSTALL_PATH/libcudnn.so.6.5" -o -e "$CUDNN_INSTALL_PATH/lib64/libcudnn.so.6.5" ]; then + if [ -e "$CUDNN_INSTALL_PATH/libcudnn.so.${TF_CUDNN_VERSION}" -o -e "$CUDNN_INSTALL_PATH/lib64/libcudnn.so.${TF_CUDNN_VERSION}" ]; then break fi - echo "Invalid path to cuDNN v2 toolkit. Neither of the following two files can be found:" - echo "$CUDNN_INSTALL_PATH/lib64/libcudnn.so.6.5" - echo "$CUDNN_INSTALL_PATH/libcudnn.so.6.5" + echo "Invalid path to cuDNN ${TF_CUDNN_VERSION} toolkit. Neither of the following two files can be found:" + echo "$CUDNN_INSTALL_PATH/lib64/libcudnn.so.${TF_CUDNN_VERSION}" + echo "$CUDNN_INSTALL_PATH/libcudnn.so.${TF_CUDNN_VERSION}" if [ -z "$fromuser" ]; then exit 1 fi + TF_CUDNN_VERSION="" CUDNN_INSTALL_PATH="" # Retry done cat > third_party/gpus/cuda/cuda.config <<EOF -# CUDA_TOOLKIT_PATH refers to the CUDA toolkit. Tensorflow requires Cuda 7.0 +# CUDA_TOOLKIT_PATH refers to the CUDA toolkit. Tensorflow requires Cuda $TF_CUDA_VERSION # at the moment. CUDA_TOOLKIT_PATH="$CUDA_TOOLKIT_PATH" @@ -100,10 +128,23 @@ CUDA_TOOLKIT_PATH="$CUDA_TOOLKIT_PATH" # files can be either in this directory, or under include/ and lib64/ # directories separately. CUDNN_INSTALL_PATH="$CUDNN_INSTALL_PATH" + +# The Cuda SDK version that should be used in this build +TF_CUDA_VERSION=$TF_CUDA_VERSION + +# The Cudnn version that should be used in this build +TF_CUDNN_VERSION=$TF_CUDNN_VERSION + EOF function UnofficialSetting() { - echo -e "\nWARNING: You are configuring unofficial settings in TensorFlow. Because some external libraries are not backward compatible, these settings are largely untested and unsupported. \n" 1>&2 + # Configure the Cuda toolkit version to work with. + perl -pi -e "s,CUDA_VERSION = '[0-9\.]*',CUDA_VERSION = '$TF_CUDA_VERSION',s" tensorflow/core/platform/default/build_config.bzl + perl -pi -e "s,(GetCudaVersion.*return )\"[0-9\.]*\",\1\"$TF_CUDA_VERSION\",s" tensorflow/stream_executor/dso_loader.cc + + # Configure the Cudnn version to work with. + perl -pi -e "s,CUDNN_VERSION = '[0-9\.]*',CUDNN_VERSION = '$TF_CUDNN_VERSION',s" tensorflow/core/platform/default/build_config.bzl + perl -pi -e "s,(GetCudnnVersion.*return )\"[0-9\.]*\",\1\"$TF_CUDNN_VERSION\",s" tensorflow/stream_executor/dso_loader.cc # Configure the compute capabilities that TensorFlow builds for. # Since Cuda toolkit is not backward-compatible, this is not guaranteed to work. diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 95d5531876..cb4d5e723c 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -298,6 +298,7 @@ tf_cuda_library( "graph/graph_constructor.h", "graph/graph_def_builder.h", "graph/node_builder.h", + "graph/validate.h", "public/session.h", "public/session_options.h", "public/tensor_c_api.h", diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/core/client/tensor_c_api.cc index 2140fc8afd..853c309091 100644 --- a/tensorflow/core/client/tensor_c_api.cc +++ b/tensorflow/core/client/tensor_c_api.cc @@ -316,16 +316,16 @@ Status LoadLibrary(const char* library_filename, void** result, } // namespace tensorflow -extern "C" { - -void TF_Run(TF_Session* s, - // Input tensors - const char** c_input_names, TF_Tensor** c_inputs, int ninputs, - // Output tensors - const char** c_output_tensor_names, TF_Tensor** c_outputs, - int noutputs, - // Target nodes - const char** c_target_node_names, int ntargets, TF_Status* status) { +void TF_Run_Helper(TF_Session* s, const char* handle, + // Input tensors + const char** c_input_names, TF_Tensor** c_inputs, + int ninputs, + // Output tensors + const char** c_output_tensor_names, TF_Tensor** c_outputs, + int noutputs, + // Target nodes + const char** c_target_node_names, int ntargets, + TF_Status* status) { status->status = Status::OK(); for (int i = 0; i < noutputs; i++) { c_outputs[i] = NULL; @@ -365,8 +365,13 @@ void TF_Run(TF_Session* s, for (int i = 0; i < ntargets; i++) { target_node_names[i] = c_target_node_names[i]; } - Status result = - s->session->Run(inputs, output_tensor_names, target_node_names, &outputs); + Status result; + if (handle == nullptr) { + result = s->session->Run(inputs, output_tensor_names, target_node_names, + &outputs); + } else { + result = s->session->PRun(handle, inputs, output_tensor_names, &outputs); + } if (!result.ok()) { status->status = result; return; @@ -392,6 +397,69 @@ void TF_Run(TF_Session* s, } } +extern "C" { + +void TF_Run(TF_Session* s, + // Input tensors + const char** c_input_names, TF_Tensor** c_inputs, int ninputs, + // Output tensors + const char** c_output_tensor_names, TF_Tensor** c_outputs, + int noutputs, + // Target nodes + const char** c_target_node_names, int ntargets, TF_Status* status) { + TF_Run_Helper(s, nullptr, c_input_names, c_inputs, ninputs, + c_output_tensor_names, c_outputs, noutputs, c_target_node_names, + ntargets, status); +} + +void TF_PRunSetup(TF_Session* s, + // Input names + const char** c_input_names, int ninputs, + // Output names + const char** c_output_tensor_names, int noutputs, + // Target nodes + const char** c_target_node_names, int ntargets, char** handle, + TF_Status* status) { + status->status = Status::OK(); + + std::vector<tensorflow::string> input_names(ninputs); + std::vector<tensorflow::string> output_tensor_names(noutputs); + std::vector<tensorflow::string> target_node_names(ntargets); + for (int i = 0; i < ninputs; i++) { + input_names[i] = c_input_names[i]; + } + for (int i = 0; i < noutputs; i++) { + output_tensor_names[i] = c_output_tensor_names[i]; + } + for (int i = 0; i < ntargets; i++) { + target_node_names[i] = c_target_node_names[i]; + } + tensorflow::string new_handle; + Status result; + result = s->session->PRunSetup(input_names, output_tensor_names, + target_node_names, &new_handle); + if (result.ok()) { + *handle = new char[new_handle.size() + 1]; + memcpy(*handle, new_handle.c_str(), new_handle.size() + 1); + } else { + status->status = result; + } +} + +void TF_PRun(TF_Session* s, const char* handle, + // Input tensors + const char** c_input_names, TF_Tensor** c_inputs, int ninputs, + // Output tensors + const char** c_output_tensor_names, TF_Tensor** c_outputs, + int noutputs, + // Target nodes + const char** c_target_node_names, int ntargets, + TF_Status* status) { + TF_Run_Helper(s, handle, c_input_names, c_inputs, ninputs, + c_output_tensor_names, c_outputs, noutputs, c_target_node_names, + ntargets, status); +} + const void* TF_BufferData(TF_Buffer* buffer) { return buffer->data; } size_t TF_BufferLength(TF_Buffer* buffer) { return buffer->length; } diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index e0a5fef6d0..00f3f17d78 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -53,58 +53,57 @@ void CopyTensor::ViaDMA(const string& edge_name, StatusCallback done) { initialization_done = true; port::Tracing::ScopedAnnotation annotation(edge_name); - VLOG(1) << "CopyViaDMA " << edge_name; - const size_t total_bytes = input->TotalBytes(); - - // Note that 0-size tensors have no backing buffer. - if (total_bytes > 0) { - const DeviceType src_device_type(src_alloc_attr.on_host() - ? DEVICE_CPU - : src->attributes().device_type()); - const DeviceType dst_device_type(dst_alloc_attr.on_host() - ? DEVICE_CPU - : dst->attributes().device_type()); - const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU); - const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU); - - if (non_cpu_src) { - if (non_cpu_dst) { - // Device to device copy. Look through registry for an appropriate - // CopyFunction. - std::vector<RegistrationInfo>* registry = MutableRegistry(); - for (const RegistrationInfo& ri : *registry) { - if (ri.sender_device_type == src_device_type && - ri.receiver_device_type == dst_device_type) { - ri.copy_function(send_dev_context, recv_dev_context, src, dst, - src_alloc_attr, dst_alloc_attr, input, output, - done); - return; - } - } - - // TODO(josh11b): If no CopyFunction is found, we currently fail - // but we could copy between devices via CPU. - done(errors::Unimplemented( - "No function registered to copy from devices of type ", - src_device_type.type(), " to devices of type ", - dst_device_type.type())); - } else { - // Device to host copy. - return send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, - output, done); + VLOG(1) << "Copy " << edge_name; + + const DeviceType src_device_type( + src_alloc_attr.on_host() ? DEVICE_CPU : src->attributes().device_type()); + const DeviceType dst_device_type( + dst_alloc_attr.on_host() ? DEVICE_CPU : dst->attributes().device_type()); + const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU); + const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU); + + // E.g., gpu -> gpu + if (non_cpu_src && non_cpu_dst) { + // Device to device copy. Look through registry for an appropriate + // CopyFunction. + std::vector<RegistrationInfo>* registry = MutableRegistry(); + for (const RegistrationInfo& ri : *registry) { + if (ri.sender_device_type == src_device_type && + ri.receiver_device_type == dst_device_type) { + ri.copy_function(send_dev_context, recv_dev_context, src, dst, + src_alloc_attr, dst_alloc_attr, input, output, done); + return; } - } else if (non_cpu_dst) { - // Host to Device copy. - // Note that this is already an async copy. - recv_dev_context->CopyCPUTensorToDevice(input, dst, output, done); - } else { - *output = *input; - done(Status::OK()); } - } else { - // buffer is empty - done(Status::OK()); + + // TODO(josh11b): If no CopyFunction is found, we currently fail + // but we could copy between devices via CPU. + done(errors::Unimplemented( + "No function registered to copy from devices of type ", + src_device_type.type(), " to devices of type ", + dst_device_type.type())); + return; + } + + // E.g., gpu -> cpu + if (non_cpu_src && !non_cpu_dst) { + // Device to host copy. + send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, output, + done); + return; } + + // E.g., cpu -> gpu + if (!non_cpu_src && non_cpu_dst) { + // Host to Device copy. + recv_dev_context->CopyCPUTensorToDevice(input, dst, output, done); + return; + } + + // cpu -> cpu + CHECK(!non_cpu_src && !non_cpu_dst); + *output = *input; + done(Status::OK()); } // static diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 199916b3c4..ae8f934c20 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/session_factory.h" #include "tensorflow/core/common_runtime/simple_placer.h" #include "tensorflow/core/framework/function.h" @@ -156,6 +155,9 @@ DirectSession::DirectSession(const SessionOptions& options, } DirectSession::~DirectSession() { + for (auto& it : partial_runs_) { + delete it.second; + } for (auto d : device_mgr_->ListDevices()) { d->op_segment()->RemoveHold(session_handle_); } @@ -240,7 +242,8 @@ Status DirectSession::ExtendLocked(const GraphDef& graph) { return Status::OK(); } -Status DirectSession::Run(const std::vector<std::pair<string, Tensor>>& inputs, +// TODO(yuanbyu): Simplify by treating Run() as "PRunSetup(); PRun()". +Status DirectSession::Run(const NamedTensorList& inputs, const std::vector<string>& output_names, const std::vector<string>& target_nodes, std::vector<Tensor>* outputs) { @@ -261,49 +264,213 @@ Status DirectSession::Run(const std::vector<std::pair<string, Tensor>>& inputs, // Check if we already have an executor for these arguments. ExecutorsAndKeys* executors_and_keys; - Status s = GetOrCreateExecutors(input_tensor_names, output_names, - target_nodes, &executors_and_keys); - if (!s.ok()) { - return s; + RunStateArgs run_state_args; + TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names, + target_nodes, &executors_and_keys, + &run_state_args)); + + // Create a run state and start execution. + RunState run_state(input_tensor_names, output_names); + run_state.graph = run_state_args.graph; + run_state.rendez = new IntraProcessRendezvous(device_mgr_.get()); + + // Send inputs. + TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez)); + + // Start parallel Executors. + const int num_executors = executors_and_keys->items.size(); + ExecutorBarrier* barrier = new ExecutorBarrier( + num_executors, run_state.rendez, [&run_state](const Status& ret) { + run_state.status = ret; + run_state.executors_done.Notify(); + }); + + Executor::Args args; + args.rendezvous = run_state.rendez; + args.cancellation_manager = cancellation_manager_; + args.runner = [this](Executor::Args::Closure c) { SchedClosure(c); }; + + for (const auto& item : executors_and_keys->items) { + item.executor->RunAsync(args, barrier->Get()); } - IntraProcessRendezvous* rendez = - new IntraProcessRendezvous(device_mgr_.get()); - core::ScopedUnref rendez_unref(rendez); + run_state.executors_done.WaitForNotification(); + TF_RETURN_IF_ERROR(run_state.status); - // Insert the input tensors into the local rendezvous by their - // rendezvous key. - for (const auto& input : inputs) { - const string& input_key = executors_and_keys->input_keys[input.first]; - s = rendez->Send(input_key, Rendezvous::Args(), input.second, false); - if (!s.ok()) { - rendez->StartAbort(s); - return s; + // Receive outputs. + TF_RETURN_IF_ERROR( + RecvOutputs(output_names, executors_and_keys, &run_state, outputs)); + return Status::OK(); +} + +Status DirectSession::PRunSetup(const std::vector<string>& input_names, + const std::vector<string>& output_names, + const std::vector<string>& target_nodes, + string* handle) { + { + mutex_lock l(graph_def_lock_); + if (!graph_created_) { + return errors::InvalidArgument( + "Session was not created with a graph before PRunSetup()!"); + } + } + + // Check if we already have an executor for these arguments. + ExecutorsAndKeys* executors_and_keys; + RunStateArgs run_state_args; + run_state_args.is_partial_run = true; + Status s = GetOrCreateExecutors(input_names, output_names, target_nodes, + &executors_and_keys, &run_state_args); + TF_RETURN_IF_ERROR(s); + + // Create the run state and save it for future PRun calls. + RunState* run_state = new RunState(input_names, output_names); + run_state->graph = run_state_args.graph; + run_state->rendez = new IntraProcessRendezvous(device_mgr_.get()); + { + mutex_lock l(executor_lock_); + if (!partial_runs_.insert({run_state_args.handle, run_state}).second) { + return errors::Internal("The handle ", run_state_args.handle, + " created for this partial" + " run is not unique."); } } // Start parallel Executors. - Notification executors_done; + Notification& executors_done = run_state->executors_done; + Status* run_status = &run_state->status; const int num_executors = executors_and_keys->items.size(); ExecutorBarrier* barrier = new ExecutorBarrier( - num_executors, rendez, [&executors_done, &s](const Status& ret) { - s = ret; + num_executors, run_state->rendez, + [&executors_done, run_status, this](const Status& ret) { + if (!ret.ok()) { + mutex_lock l(executor_lock_); + *run_status = ret; + } executors_done.Notify(); }); Executor::Args args; - args.rendezvous = rendez; + args.rendezvous = run_state->rendez; args.cancellation_manager = cancellation_manager_; args.runner = [this](Executor::Args::Closure c) { SchedClosure(c); }; - for (const auto& item : executors_and_keys->items) { - item.executor->RunAsync(args, barrier->Get()); + for (auto& item : executors_and_keys->items) { + Executor* exec = item.executor; + exec->RunAsync(args, barrier->Get()); } - executors_done.WaitForNotification(); + *handle = run_state_args.handle; + return Status::OK(); +} - TF_RETURN_IF_ERROR(s); +Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, + const std::vector<string>& output_names, + std::vector<Tensor>* outputs) { + std::vector<string> parts = str_util::Split(handle, ';'); + const string& key = parts[0]; + // Get the executors for this partial run. + ExecutorsAndKeys* executors_and_keys; + RunState* run_state; + { + mutex_lock l(executor_lock_); // could use reader lock + auto exc_it = executors_.find(key); + if (exc_it == executors_.end()) { + return errors::InvalidArgument( + "Must run 'setup' before performing partial runs!"); + } + executors_and_keys = exc_it->second; + + auto prun_it = partial_runs_.find(handle); + if (prun_it == partial_runs_.end()) { + return errors::InvalidArgument( + "Must run 'setup' before performing partial runs!"); + } + run_state = prun_it->second; + + // Make sure that this is a new set of feeds that are still pending. + for (const auto& input : inputs) { + auto it = run_state->pending_inputs.find(input.first); + if (it == run_state->pending_inputs.end()) { + return errors::InvalidArgument("The feed ", input.first, + " had already been fed."); + } + } + // Check that this is a new set of fetches that are still pending. + for (const auto& output : output_names) { + auto it = run_state->pending_outputs.find(output); + if (it == run_state->pending_outputs.end()) { + return errors::InvalidArgument("The fetch ", output, + " had already been fetched."); + } + } + } + // Check that this new set of fetches can be computed from all the + // feeds we have supplied. + TF_RETURN_IF_ERROR(CheckFetch(inputs, output_names, run_state)); + + // Send inputs. + Status s = SendInputs(inputs, executors_and_keys, run_state->rendez); + + // Receive outputs. + if (s.ok()) { + s = RecvOutputs(output_names, executors_and_keys, run_state, outputs); + } + + // Delete the run state if there is an error or all fetches are done. + { + mutex_lock l(executor_lock_); + bool done = true; + if (s.ok()) { + if (!run_state->status.ok()) { + LOG(WARNING) << "An error unrelated to this prun has been detected. " + << run_state->status; + } + for (const auto& it : inputs) { + run_state->pending_inputs.erase(it.first); + } + for (const auto& name : output_names) { + run_state->pending_outputs.erase(name); + } + done = run_state->pending_outputs.size() == 0; + } + if (done) { + run_state->executors_done.WaitForNotification(); + partial_runs_.erase(handle); + delete run_state; + } + } + return s; +} + +Status DirectSession::SendInputs(const NamedTensorList& inputs, + const ExecutorsAndKeys* executors_and_keys, + IntraProcessRendezvous* rendez) { + Status s; + // Insert the input tensors into the local rendezvous by their + // rendezvous key. + for (const auto& input : inputs) { + auto it = executors_and_keys->input_keys.find(input.first); + if (it == executors_and_keys->input_keys.end()) { + return errors::InvalidArgument("'", input.first, + "' is not a pre-defined feed!"); + } + const string& input_key = it->second; + s = rendez->Send(input_key, Rendezvous::Args(), input.second, false); + if (!s.ok()) { + rendez->StartAbort(s); + return s; + } + } + return Status::OK(); +} + +Status DirectSession::RecvOutputs(const std::vector<string>& output_names, + const ExecutorsAndKeys* executors_and_keys, + RunState* run_state, + std::vector<Tensor>* outputs) { + Status s; if (!output_names.empty()) { outputs->resize(output_names.size()); } @@ -311,14 +478,21 @@ Status DirectSession::Run(const std::vector<std::pair<string, Tensor>>& inputs, // Get the outputs from the rendezvous for (size_t output_offset = 0; output_offset < output_names.size(); ++output_offset) { - const string& output_key = - executors_and_keys->output_keys[output_names[output_offset]]; + const string& output_name = output_names[output_offset]; + auto it = executors_and_keys->output_keys.find(output_name); + if (it == executors_and_keys->output_keys.end()) { + return errors::InvalidArgument("'", output_name, + "' was not defined as a fetch" + " target in PRunSetup."); + } + const string& output_key = it->second; Tensor output_tensor; bool is_dead; // Fetch data from the Rendezvous. + IntraProcessRendezvous* rendez = run_state->rendez; s = rendez->Recv(output_key, Rendezvous::Args(), &output_tensor, &is_dead); - if (is_dead) { + if (is_dead && s.ok()) { s = errors::InvalidArgument("The tensor returned for ", output_names[output_offset], " was not valid."); @@ -331,14 +505,74 @@ Status DirectSession::Run(const std::vector<std::pair<string, Tensor>>& inputs, (*outputs)[output_offset] = output_tensor; } + return Status::OK(); +} - return s; +Status DirectSession::CheckFetch(const NamedTensorList& feeds, + const std::vector<string>& fetches, + const RunState* run_state) { + const Graph* g = run_state->graph; + std::unordered_map<StringPiece, Node*, StringPiece::Hasher> name_to_node; + for (Node* n : g->nodes()) { + name_to_node[n->name()] = n; + } + + // Build the set of pending feeds that we haven't seen. + std::unordered_set<TensorId, TensorId::Hasher> pending_feeds; + { + mutex_lock l(executor_lock_); + for (const string& feed : run_state->pending_inputs) { + TensorId id(ParseTensorName(feed)); + auto it = name_to_node.find(id.first); + if (it == name_to_node.end()) { + return errors::NotFound("Feed ", feed, ": not found"); + } + pending_feeds.insert(id); + } + } + for (const auto& it : feeds) { + TensorId id(ParseTensorName(it.first)); + pending_feeds.erase(id); + } + + // Initialize the stack with the fecth nodes. + std::vector<const Node*> stack; + for (const string& fetch : fetches) { + TensorId id(ParseTensorName(fetch)); + auto it = name_to_node.find(id.first); + if (it == name_to_node.end()) { + return errors::NotFound("Fetch ", fetch, ": not found"); + } + stack.push_back(it->second); + } + + // Any tensor needed for fetches can't be in pending_feeds. + std::vector<bool> visited(g->num_node_ids(), false); + while (!stack.empty()) { + const Node* n = stack.back(); + stack.pop_back(); + + for (const Edge* in_edge : n->in_edges()) { + const Node* in_node = in_edge->src(); + if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) { + return errors::InvalidArgument("Fetch ", in_node->name(), ":", + in_edge->src_output(), + " can't be computed from the feeds" + " that have been fed so far."); + } + if (!visited[in_node->id()]) { + visited[in_node->id()] = true; + stack.push_back(in_node); + } + } + } + return Status::OK(); } Status DirectSession::GetOrCreateExecutors( gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs, - gtl::ArraySlice<string> target_nodes, - ExecutorsAndKeys** executors_and_keys) { + gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys, + RunStateArgs* run_state_args) { // Sort the inputs and outputs, so we don't create separate // executors when a user passes in the same inputs/outputs in // different orders. @@ -370,10 +604,9 @@ Status DirectSession::GetOrCreateExecutors( // being created. FunctionLibraryDefinition* fdefs; std::unordered_map<string, Graph*> graphs; - Status s = CreateGraphs(inputs, outputs, target_nodes, &fdefs, &graphs); - if (!s.ok()) { - return s; - } + Status s = CreateGraphs(inputs, outputs, target_nodes, &fdefs, &graphs, + run_state_args); + TF_RETURN_IF_ERROR(s); std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys); ek->func_defs = fdefs; @@ -386,9 +619,7 @@ Status DirectSession::GetOrCreateExecutors( Device* device; s = device_mgr_->LookupDevice(partition_name, &device); - if (!s.ok()) { - return s; - } + TF_RETURN_IF_ERROR(s); ek->items.resize(ek->items.size() + 1); auto* item = &(ek->items.back()); @@ -434,6 +665,12 @@ Status DirectSession::GetOrCreateExecutors( output, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); } + // Set the handle. + { + mutex_lock l(mu_); + run_state_args->handle = strings::StrCat(key, ";", name_counter_++); + } + // Reacquire the lock, try to insert into the map. mutex_lock l(executor_lock_); const bool inserted = executors_.insert(std::make_pair(key, ek.get())).second; @@ -470,10 +707,12 @@ void DirectSession::RestoreStatefulNodes(Graph* graph) { } } -Status DirectSession::CreateGraphs( - gtl::ArraySlice<string> feeds, gtl::ArraySlice<string> fetches, - gtl::ArraySlice<string> target_nodes, FunctionLibraryDefinition** func_defs, - std::unordered_map<string, Graph*>* outputs) { +Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds, + gtl::ArraySlice<string> fetches, + gtl::ArraySlice<string> target_nodes, + FunctionLibraryDefinition** func_defs, + std::unordered_map<string, Graph*>* outputs, + RunStateArgs* run_state_args) { std::unique_ptr<FunctionLibraryDefinition> fdefs; std::unique_ptr<Graph> graph; GraphConstructorOptions opts{ @@ -511,6 +750,12 @@ Status DirectSession::CreateGraphs( TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def_, graph.get())); } + // Remember the graph in run state if this is a partial run. + if (run_state_args->is_partial_run) { + run_state_args->graph = new Graph(fdefs.get()); + CopyGraph(*graph.get(), run_state_args->graph); + } + TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( graph.get(), feeds, fetches, target_nodes, device_set_.client_device()->attributes())); diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index 178dfe8851..0a7430ff04 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -19,11 +19,13 @@ limitations under the License. #include <memory> #include <string> #include <unordered_map> +#include <unordered_set> #include <vector> #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" @@ -46,12 +48,24 @@ class DirectSession : public Session { DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr); ~DirectSession() override; + typedef std::vector<std::pair<string, Tensor>> NamedTensorList; + ::tensorflow::Status Create(const GraphDef& graph) override; ::tensorflow::Status Extend(const GraphDef& graph) override; - ::tensorflow::Status Run(const std::vector<std::pair<string, Tensor>>& inputs, + ::tensorflow::Status Run(const NamedTensorList& inputs, const std::vector<string>& output_names, const std::vector<string>& target_nodes, std::vector<Tensor>* outputs) override; + + // NOTE: PRunSetup and PRun are added to support partial execution. This + // feature is experimental and subject to change. + ::tensorflow::Status PRunSetup(const std::vector<string>& input_names, + const std::vector<string>& output_names, + const std::vector<string>& target_nodes, + string* handle) override; + ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs, + const std::vector<string>& output_names, + std::vector<Tensor>* outputs) override; ::tensorflow::Status Close() override; private: @@ -85,24 +99,86 @@ class DirectSession : public Session { } }; + // For each live partial execution, the session maintains a RunState. + // 'status' is the current status of this partial execution. 'executor_done' + // is "notified" when all executors are done. 'graph' is the graph being + // executed. 'pending_inputs' are the set of pending feeds and + // 'pending_outputs' are the set of pending fetches. + struct RunState { + Status status; + IntraProcessRendezvous* rendez = nullptr; + Notification executors_done; + Graph* graph = nullptr; + std::unordered_set<string> pending_inputs; + std::unordered_set<string> pending_outputs; + + RunState(const std::vector<string>& input_names, + const std::vector<string>& output_names) { + // Initially all the feeds and fetches are pending. + for (auto& name : input_names) { + pending_inputs.emplace(name); + } + for (auto& name : output_names) { + pending_outputs.emplace(name); + } + } + + ~RunState() { + if (rendez != nullptr) { + if (!executors_done.HasBeenNotified()) { + rendez->StartAbort(errors::Cancelled("PRun cancellation")); + executors_done.WaitForNotification(); + } + rendez->Unref(); + } + delete graph; + } + }; + + struct RunStateArgs { + bool is_partial_run = false; + string handle; + Graph* graph = nullptr; + }; + // Retrieves an already existing set of executors to run 'inputs' and // 'outputs', or creates and caches them for future use. ::tensorflow::Status GetOrCreateExecutors( gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs, gtl::ArraySlice<string> target_nodes, - ExecutorsAndKeys** executors_and_keys); + ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args); // Creates several graphs given the existing graph_def_ and the // input feeds and fetches, given 'devices'. - ::tensorflow::Status CreateGraphs( - gtl::ArraySlice<string> feeds, gtl::ArraySlice<string> fetches, - gtl::ArraySlice<string> target_nodes, - FunctionLibraryDefinition** func_defs, - std::unordered_map<string, Graph*>* outputs); + ::tensorflow::Status CreateGraphs(gtl::ArraySlice<string> feeds, + gtl::ArraySlice<string> fetches, + gtl::ArraySlice<string> target_nodes, + FunctionLibraryDefinition** func_defs, + std::unordered_map<string, Graph*>* outputs, + RunStateArgs* run_state_args); ::tensorflow::Status ExtendLocked(const GraphDef& graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_); + // Feeds more inputs to the executors, triggering further execution. + ::tensorflow::Status SendInputs( + const std::vector<std::pair<string, Tensor>>& inputs, + const ExecutorsAndKeys* executors_and_keys, + IntraProcessRendezvous* rendez); + + // Fetches more outputs from the executors. It waits until the output + // tensors are computed. + ::tensorflow::Status RecvOutputs(const std::vector<string>& output_names, + const ExecutorsAndKeys* executors_and_keys, + RunState* run_state, + std::vector<Tensor>* outputs); + + // Check if the specified fetches can be computed from the feeds + // that we have already provided. + ::tensorflow::Status CheckFetch( + const std::vector<std::pair<string, Tensor>>& feeds, + const std::vector<string>& fetches, const RunState* run_state); + const SessionOptions options_; // Device structures. @@ -129,6 +205,10 @@ class DirectSession : public Session { std::unordered_map<string, ExecutorsAndKeys*> executors_ GUARDED_BY(executor_lock_); + // Holds mappings from handle to partial run state. + std::unordered_map<string, RunState*> partial_runs_ + GUARDED_BY(executor_lock_); + CancellationManager* cancellation_manager_; // Saves and restores device placements for stateful nodes. diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 9109da76dd..68009428f8 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -298,8 +298,6 @@ TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) { TEST(DirectSessionTest, MultipleFeedTest) { GraphDef def; Graph g(OpRegistry::Global()); - Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10})); - var->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); Tensor first_value(DT_FLOAT, TensorShape({})); first_value.scalar<float>()() = 1.0; @@ -396,5 +394,138 @@ TEST(DirectSessionTest, DarthKernel) { delete sess; } +TEST(DirectSessionTest, PartialRunTest) { + GraphDef def; + Graph g(OpRegistry::Global()); + + Tensor first_value(DT_FLOAT, TensorShape({})); + first_value.scalar<float>()() = 1.0; + Node* first_const = test::graph::Constant(&g, first_value); + Node* first_identity = test::graph::Identity(&g, first_const); + + Tensor second_value(DT_FLOAT, TensorShape({})); + second_value.scalar<float>()() = 2.0; + Node* second_const = test::graph::Constant(&g, second_value); + Node* second_identity = test::graph::Identity(&g, second_const); + + Node* third = test::graph::Add(&g, first_identity, second_identity); + Node* third_identity = test::graph::Identity(&g, third); + + test::graph::ToGraphDef(&g, &def); + + std::unique_ptr<Session> session(CreateSession()); + ASSERT_TRUE(session != nullptr); + ASSERT_OK(session->Create(def)); + + std::vector<Tensor> outputs; + + string handle; + Status s = session->PRunSetup( + {first_const->name(), second_const->name()}, + {first_identity->name() + ":0", second_identity->name() + ":0", + third_identity->name() + ":0"}, + {}, &handle); + ASSERT_TRUE(s.ok()); + + Tensor value_11(DT_FLOAT, TensorShape({})); + value_11.scalar<float>()() = 11.0; + Tensor value_22(DT_FLOAT, TensorShape({})); + value_22.scalar<float>()() = 22.0; + + // Feed first_const, fetch first_identity + s = session->PRun(handle, {{first_const->name(), value_11}}, + {first_identity->name() + ":0"}, &outputs); + ASSERT_TRUE(s.ok()); + ASSERT_EQ(1, outputs.size()); + ASSERT_EQ(11.0, outputs[0].flat<float>()(0)); + + // Feed second_const, fetch second_identity and third_identity + s = session->PRun( + handle, {{second_const->name(), value_22}}, + {second_identity->name() + ":0", third_identity->name() + ":0"}, + &outputs); + ASSERT_TRUE(s.ok()); + ASSERT_EQ(2, outputs.size()); + ASSERT_EQ(22.0, outputs[0].flat<float>()(0)); + ASSERT_EQ(11.0 + 22.0, outputs[1].flat<float>()(0)); +} + +TEST(DirectSessionTest, PartialRunMissingFeed) { + GraphDef def; + Graph g(OpRegistry::Global()); + + Tensor first_value(DT_FLOAT, TensorShape({})); + first_value.scalar<float>()() = 1.0; + Node* first_const = test::graph::Constant(&g, first_value); + Node* first_identity = test::graph::Identity(&g, first_const); + + Tensor second_value(DT_FLOAT, TensorShape({})); + second_value.scalar<float>()() = 2.0; + Node* second_const = test::graph::Constant(&g, second_value); + Node* second_identity = test::graph::Identity(&g, second_const); + + Node* third = test::graph::Add(&g, first_identity, second_identity); + Node* third_identity = test::graph::Identity(&g, third); + + test::graph::ToGraphDef(&g, &def); + + std::unique_ptr<Session> session(CreateSession()); + ASSERT_TRUE(session != nullptr); + ASSERT_OK(session->Create(def)); + + std::vector<Tensor> outputs; + + string handle; + Status s = session->PRunSetup({first_const->name(), second_const->name()}, + {third_identity->name() + ":0"}, {}, &handle); + ASSERT_TRUE(s.ok()); + + // Feed first_const, fetch third_identity + Tensor value_11(DT_FLOAT, TensorShape({})); + value_11.scalar<float>()() = 11.0; + s = session->PRun(handle, {{first_const->name(), value_11}}, + {third_identity->name() + ":0"}, &outputs); + ASSERT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("can't be computed from the feeds")); +} + +TEST(DirectSessionTest, PartialRunMultiOutputFeed) { + GraphDef def; + Graph g(OpRegistry::Global()); + + Tensor bool_value(DT_BOOL, TensorShape({})); + bool_value.scalar<bool>()() = true; + Node* bool_const = test::graph::Constant(&g, bool_value); + Node* switch_node = test::graph::Switch(&g, bool_const, bool_const); + Node* fourth_identity = test::graph::Identity(&g, switch_node, 1); + + test::graph::ToGraphDef(&g, &def); + + std::unique_ptr<Session> session(CreateSession()); + ASSERT_TRUE(session != nullptr); + ASSERT_OK(session->Create(def)); + + std::vector<Tensor> outputs; + + string handle; + Status s = session->PRunSetup({switch_node->name() + ":1"}, + {fourth_identity->name() + ":0"}, {}, &handle); + ASSERT_TRUE(s.ok()); + + // Fetch fourth_identity without feeds. + s = session->PRun(handle, {}, {fourth_identity->name() + ":0"}, &outputs); + ASSERT_TRUE(errors::IsInvalidArgument(s)); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("can't be computed from the feeds")); + + // Feed switch_node:1 and fetch fourth_identity. + s = session->PRun(handle, {{switch_node->name() + ":1", bool_value}}, + {fourth_identity->name() + ":0"}, &outputs); + ASSERT_TRUE(s.ok()); + ASSERT_EQ(1, outputs.size()); + ASSERT_EQ(true, outputs[0].flat<bool>()(0)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc index 2accf92503..f34ac256d1 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -37,6 +37,18 @@ limitations under the License. #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/util/util.h" +// IMPLEMENTATION NOTE: +// +// 1. Within this module, we intentionally LOG(FATAL) if any stream +// involved in memcpy becomes !stream->ok(), because TF process +// today (1/2016) can not properly recover from such an error. +// +// 2. When 0-size tensor is being copied, we should not schedule a +// copy ThenMemcpy since there is no byte to move. However, we must +// ensure the causal ordering by arranging the copy done callback +// happens-after all activities scheduled on the given stream being +// finished. + // If this need to be runtime configurable, consider adding options to // ConfigProto. const tensorflow::int64 FLAGS_brain_gpu_util_debug_string_maxlen = 128; @@ -50,60 +62,106 @@ namespace tensorflow { namespace gpu = ::perftools::gputools; +Status PrepareCopy(Device* device, const DeviceContext* ctx, const Tensor& src, + const Tensor* dst, + const DeviceBase::GpuDeviceInfo** dev_info, + gpu::Stream** stream) { + if (device == nullptr) { + return errors::Internal("Unexpected null device."); + } + auto di = device->tensorflow_gpu_device_info(); + if (di == nullptr) { + return errors::Internal("Unexpected null device info."); + } + *dev_info = di; + if (ctx == nullptr) { + return errors::Internal("Unexpected null device context."); + } + auto gs = static_cast<const GPUDeviceContext*>(ctx)->stream(); + if (gs == nullptr) { + return errors::Internal("No gpu stream is available."); + } + *stream = gs; + if (dst != nullptr) { + if (src.dtype() != dst->dtype()) { + return errors::Internal("Can't copy a tensor of ", + DataTypeString(src.dtype()), " into a tensor of ", + DataTypeString(dst->dtype())); + } + if (src.TotalBytes() != dst->TotalBytes()) { + return errors::Internal("Can't copy ", src.TotalBytes(), + " bytes of a tensor into another with ", + dst->TotalBytes(), " bytes buffer."); + } + if ((src.TotalBytes() > 0) && !src.IsInitialized()) { + return errors::Internal("Src tensor is not initialized."); + } + if ((dst->TotalBytes() > 0) && !dst->IsInitialized()) { + return errors::Internal("Dst tensor is not initialized."); + } + } + if (!DMAHelper::CanUseDMA(&src)) { + return errors::Internal("GPU copy from non-DMA ", + DataTypeString(src.dtype()), "tensor"); + } + return Status::OK(); +} + +void* GetBase(const Tensor* src) { + return const_cast<void*>(DMAHelper::base(src)); +} + +void* GetBase(Tensor* dst) { return DMAHelper::base(dst); } + /*static*/ void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev, const DeviceContext* device_context, TensorProto* proto, bool is_dead, StatusCallback done) { VLOG(1) << "SetProtoFromGPU device_context " << device_context; + const DeviceBase::GpuDeviceInfo* dev_info = nullptr; + gpu::Stream* stream = nullptr; + Status s = + PrepareCopy(dev, device_context, tensor, nullptr, &dev_info, &stream); + if (!s.ok()) { + done(s); + return; + } + // Tensor values need to be copied from GPU to CPU ram so that // we can build the protobuf response for a RecvTensor RPC. // "device context" identifies the stream where the _Send op executed. - CHECK(device_context); - gpu::Stream* stream = - static_cast<const GPUDeviceContext*>(device_context)->stream(); - - if (!DMAHelper::CanUseDMA(&tensor)) { - done(errors::Internal(strings::StrCat( - "GPU copy from non-DMA ", DataTypeString(tensor.dtype()), "tensor"))); - return; - } proto->set_dtype(tensor.dtype()); tensor.shape().AsProto(proto->mutable_tensor_shape()); - // Prepare a Cord with the right data buf size, and DMA the - // data over from the GPU buffer. Note that 0-size tensors - // do not have a backing buffer. - const size_t num_bytes = is_dead ? 0 : tensor.TotalBytes(); - if (num_bytes > 0) { + + // Prepare a proto with the right data buf size, and DMA the data + // over from the GPU buffer. Note that 0-size tensors do not have a + // backing buffer. + Allocator* alloc = nullptr; + char* buf = nullptr; + const int64 total_bytes = is_dead ? 0 : tensor.TotalBytes(); + if (total_bytes > 0) { port::Tracing::ScopedAnnotation annotation("SetProtoFromGPU"); - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); - char* mb = alloc->Allocate<char>(num_bytes); - const char* src_ptr = - reinterpret_cast<const char*>(DMAHelper::base(&tensor)); - DeviceMemoryBase gpu_src_ptr(const_cast<char*>(src_ptr), num_bytes); - stream->ThenMemcpy(mb, gpu_src_ptr, num_bytes); - // Use of tensor may outlive stack scope, so keep a ref. - TensorReference tensor_ref(tensor); - dev->tensorflow_gpu_device_info()->event_mgr->ThenExecute( - stream, [stream, done, proto, mb, num_bytes, alloc, tensor_ref]() { - if (!stream->ok()) { - done(errors::Internal("SetProtoFromGPU: GPU Memcpy failed")); - // TODO(pbar) We currently have no way to recover the - // worker from a GPU stream in the error state. Until - // there is a way to reset the CUDA driver, it is - // preferable to crash the process and restart. Tracked - // under b/23717097 - LOG(FATAL) << "SetProtoFromGPU: GPU Memcpy failed"; - return; - } - tensor_ref.Unref(); - port::CopyFromArray(proto->mutable_tensor_content(), mb, num_bytes); - alloc->Deallocate<char>(mb, num_bytes); - done(Status::OK()); - }); - } else { - done(Status::OK()); + alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + buf = alloc->Allocate<char>(total_bytes); + void* src_ptr = GetBase(&tensor); + DeviceMemoryBase gpu_src_ptr(src_ptr, total_bytes); + stream->ThenMemcpy(buf, gpu_src_ptr, total_bytes); } + // Use of tensor may outlive stack scope, so keep a ref. + TensorReference tensor_ref(tensor); + dev_info->event_mgr->ThenExecute(stream, [stream, done, proto, buf, + total_bytes, alloc, tensor_ref]() { + if (!stream->ok()) { + LOG(FATAL) << "SetProtoFromGPU: GPU Memcpy failed"; + } + tensor_ref.Unref(); + if (total_bytes > 0) { + port::CopyFromArray(proto->mutable_tensor_content(), buf, total_bytes); + alloc->Deallocate<char>(buf, total_bytes); + } + done(Status::OK()); + }); } // static @@ -114,67 +172,67 @@ void GPUUtil::DeviceToDeviceCopy(DeviceContext* send_dev_context, AllocatorAttributes dst_alloc_attr, const Tensor* input, Tensor* output, StatusCallback done) { - const void* src_ptr = DMAHelper::base(input); - void* dst_ptr = DMAHelper::base(output); - VLOG(2) << "src_ptr " << src_ptr << " dst_ptr " << dst_ptr; - const size_t total_bytes = input->TotalBytes(); - - gpu::Stream* stream = send_dev_context->stream(); - if (stream == nullptr) { - done(errors::Internal("Failed to find device stream")); + const DeviceBase::GpuDeviceInfo* dev_info = nullptr; + gpu::Stream* stream = nullptr; + Status s = + PrepareCopy(src, send_dev_context, *input, output, &dev_info, &stream); + if (!s.ok()) { + done(s); return; } - auto* src_dev_info = src->tensorflow_gpu_device_info(); - CHECK(src_dev_info); - DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); - stream->ThenMemcpy(&gpu_dst_ptr, - DeviceMemoryBase{const_cast<void*>(src_ptr), total_bytes}, - total_bytes); - if (dst->attributes().device_type() == DeviceType(DEVICE_GPU).type()) { - // Use of input may outlive stack scope, so keep a ref. - TensorReference input_ref(*input); - src_dev_info->event_mgr->ThenExecute(stream, [done, stream, input_ref]() { - input_ref.Unref(); - if (!stream->ok()) { - done(errors::Internal("GPU->GPU Memcpy failed")); - } else { - done(Status::OK()); - } - }); + const int64 total_bytes = input->TotalBytes(); + if (total_bytes > 0) { + void* src_ptr = GetBase(input); + DeviceMemoryBase gpu_src_ptr(src_ptr, total_bytes); + void* dst_ptr = GetBase(output); + DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); + VLOG(2) << "src_ptr " << src_ptr << " dst_ptr " << dst_ptr; + stream->ThenMemcpy(&gpu_dst_ptr, gpu_src_ptr, total_bytes); } + + // Use of input may outlive stack scope, so keep a ref. + TensorReference input_ref(*input); + dev_info->event_mgr->ThenExecute(stream, [done, stream, input_ref]() { + input_ref.Unref(); + if (!stream->ok()) { + LOG(FATAL) << "GPU->GPU Memcpy failed"; + } + done(Status::OK()); + }); send_dev_context->MaintainLifetimeOnStream(input, stream); } static CopyTensor::Registration register_gpu_gpu_copy( DEVICE_GPU, DEVICE_GPU, GPUUtil::DeviceToDeviceCopy); +// static void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device, const DeviceContext* device_context, const Tensor* gpu_tensor, Tensor* cpu_tensor, StatusCallback done) { VLOG(1) << "CopyGPUTensorToCPU"; - size_t total_bytes = gpu_tensor->TotalBytes(); - // Note that 0-size tensors have no backing buffer. + const DeviceBase::GpuDeviceInfo* dev_info = nullptr; + gpu::Stream* stream = nullptr; + Status s = PrepareCopy(gpu_device, device_context, *gpu_tensor, cpu_tensor, + &dev_info, &stream); + if (!s.ok()) { + done(s); + return; + } + const int64 total_bytes = gpu_tensor->TotalBytes(); if (total_bytes > 0) { - const void* src_ptr = DMAHelper::base(gpu_tensor); - void* dst_ptr = DMAHelper::base(cpu_tensor); - CHECK(dst_ptr); - auto* stream = gpu_device->tensorflow_gpu_device_info()->stream; - if (device_context) { - stream = static_cast<const GPUDeviceContext*>(device_context)->stream(); - } - stream->ThenMemcpy( - dst_ptr, DeviceMemoryBase{const_cast<void*>(src_ptr), total_bytes}, - total_bytes); - stream->BlockHostUntilDone(); + void* src_ptr = GetBase(gpu_tensor); + DeviceMemoryBase gpu_src_ptr(src_ptr, total_bytes); + void* dst_ptr = GetBase(cpu_tensor); + stream->ThenMemcpy(dst_ptr, gpu_src_ptr, total_bytes); + } + dev_info->event_mgr->ThenExecute(stream, [stream, done]() { if (!stream->ok()) { - done(errors::Internal("CopyGPUTensorToCPU: GPU->CPU Memcpy failed")); - return; + LOG(FATAL) << "GPU->CPU Memcpy failed"; } - } - - done(Status::OK()); + done(Status::OK()); + }); } /* static */ @@ -183,47 +241,31 @@ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor, Device* gpu_device, Tensor* gpu_tensor, StatusCallback done) { VLOG(1) << "CopyCPUTensorToGPU"; - CHECK(DeviceType(gpu_device->attributes().device_type()) == - DeviceType(DEVICE_GPU)); - - auto* dev_info = gpu_device->tensorflow_gpu_device_info(); - if (!dev_info) { - done(errors::Internal("Failed to find dest device GPUDeviceInfo")); - return; - } - if (cpu_tensor->TotalBytes() != gpu_tensor->TotalBytes()) { - done(errors::Internal( - strings::StrCat("Can't copy ", cpu_tensor->TotalBytes(), - " bytes of a tensor into another with ", - gpu_tensor->TotalBytes(), " bytes buffer."))); + const DeviceBase::GpuDeviceInfo* dev_info = nullptr; + gpu::Stream* stream = nullptr; + Status s = PrepareCopy(gpu_device, device_context, *cpu_tensor, gpu_tensor, + &dev_info, &stream); + if (!s.ok()) { + done(s); return; } const int64 total_bytes = cpu_tensor->TotalBytes(); // Note that 0-size tensors have no backing buffer. if (total_bytes > 0) { - const void* src_ptr = DMAHelper::base(cpu_tensor); - void* dst_ptr = DMAHelper::base(gpu_tensor); + void* src_ptr = GetBase(cpu_tensor); + void* dst_ptr = GetBase(gpu_tensor); DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); - - CHECK(device_context); - auto* stream = - static_cast<const GPUDeviceContext*>(device_context)->stream(); stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes); - auto* dev_info = gpu_device->tensorflow_gpu_device_info(); - // Use of cpu_tensor may outlive stack scope, so keep a ref. - TensorReference input_ref(*cpu_tensor); - dev_info->event_mgr->ThenExecute(stream, [stream, done, input_ref]() { - input_ref.Unref(); - if (!stream->ok()) { - done(errors::Internal("CopyCPUTensorToGPU: GPU Memcpy failed")); - } else { - done(Status::OK()); - } - }); - } else { - // empty tensor case - done(Status::OK()); } + // Use of cpu_tensor may outlive stack scope, so keep a ref. + TensorReference input_ref(*cpu_tensor); + dev_info->event_mgr->ThenExecute(stream, [stream, done, input_ref]() { + input_ref.Unref(); + if (!stream->ok()) { + LOG(FATAL) << "CPU->GPU Memcpy failed"; + } + done(Status::OK()); + }); } Status GPUUtil::Sync(Device* gpu_device) { @@ -257,7 +299,7 @@ string GPUUtil::MemoryDebugString(const Device* device, Tensor* tensor) { CHECK(tensor); const int64 num_bytes = std::min<int64>( FLAGS_brain_gpu_util_debug_string_maxlen, tensor->TotalBytes()); - void* ptr = (num_bytes > 0) ? DMAHelper::base(tensor) : nullptr; + void* ptr = (num_bytes > 0) ? GetBase(tensor) : nullptr; strings::Appendf(&ret, "%p:", ptr); if (num_bytes > 0) { auto* dev_info = device->tensorflow_gpu_device_info(); @@ -295,14 +337,14 @@ uint64 GPUUtil::Checksum(Device* gpu_device, } uint64 GPUUtil::Checksum(const Tensor& tensor) { - const float* fptr = reinterpret_cast<const float*>(DMAHelper::base(&tensor)); + const float* fptr = reinterpret_cast<const float*>(GetBase(&tensor)); size_t num_bytes = tensor.TotalBytes(); size_t num_floats = num_bytes / sizeof(float); for (size_t i = 0; i < num_floats; ++i) { CHECK(!std::isnan(fptr[i])) << " i " << i; } // TODO(tucker): consider using crc32c instead. - return Hash64(reinterpret_cast<const char*>(DMAHelper::base(&tensor)), + return Hash64(reinterpret_cast<const char*>(GetBase(&tensor)), tensor.TotalBytes(), 0); } diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index ac1e0554dd..b94207e2e8 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -553,4 +553,25 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, return Status::OK(); } +void RemoveDescriptionsFromOpDef(OpDef* op_def) { + for (int i = 0; i < op_def->input_arg_size(); ++i) { + op_def->mutable_input_arg(i)->clear_description(); + } + for (int i = 0; i < op_def->output_arg_size(); ++i) { + op_def->mutable_output_arg(i)->clear_description(); + } + for (int i = 0; i < op_def->attr_size(); ++i) { + op_def->mutable_attr(i)->clear_description(); + } + op_def->clear_summary(); + op_def->clear_description(); +} + +void RemoveDescriptionsFromOpList(OpList* op_list) { + for (int i = 0; i < op_list->op_size(); ++i) { + OpDef* op_def = op_list->mutable_op(i); + RemoveDescriptionsFromOpDef(op_def); + } +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h index 34dd4f374b..de350d9aaa 100644 --- a/tensorflow/core/framework/op_def_util.h +++ b/tensorflow/core/framework/op_def_util.h @@ -54,6 +54,10 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, const OpDef& penultimate_op, const OpDef& new_op); +// Remove all docs from *op_def / *op_list. +void RemoveDescriptionsFromOpDef(OpDef* op_def); +void RemoveDescriptionsFromOpList(OpList* op_list); + } // namespace tensorflow #endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_ diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h index 9718e252a5..25391009f9 100644 --- a/tensorflow/core/graph/tensor_id.h +++ b/tensorflow/core/graph/tensor_id.h @@ -19,6 +19,7 @@ limitations under the License. #include <string> #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -33,6 +34,13 @@ struct TensorId : public std::pair<StringPiece, int> { using Base::pair; string ToString() const { return strings::StrCat(first, ":", second); } + + struct Hasher { + public: + std::size_t operator()(const TensorId& x) const { + return Hash32(x.first.data(), x.first.size(), x.second); + } + }; }; TensorId ParseTensorName(const string& name); diff --git a/tensorflow/core/graph/validate.cc b/tensorflow/core/graph/validate.cc index 465ca12098..faf1ea89e0 100644 --- a/tensorflow/core/graph/validate.cc +++ b/tensorflow/core/graph/validate.cc @@ -15,8 +15,13 @@ limitations under the License. #include "tensorflow/core/graph/validate.h" +#include <unordered_map> + +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace graph { @@ -34,5 +39,47 @@ Status ValidateGraphDef(const GraphDef& graph_def, return s; } +namespace { + +class OpListOpRegistry : public OpRegistryInterface { + public: + // Does not take ownership of op_list, *op_list must outlive *this. + OpListOpRegistry(const OpList* op_list) { + for (const OpDef& op_def : op_list->op()) { + index_[op_def.name()] = &op_def; + } + } + ~OpListOpRegistry() override {} + + const OpDef* LookUp(const string& op_type_name, + Status* status) const override { + auto iter = index_.find(op_type_name); + if (iter == index_.end()) { + status->Update( + errors::NotFound("Op type not registered '", op_type_name, "'")); + return nullptr; + } + return iter->second; + } + + private: + std::unordered_map<string, const OpDef*> index_; +}; + +} // namespace + +Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def, + const OpList& op_list) { + OpListOpRegistry registry(&op_list); + GraphDef copy(graph_def); + TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(©, ®istry, 0)); + return ValidateGraphDef(copy, ®istry); +} + +void GetOpListForValidation(OpList* op_list, const OpRegistry* op_registry) { + op_registry->Export(false, op_list); + RemoveDescriptionsFromOpList(op_list); +} + } // namespace graph } // namespace tensorflow diff --git a/tensorflow/core/graph/validate.h b/tensorflow/core/graph/validate.h index af91f76511..b98f2e0f01 100644 --- a/tensorflow/core/graph/validate.h +++ b/tensorflow/core/graph/validate.h @@ -23,15 +23,29 @@ limitations under the License. namespace tensorflow { namespace graph { -// Returns OK if 'graph_def' has the following properties: +// Returns OK if every NodeDef in `graph_def` is valid with respect to +// its corresponding OpDef (as defined by ValidateNodeDef()) as +// registered in `op_registry`. // -// 1) Every NodeDef is valid with respect to its corresponding OpDef -// as registered in 'op_registry'. -// -// REQUIRES: 'op_registry' is not nullptr. +// REQUIRES: +// * `op_registry` is not nullptr. +// * `graph_def` has default attrs filled in (see AddDefaultAttrsToGraphDef()). Status ValidateGraphDef(const GraphDef& graph_def, const OpRegistryInterface* op_registry); +// Like ValidateGraphDef() except: +// * Takes an OpList instead of an OpRegistryInterface. +// Note that the OpList need not have descriptions, which can be a big +// space savings, see GetOpListForValidation() below. +// * Makes a copy of `graph_def` and calls AddDefaultAttrsToGraphDef() +// on the copy. +Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def, + const OpList& op_list); + +// Get an OpList from `*op_registry` with all the descriptions removed. +void GetOpListForValidation( + OpList* op_list, const OpRegistry* op_registry = OpRegistry::Global()); + } // namespace graph } // namespace tensorflow diff --git a/tensorflow/core/graph/validate_test.cc b/tensorflow/core/graph/validate_test.cc index c1f75f9c47..17e56d7443 100644 --- a/tensorflow/core/graph/validate_test.cc +++ b/tensorflow/core/graph/validate_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/subgraph.h" @@ -34,9 +35,9 @@ REGISTER_OP("FloatInput").Output("o: float"); REGISTER_OP("Int32Input").Output("o: int32"); TEST(ValidateGraphDefTest, TestValidGraph) { - string graph_def_str = - "node { name: 'A' op: 'FloatInput'}" - "node { name: 'B' op: 'FloatInput'}" + const string graph_def_str = + "node { name: 'A' op: 'FloatInput' }" + "node { name: 'B' op: 'FloatInput' }" "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"; GraphDef graph_def; @@ -46,9 +47,9 @@ TEST(ValidateGraphDefTest, TestValidGraph) { } TEST(ValidateGraphDefTest, GraphWithUnspecifiedDefaultAttr) { - string graph_def_str = - "node { name: 'A' op: 'FloatInput'}" - "node { name: 'B' op: 'Int32Input'}" + const string graph_def_str = + "node { name: 'A' op: 'FloatInput' }" + "node { name: 'B' op: 'Int32Input' }" "node { " " name: 'C' op: 'Sum' " " attr { key: 'T' value { type: DT_FLOAT } }" @@ -70,8 +71,8 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedDefaultAttr) { TEST(ValidateGraphDefTest, GraphWithUnspecifiedRequiredAttr) { // "DstT" attribute is missing. - string graph_def_str = - "node { name: 'A' op: 'FloatInput'}" + const string graph_def_str = + "node { name: 'A' op: 'FloatInput' }" "node { " " name: 'B' op: 'Cast' " " attr { key: 'SrcT' value { type: DT_FLOAT } }" @@ -93,5 +94,53 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedRequiredAttr) { EXPECT_TRUE(StringPiece(s.ToString()).contains("NodeDef missing attr")); } +TEST(ValidateGraphDefAgainstOpListTest, GraphWithOpOnlyInOpList) { + OpList op_list; + TF_ASSERT_OK(OpDefBuilder("UniqueSnowflake").Finalize(op_list.add_op())); + const string graph_def_str = "node { name: 'A' op: 'UniqueSnowflake' }"; + GraphDef graph_def; + auto parser = protobuf::TextFormat::Parser(); + CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str; + TF_ASSERT_OK(graph::ValidateGraphDefAgainstOpList(graph_def, op_list)); +} + +TEST(ValidateGraphDefAgainstOpListTest, GraphWithGlobalOpNotInOpList) { + OpList op_list; + TF_ASSERT_OK(OpDefBuilder("NotAnywhere").Finalize(op_list.add_op())); + const string graph_def_str = "node { name: 'A' op: 'FloatInput' }"; + GraphDef graph_def; + auto parser = protobuf::TextFormat::Parser(); + CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str; + ASSERT_FALSE(graph::ValidateGraphDefAgainstOpList(graph_def, op_list).ok()); +} + +REGISTER_OP("HasDocs").Doc("This is in the summary."); + +TEST(GetOpListForValidationTest, ShouldStripDocs) { + bool found_float = false; + bool found_int32 = false; + bool found_has_docs = false; + OpList op_list; + graph::GetOpListForValidation(&op_list); + for (const OpDef& op_def : op_list.op()) { + if (op_def.name() == "FloatInput") { + EXPECT_FALSE(found_float); + found_float = true; + } + if (op_def.name() == "Int32Input") { + EXPECT_FALSE(found_int32); + found_int32 = true; + } + if (op_def.name() == "HasDocs") { + EXPECT_FALSE(found_has_docs); + found_has_docs = true; + EXPECT_TRUE(op_def.summary().empty()); + } + } + EXPECT_TRUE(found_float); + EXPECT_TRUE(found_int32); + EXPECT_TRUE(found_has_docs); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_solve_ls_op.cc b/tensorflow/core/kernels/matrix_solve_ls_op.cc index b752a7ed6e..c69c93fcb1 100644 --- a/tensorflow/core/kernels/matrix_solve_ls_op.cc +++ b/tensorflow/core/kernels/matrix_solve_ls_op.cc @@ -21,11 +21,11 @@ limitations under the License. #include "third_party/eigen3/Eigen/QR" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/binary_linalg_ops_common.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/port.h" -#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc new file mode 100644 index 0000000000..f083c50743 --- /dev/null +++ b/tensorflow/core/kernels/sparse_xent_op.cc @@ -0,0 +1,105 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op_kernel.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/sparse_xent_op.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template <typename Device, typename T> +class SparseSoftmaxXentWithLogitsOp : public OpKernel { + public: + explicit SparseSoftmaxXentWithLogitsOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& logits_in = context->input(0); + const Tensor& labels_in = context->input(1); + OP_REQUIRES(context, logits_in.shape().dim_size(0) == labels_in.NumElements(), + errors::InvalidArgument( + "logits first dimension must match labels size. logits shape=", + logits_in.shape().DebugString(), " labels shape=", + labels_in.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()), + errors::InvalidArgument("logits must be 2-dimensional")); + // As we already tested that both inputs have the same shape no need to + // check that "labels" is a matrix too. + + // loss is 1-D (one per example), and size is batch_size. + + Tensor scratch; + OP_REQUIRES_OK( + context, context->allocate_temp(DataTypeToEnum<T>::value, + TensorShape({logits_in.dim_size(0)}), + &scratch)); + + Tensor* loss_out = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({logits_in.dim_size(0)}), &loss_out)); + Tensor* back_out = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(1, logits_in.shape(), &back_out)); + + functor::SparseXentFunctor<Device, T> functor; + functor(context->eigen_device<Device>(), logits_in.matrix<T>(), + labels_in.vec<int64>(), scratch.vec<T>(), loss_out->vec<T>(), + back_out->matrix<T>()); + } +}; + +// Partial specialization for a CPUDevice, that uses the Eigen implementation +// from XentEigenImpl. +namespace functor { +template <typename T> +struct SparseXentFunctor<CPUDevice, T> { + void operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix logits, + typename TTypes<int64>::ConstVec labels, + typename TTypes<T>::Vec scratch, + typename TTypes<T>::Vec loss, + typename TTypes<T>::Matrix backprop) { + SparseXentEigenImpl<CPUDevice, T>::Compute(d, logits, labels, scratch, loss, + backprop); + } +}; +} // namespace functor + +REGISTER_KERNEL_BUILDER(Name("SparseSoftmaxCrossEntropyWithLogits") + .Device(DEVICE_CPU) + .TypeConstraint<float>("T"), + SparseSoftmaxXentWithLogitsOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("SparseSoftmaxCrossEntropyWithLogits") + .Device(DEVICE_CPU) + .TypeConstraint<double>("T"), + SparseSoftmaxXentWithLogitsOp<CPUDevice, double>); + +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("SparseSoftmaxCrossEntropyWithLogits") + .Device(DEVICE_GPU) + .TypeConstraint<float>("T"), + SparseSoftmaxXentWithLogitsOp<GPUDevice, float>); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_xent_op.h b/tensorflow/core/kernels/sparse_xent_op.h new file mode 100644 index 0000000000..5014794cb9 --- /dev/null +++ b/tensorflow/core/kernels/sparse_xent_op.h @@ -0,0 +1,204 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_XENT_OP_H_ +#define TENSORFLOW_KERNELS_XENT_OP_H_ +// Functor definition for SparseXentOp, must be compilable by nvcc. + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +namespace sparse_xent_helpers { + +template <typename T> +typename TTypes<const T, 1>::Tensor32Bit To32BitConst( + typename TTypes<T>::Vec in) { + return To32Bit(typename TTypes<T>::ConstVec(in.data(), in.dimensions())); +} + +template <typename T> +typename TTypes<const T, 2>::Tensor32Bit To32BitConst( + typename TTypes<T>::Matrix in) { + return To32Bit(typename TTypes<T>::ConstMatrix(in.data(), in.dimensions())); +} + +} // namespace sparse_xent_helpers + +namespace generator { + +// Generator for calculation of the sparse Xent loss. +// This generator takes the logits, the sum of the exponentiated +// logits, and the label indices. For each minibatch entry, ignoring +// the batch index b, it calculates: +// +// loss[j] = (log(sum_exp_logits) - logits[j]) * 1{ j == label } +// +// for j = 0 .. num_classes. This value must be summed over all j for +// the final loss. +template <typename T> +class SparseXentLossGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentLossGenerator( + typename TTypes<const T, 2>::Tensor32Bit logits, + typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits, + TTypes<const int64, 1>::Tensor32Bit labels) + : logits_(logits), sum_exp_logits_(sum_exp_logits), labels_(labels) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + operator()(const Eigen::array<int, 2>& coords) const { + int batch = coords[0]; + int depth = coords[1]; + return (labels_(batch) == depth) + ? (std::log(sum_exp_logits_(batch)) - logits_(coords)) + : T(0.0); + }; + + private: + typename TTypes<const T, 2>::Tensor32Bit logits_; + typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_; + TTypes<const int64, 1>::Tensor32Bit labels_; +}; + +// Generator for calculation of the sparse Xent gradient. +// This generator takes the logits, the sum of the exponentiated +// logits, and the label indices. For each minibatch entry, ignoring +// the batch index b, it calculates: +// +// exp(logits[j]) / sum_exp_logits - 1{ j == label } +// +// for j = 0 .. num_classes. +template <typename T> +class SparseXentGradGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentGradGenerator( + typename TTypes<const T, 2>::Tensor32Bit logits, + typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits, + TTypes<const int64, 1>::Tensor32Bit labels) + : logits_(logits), sum_exp_logits_(sum_exp_logits), labels_(labels) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + operator()(const Eigen::array<int, 2>& coords) const { + int batch = coords[0]; + int depth = coords[1]; + T subtract = (depth == labels_(batch)) ? T(1.0) : T(0.0); + return std::exp(logits_(coords)) / sum_exp_logits_(batch) - subtract; + }; + + private: + typename TTypes<const T, 2>::Tensor32Bit logits_; + typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_; + TTypes<const int64, 1>::Tensor32Bit labels_; +}; + +} // namespace generator + +namespace functor { + +// Functor used by SparseXentOp to do the computations. +template <typename Device, typename T> +struct SparseXentFunctor { + // Computes Cross Entropy loss and backprop. + // + // logits: batch_size, num_classes. + // labels: num_classes. + // scratch: temporary tensor, dims: batch_size, 1 + // loss: output tensor for the loss, dims: batch_size. + // backprop: output tensor for the backprop, dims: batch_size, num_classes. + void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits, + typename TTypes<int64>::ConstVec labels, + typename TTypes<T>::Vec scratch, typename TTypes<T>::Vec loss, + typename TTypes<T>::Matrix backprop); +}; + +// Eigen code implementing SparseXentFunctor::operator(). +// This code works for both CPU and GPU and is used by the functor +// specializations for both device types. +template <typename Device, typename T> +struct SparseXentEigenImpl { + static void Compute(const Device& d, typename TTypes<T>::ConstMatrix logits, + typename TTypes<int64>::ConstVec labels, + typename TTypes<T>::Vec scratch, + typename TTypes<T>::Vec loss, + typename TTypes<T>::Matrix backprop) { + // NOTE(touts): This duplicates some of the computations in softmax_op + // because we need the intermediate (logits -max(logits)) values to + // avoid a log(exp()) in the computation of the loss. + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + +// These arrays are used to reduce along the class dimension, and broadcast +// the resulting value to all classes. +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::array<int, 1> along_class; + along_class[0] = kClassDim; + Eigen::array<int, 1> batch_only; + batch_only[0] = batch_size; + Eigen::array<int, 2> batch_by_one; + batch_by_one[0] = batch_size; + batch_by_one[1] = 1; + Eigen::array<int, 2> one_by_class; + one_by_class[0] = 1; + one_by_class[1] = num_classes; +#else + Eigen::IndexList<Eigen::type2index<kClassDim> > along_class; + Eigen::IndexList<int, Eigen::type2index<1> > batch_by_one; + batch_by_one.set(0, batch_size); + Eigen::IndexList<int> batch_only; + batch_only.set(0, batch_size); + Eigen::IndexList<Eigen::type2index<1>, int> one_by_class; + one_by_class.set(1, num_classes); +#endif + + // scratch = max_logits along classes. + To32Bit(scratch).device(d) = To32Bit(logits).maximum(along_class); + + // backprop = logits - max_logits. + To32Bit(backprop).device(d) = + To32Bit(logits) - + To32Bit(scratch).reshape(batch_by_one).broadcast(one_by_class); + + // scratch = sum(exp(logits - max_logits)) along classes. + To32Bit(scratch).device(d) = To32Bit(backprop).exp().sum(along_class); + + // sum(-labels * + // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) + // along classes + generator::SparseXentLossGenerator<T> sparse_xent_loss_gen( + sparse_xent_helpers::To32BitConst<T>(backprop), + sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels)); + To32Bit(loss).device(d) = + To32Bit(backprop).generate(sparse_xent_loss_gen).sum(along_class); + + // backprop: prob - labels, where + // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) + generator::SparseXentGradGenerator<T> sparse_xent_grad_gen( + sparse_xent_helpers::To32BitConst<T>(backprop), + sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels)); + To32Bit(backprop).device(d) = + To32Bit(backprop).generate(sparse_xent_grad_gen); + } +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_XENT_OP_H_ diff --git a/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc new file mode 100644 index 0000000000..296e4352a7 --- /dev/null +++ b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc @@ -0,0 +1,51 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/sparse_xent_op.h" + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +// Partial specialization for a GPUDevice, that uses the Eigen implementation +// from XentEigenImpl. +namespace functor { +template <typename T> +struct SparseXentFunctor<GPUDevice, T> { + void operator()(const GPUDevice& d, typename TTypes<T>::ConstMatrix logits, + typename TTypes<int64>::ConstVec labels, + typename TTypes<T>::Vec scratch, + typename TTypes<T>::Vec loss, + typename TTypes<T>::Matrix backprop) { + SparseXentEigenImpl<GPUDevice, T>::Compute(d, logits, labels, scratch, loss, + backprop); + } +}; +} // end namespace functor + +// Instantiate the GPU implementation for float. +template struct functor::SparseXentFunctor<GPUDevice, float>; +template class generator::SparseXentGradGenerator<float>; + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/sparse_xent_op_test.cc b/tensorflow/core/kernels/sparse_xent_op_test.cc new file mode 100644 index 0000000000..16335a9f1b --- /dev/null +++ b/tensorflow/core/kernels/sparse_xent_op_test.cc @@ -0,0 +1,78 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <random> + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/kernels/xent_op.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +static Graph* SparseXent(int batch_size, int num_classes) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor logits(DT_FLOAT, TensorShape({batch_size, num_classes})); + logits.flat<float>().setRandom(); + Tensor labels(DT_INT64, TensorShape({batch_size})); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dist(0, num_classes - 1); + auto labels_t = labels.flat<int64>(); + for (int i = 0; i < batch_size; ++i) { + labels_t(i) = dist(gen); + } + test::graph::Binary(g, "SparseSoftmaxCrossEntropyWithLogits", + test::graph::Constant(g, logits), + test::graph::Constant(g, labels)); + return g; +} + +#define BM_SparseXentDev(BATCH, CLASS, DEVICE) \ + static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE(int iters) { \ + testing::ItemsProcessed(static_cast<int64>(iters) * BATCH * CLASS); \ + test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS)).Run(iters); \ + } \ + BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE); + +/// The representative tests for ptb_word on GPU +BM_SparseXentDev(8, 1000000, gpu); + +BM_SparseXentDev(16, 10000, gpu); +BM_SparseXentDev(16, 30000, gpu); +BM_SparseXentDev(16, 100000, gpu); + +BM_SparseXentDev(32, 10000, gpu); +BM_SparseXentDev(32, 30000, gpu); +BM_SparseXentDev(32, 100000, gpu); + +BM_SparseXentDev(64, 10000, gpu); +BM_SparseXentDev(64, 30000, gpu); +BM_SparseXentDev(64, 100000, gpu); + +// CPU +BM_SparseXentDev(8, 1000000, cpu); + +BM_SparseXentDev(16, 10000, cpu); +BM_SparseXentDev(16, 100000, cpu); + +BM_SparseXentDev(32, 10000, cpu); +BM_SparseXentDev(32, 100000, cpu); + +BM_SparseXentDev(64, 10000, cpu); +BM_SparseXentDev(64, 100000, cpu); + +} // end namespace tensorflow diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 83f28bdacf..a51a0b3469 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -551,6 +551,29 @@ loss: Per example loss (batch_size vector). backprop: backpropagated gradients (batch_size x num_classes matrix). )doc"); +REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits") + .Input("features: T") + .Input("labels: int64") + .Output("loss: T") + .Output("backprop: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Computes softmax cross entropy cost and gradients to backpropagate. + +Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept +a matrix of label probabilities, but rather a single label per row +of features. This label is considered to have probability 1.0 for the +given row. + +Inputs are the logits, not probabilities. + +features: batch_size x num_classes matrix +labels: batch_size vector with values in [0, num_classes). + This is the label for the given minibatch entry. +loss: Per example loss (batch_size vector). +backprop: backpropagated gradients (batch_size x num_classes matrix). +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("InTopK") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index fbb589bf24..c9f0dda3a1 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -8509,6 +8509,41 @@ op { description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nLike `SegmentSum`, but `segment_ids` can have rank less than `data`\'s first\ndimension, selecting a subset of dimension 0, specified by `indices`.\n\nFor example:\n\n```prettyprint\nc = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])\n\n# Select two rows, one segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))\n ==> [[0 0 0 0]]\n\n# Select two rows, two segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))\n ==> [[ 1 2 3 4]\n [-1 -2 -3 -4]]\n\n# Select all rows, two segments.\ntf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))\n ==> [[0 0 0 0]\n [5 6 7 8]]\n\n# Which is equivalent to:\ntf.segment_sum(c, tf.constant([0, 0, 1]))\n```" } op { + name: "SparseSoftmaxCrossEntropyWithLogits" + input_arg { + name: "features" + description: "batch_size x num_classes matrix" + type_attr: "T" + } + input_arg { + name: "labels" + description: "batch_size vector with values in [0, num_classes).\nThis is the label for the given minibatch entry." + type: DT_INT64 + } + output_arg { + name: "loss" + description: "Per example loss (batch_size vector)." + type_attr: "T" + } + output_arg { + name: "backprop" + description: "backpropagated gradients (batch_size x num_classes matrix)." + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + summary: "Computes softmax cross entropy cost and gradients to backpropagate." + description: "Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept\na matrix of label probabilities, but rather a single label per row\nof features. This label is considered to have probability 1.0 for the\ngiven row.\n\nInputs are the logits, not probabilities." +} +op { name: "SparseSplit" input_arg { name: "split_dim" diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index a5932cf0e1..258e4b6210 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -3,6 +3,10 @@ load("/google/protobuf/protobuf", "cc_proto_library") load("/google/protobuf/protobuf", "py_proto_library") +# configure may change the following lines. +CUDA_VERSION = '7.0' +CUDNN_VERSION = '6.5' + # Appends a suffix to a list of deps. def tf_deps(deps, suffix): tf_deps = [] @@ -68,3 +72,9 @@ def tf_additional_test_srcs(): def tf_kernel_tests_linkstatic(): return 0 + +def tf_get_cuda_version(): + return CUDA_VERSION + +def tf_get_cudnn_version(): + return CUDNN_VERSION diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD index 44dbc47ad1..b30c3ecf3d 100644 --- a/tensorflow/core/platform/default/build_config/BUILD +++ b/tensorflow/core/platform/default/build_config/BUILD @@ -9,6 +9,7 @@ exports_files(["LICENSE"]) load("/tensorflow/tensorflow", "tf_copts") load("/tensorflow/tensorflow", "tf_cuda_library") +load("//tensorflow/core:platform/default/build_config.bzl", "tf_get_cuda_version") cc_library( name = "gtest", @@ -74,7 +75,7 @@ filegroup( cc_library( name = "cuda", data = [ - "//third_party/gpus/cuda:lib64/libcudart.so.7.0", + "//third_party/gpus/cuda:lib64/libcudart.so." + tf_get_cuda_version(), ], linkopts = [ "-Wl,-rpath,third_party/gpus/cuda/lib64", diff --git a/tensorflow/core/public/session.h b/tensorflow/core/public/session.h index 6d364edef3..cbaaf602ee 100644 --- a/tensorflow/core/public/session.h +++ b/tensorflow/core/public/session.h @@ -114,6 +114,33 @@ class Session { const std::vector<string>& target_node_names, std::vector<Tensor>* outputs) = 0; + /// \brief Sets up a graph for partial execution. All future feeds and + /// fetches are specified by 'input_names' and 'output_names'. Returns + /// 'handle' that can be used to perform a sequence of partial feeds and + /// fetches. + /// NOTE: This API is still experimental and may change. + virtual Status PRunSetup(const std::vector<string>& input_names, + const std::vector<string>& output_names, + const std::vector<string>& target_nodes, + string* handle) { + return errors::Unimplemented( + "Partial run is not supported for" + " this session."); + } + + /// \brief Continues the pending execution specified by 'handle' with the + /// provided input tensors and fills `outputs` for the endpoints specified + /// in `output_names`. + /// NOTE: This API is still experimental and may change. + virtual Status PRun(const string& handle, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_names, + std::vector<Tensor>* outputs) { + return errors::Unimplemented( + "Partial run is not supported for" + " this session."); + } + /// \brief Closes this session. /// /// Closing a session releases the resources used by this session diff --git a/tensorflow/core/public/tensor_c_api.h b/tensorflow/core/public/tensor_c_api.h index 2c70713f14..e2210dcdcf 100644 --- a/tensorflow/core/public/tensor_c_api.h +++ b/tensorflow/core/public/tensor_c_api.h @@ -253,7 +253,7 @@ extern void TF_ExtendGraph(TF_Session*, const void* proto, size_t proto_len, // implementation will eventually call TF_DeleteTensor on each input). // // On success, the tensors corresponding to output_names[0,noutputs-1] -// are placed in outputs[]. and these outputs[] become the property +// are placed in outputs[], and these outputs[] become the property // of the caller (the caller must eventually call TF_DeleteTensor on // them). // @@ -269,6 +269,40 @@ extern void TF_Run(TF_Session*, // Output status TF_Status*); +// Set up the graph with the intended feeds and fetches for a sequence +// of partial run calls. +// +// On success, returns a handle that is used for subsequent PRun calls. +// +// On failure, out_status contains a tensorflow::Status with an error +// message. +// NOTE: This is EXPERIMENTAL and subject to change. +extern void TF_PRunSetup(TF_Session*, + // Input names + const char** input_names, int ninputs, + // Output names + const char** output_tensor_names, int noutputs, + // Target nodes + const char** target_node_names, int ntargets, + // Output handle + char** handle, + // Output status + TF_Status*); + +// Continue to run the graph with additional feeds and fetches. The +// execution state is uniquely identified by the handle. +// NOTE: This is EXPERIMENTAL and subject to change. +extern void TF_PRun(TF_Session*, const char* handle, + // Input tensors + const char** input_names, TF_Tensor** inputs, int ninputs, + // Output tensors + const char** output_tensor_names, TF_Tensor** outputs, + int noutputs, + // Target nodes + const char** target_node_names, int ntargets, + // Output status + TF_Status*); + // -------------------------------------------------------------------------- // Load plugins containing custom ops and kernels diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD index 4f1fabfe56..5eb6d846fd 100644 --- a/tensorflow/examples/android/BUILD +++ b/tensorflow/examples/android/BUILD @@ -13,16 +13,29 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -cc_library( - name = "tensorflow_native_libs", - srcs = glob(["jni/**/*.cc"]) + [":libpthread.so"], - hdrs = glob(["jni/**/*.h"]), +cc_binary( + name = "libtensorflow_demo.so", + srcs = glob([ + "jni/**/*.cc", + "jni/**/*.h", + ]) + [":libpthread.so"], copts = [ "-std=c++11", "-mfpu=neon", "-O2", ], - linkopts = ["-llog -landroid -lm -ljnigraphics"], + linkopts = [ + "-landroid", + "-ljnigraphics", + "-llog", + "-lm", + "-z defs", + "-s", + "-Wl,--icf=all", # Identical Code Folding + "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export + ], + linkshared = 1, + linkstatic = 1, tags = [ "manual", "notap", @@ -43,6 +56,14 @@ cc_binary( ], ) +cc_library( + name = "tensorflow_native_libs", + srcs = [ + ":libpthread.so", + ":libtensorflow_demo.so", + ], +) + android_binary( name = "tensorflow_demo", srcs = glob([ @@ -52,7 +73,6 @@ android_binary( assets_dir = "assets", custom_package = "org.tensorflow.demo", inline_constants = 1, - legacy_native_support = 0, manifest = "AndroidManifest.xml", resource_files = glob(["res/**"]), tags = [ diff --git a/tensorflow/examples/udacity/README.md b/tensorflow/examples/udacity/README.md index bab75c7973..a3adc5f155 100644 --- a/tensorflow/examples/udacity/README.md +++ b/tensorflow/examples/udacity/README.md @@ -17,6 +17,24 @@ On mac, find the virtual machine's IP using: Then go to: http://IP:8888 (likely http://192.168.99.100:8888) +FAQ +--- + +* **I'm getting a MemoryError when loading data in the first notebook.** + +If you're using a Mac, Docker works by running a VM locally (which +is controlled by `docker-machine`). It's quite likely that you'll +need to bump up the amount of RAM allocated to the VM beyond the +default (which is 1G). +[This Stack Overflow question](http://stackoverflow.com/questions/32834082/how-to-increase-docker-machine-memory-mac) +has two good suggestions; we recommend using 8G. + +In addition, you may need to pass `--memory=8g` as an extra argument to +`docker run`. + +Notes for anyone needing to build their own containers (mostly instructors) +=========================================================================== + Building a local Docker container --------------------------------- diff --git a/tensorflow/g3doc/api_docs/python/array_ops.md b/tensorflow/g3doc/api_docs/python/array_ops.md index d224ebbe5a..6a272c0786 100644 --- a/tensorflow/g3doc/api_docs/python/array_ops.md +++ b/tensorflow/g3doc/api_docs/python/array_ops.md @@ -884,7 +884,7 @@ tf.transpose(b, perm=[0, 2, 1]) ==> [[[1 4] - - - -### `tf.gather(params, indices, name=None)` {#gather} +### `tf.gather(params, indices, validate_indices=None, name=None)` {#gather} Gather slices from `params` according to `indices`. @@ -912,6 +912,7 @@ this operation will permute `params` accordingly. * <b>`params`</b>: A `Tensor`. * <b>`indices`</b>: A `Tensor`. Must be one of the following types: `int32`, `int64`. +* <b>`validate_indices`</b>: An optional `bool`. Defaults to `True`. * <b>`name`</b>: A name for the operation (optional). ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index 253f13f38f..b1a9cd4a14 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -332,6 +332,7 @@ * [`softmax_cross_entropy_with_logits`](../../api_docs/python/nn.md#softmax_cross_entropy_with_logits) * [`softplus`](../../api_docs/python/nn.md#softplus) * [`softsign`](../../api_docs/python/nn.md#softsign) + * [`sparse_softmax_cross_entropy_with_logits`](../../api_docs/python/nn.md#sparse_softmax_cross_entropy_with_logits) * [`tanh`](../../api_docs/python/nn.md#tanh) * [`top_k`](../../api_docs/python/nn.md#top_k) * [`uniform_candidate_sampler`](../../api_docs/python/nn.md#uniform_candidate_sampler) diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md index 6677320a8d..7e056e2f11 100644 --- a/tensorflow/g3doc/api_docs/python/nn.md +++ b/tensorflow/g3doc/api_docs/python/nn.md @@ -751,6 +751,12 @@ classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both. +**NOTE:**: While the classes are mutually exclusive, their probabilities +need not be. All that is required is that each row of `labels` is +a valid probability distribution. If using exclusive `labels` +(wherein one and only one class is true at a time), see +`sparse_softmax_cross_entropy_with_logits`. + **WARNING:** This op expects unscaled logits, since it performs a `softmax` on `logits` internally for efficiency. Do not call this op with the output of `softmax`, as it will produce incorrect results. @@ -771,6 +777,46 @@ and the same dtype (either `float32` or `float64`). softmax cross entropy loss. +- - - + +### `tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name=None)` {#sparse_softmax_cross_entropy_with_logits} + +Computes sparse softmax cross entropy between `logits` and `labels`. + +Measures the probability error in discrete classification tasks in which the +classes are mutually exclusive (each entry is in exactly one class). For +example, each CIFAR-10 image is labeled with one and only one label: an image +can be a dog or a truck, but not both. + +**NOTE:**: For this operation, the probability of a given label is considered +exclusive. That is, soft classes are not allowed, and the `labels` vector +must provide a single specific index for the true class for each row of +`logits` (each minibatch entry). For soft softmax classification with +a probability distribution for each entry, see +`softmax_cross_entropy_with_logits`. + +**WARNING:** This op expects unscaled logits, since it performs a `softmax` +on `logits` internally for efficiency. Do not call this op with the +output of `softmax`, as it will produce incorrect results. + +`logits` and must have the shape `[batch_size, num_classes]` +and the dtype (either `float32` or `float64`). + +`labels` must have the shape `[batch_size]` and the dtype `int64`. + +##### Args: + + +* <b>`logits`</b>: Unscaled log probabilities. +* <b>`labels`</b>: Each entry `labels[i]` must be an index in `[0, num_classes)`. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the + softmax cross entropy loss. + + ## Embeddings diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md index 1b442309f9..f89dd1936f 100644 --- a/tensorflow/g3doc/get_started/os_setup.md +++ b/tensorflow/g3doc/get_started/os_setup.md @@ -448,6 +448,29 @@ Setting up Cuda nvvm Configuration finished ``` +##### Using a different Cuda SDK and Cudnn versions +TensorFlow officially supports Cuda 7.0 and Cudnn V2 (6.5) at this point. In +order to use a different Cuda SDK or Cudnn libraries, use the unofficial setting +with "configure" + +```bash +$ TF_UNOFFICIAL_SETTING=1 ./configure +... +Please specify the Cuda SDK version you want to use. [Default is 7.0]: 7.5 +Please specify the location where CUDA 7.5 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: /usr/local/cuda-7.5 +Please specify the Cudnn version you want to use. [Default is 6.5]: 4.0.4 +Please specify the location where cuDNN 4.0.4 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda-7.5]: /usr/local/cudnn-r4-rc/ +... +Setting up Cuda include +Setting up Cuda lib64 +Setting up Cuda bin +Setting up Cuda nvvm +Configuration finished +``` + +For the Cudnn libraries, use '6.5' for R2, '7.0' for R3, and '4.0.4' for +R4-RC. + ##### Known issues * Although it is possible to build both Cuda and non-Cuda configs under the same diff --git a/tensorflow/g3doc/how_tos/adding_an_op/index.md b/tensorflow/g3doc/how_tos/adding_an_op/index.md index 116d788449..119935f598 100644 --- a/tensorflow/g3doc/how_tos/adding_an_op/index.md +++ b/tensorflow/g3doc/how_tos/adding_an_op/index.md @@ -220,9 +220,9 @@ This asserts that the input is a vector, and returns having set the for its `SetStatus()` method. * The condition. For example, there are functions for validating the shape of a tensor in - [`tensorflow/core/public/tensor_shape.h`](https://www.tensorflow.org/code/tensorflow/core/public/tensor_shape.h) + [`tensorflow/core/framework/tensor_shape.h`](https://www.tensorflow.org/code/tensorflow/core/framework/tensor_shape.h) * The error itself, which is represented by a `Status` object, see - [`tensorflow/core/public/status.h`](https://www.tensorflow.org/code/tensorflow/core/public/status.h). A + [`tensorflow/core/lib/core/status.h`](https://www.tensorflow.org/code/tensorflow/core/lib/core/status.h). A `Status` has both a type (frequently `InvalidArgument`, but see the list of types) and a message. Functions for constructing an error may be found in [`tensorflow/core/lib/core/errors.h`][validation-macros]. diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index ed9481c330..126af9f342 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -582,6 +582,7 @@ tf_gen_op_wrapper_py( "AvgPoolGrad", # "*Grad" accessible through nn_grad instead of nn_ops. "BatchNormWithGlobalNormalizationGrad", "SoftmaxCrossEntropyWithLogits", + "SparseSoftmaxCrossEntropyWithLogits", "LRNGrad", "MaxPoolGrad", "MaxPoolGradWithArgmax", diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 918e7e4da6..1fda123150 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -50,8 +50,15 @@ class SessionInterface(object): def run(self, fetches, feed_dict=None): """Runs operations in the session. See `Session.run()` for details.""" - raise NotImplementedError('Run') + raise NotImplementedError('run') + def partial_run_setup(self, fetches, feeds=None): + """Sets up the feeds and fetches for partial runs in the session.""" + raise NotImplementedError('partial_run_setup') + + def partial_run(self, handle, fetches, feed_dict=None): + """Continues the execution with additional feeds and fetches.""" + raise NotImplementedError('partial_run') def _get_indexed_slices_value_from_fetches(fetched_vals): return ops.IndexedSlicesValue(fetched_vals[0], fetched_vals[1], @@ -213,11 +220,12 @@ class BaseSession(SessionInterface): return ops.default_session(self) # Eventually, this registration could be opened up to support custom - # Tensor expansions. Expects tuples of (Type, fetch_fn, feed_fn), + # Tensor expansions. Expects tuples of (Type, fetch_fn, feed_fn1, feed_fn2), # where the signatures are: # fetch_fn : Type -> (list of Tensors, # lambda: list of fetched np.ndarray -> TypeVal) - # feed_fn : Type, TypeVal -> list of (Tensor, value) + # feed_fn1 : Type, TypeVal -> list of (Tensor, value) + # feed_fn2 : Type -> list of Tensors # Conceptually, fetch_fn describes how to expand fetch into its # component Tensors and how to contracting the fetched results back into # a single return value. feed_fn describes how to unpack a single fed @@ -231,7 +239,8 @@ class BaseSession(SessionInterface): [fetch.indices, fetch.values, fetch.shape], lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)), lambda feed, feed_val: list(zip( - [feed.indices, feed.values, feed.shape], feed_val))), + [feed.indices, feed.values, feed.shape], feed_val)), + lambda feed: [feed.indices, feed.values, feed.shape]), # IndexedSlices are fetched as IndexedSlicesValues. They can be fed # IndexedSlicesValues or normal tuples. (ops.IndexedSlices, @@ -239,11 +248,14 @@ class BaseSession(SessionInterface): [fetch.values, fetch.indices] if fetch.dense_shape is None else [fetch.values, fetch.indices, fetch.dense_shape], _get_indexed_slices_value_from_fetches), - _get_feeds_for_indexed_slices), + _get_feeds_for_indexed_slices, + lambda feed: [feed.values, feed.indices] if feed.dense_shape is None + else [feed.values, feed.indices, feed.dense_shape]), # The default catches all types and performs no expansions. (object, lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]), - lambda feed, feed_val: [(feed, feed_val)])] + lambda feed, feed_val: [(feed, feed_val)], + lambda feed: [feed])] # pylint: enable=g-long-lambda def run(self, fetches, feed_dict=None): @@ -302,17 +314,67 @@ class BaseSession(SessionInterface): ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a `Tensor` that doesn't exist. """ - def _fetch_fn(fetch): - for tensor_type, fetch_fn, _ in BaseSession._REGISTERED_EXPANSIONS: - if isinstance(fetch, tensor_type): - return fetch_fn(fetch) - raise TypeError('Fetch argument %r has invalid type %r' - % (fetch, type(fetch))) + return self._run(None, fetches, feed_dict) - def _feed_fn(feed, feed_val): - for tensor_type, _, feed_fn in BaseSession._REGISTERED_EXPANSIONS: + def partial_run(self, handle, fetches, feed_dict=None): + """Continues the execution with more feeds and fetches. + + This is EXPERIMENTAL and subject to change. + + To use partial execution, a user first calls `partial_run_setup()` and + then a sequence of `partial_run()`. `partial_run_setup` specifies the + list of feeds and fetches that will be used in the subsequent + `partial_run` calls. + + Below is a simple example: + + a = array_ops.placeholder(dtypes.float32, shape=[]) + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + r2 = math_ops.mul(r1, c) + + h = sess.partial_run_setup([r1, r2], [a, b, c]) + res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) + res = sess.partial_run(h, r2, feed_dict={c: res}) + + Args: + handle: A handle for a sequence of partial runs. + fetches: A single graph element, or a list of graph elements + (described above). + feed_dict: A dictionary that maps graph elements to values + (described above). + + Returns: + Either a single value if `fetches` is a single graph element, or + a list of values if `fetches` is a list (described above). + """ + return self._run(handle, fetches, feed_dict) + + def partial_run_setup(self, fetches, feeds=None): + """Sets up a graph with feeds and fetches for partial run. + + This is EXPERIMENTAL and subject to change. + + Note that contrary to `run`, `feeds` only specifies the graph elements. + The tensors will be supplied by the subsequent `partial_run` calls. + + Args: + fetches: A single graph element, or a list of graph elements. + feeds: A single graph element, or a list of graph elements. + + Returns: + A handle for partial run. + + Raises: + RuntimeError: If this `Session` is in an invalid state (e.g. has been + closed). + TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. + """ + def _feed_fn(feed): + for tensor_type, _, _, feed_fn in BaseSession._REGISTERED_EXPANSIONS: if isinstance(feed, tensor_type): - return feed_fn(feed, feed_val) + return feed_fn(feed) raise TypeError('Feed argument %r has invalid type %r' % (feed, type(feed))) @@ -324,6 +386,46 @@ class BaseSession(SessionInterface): 'graph before calling run().') # Validate and process fetches. + unique_fetches, target_list, _ = self._process_fetches(fetches) + + # Create request. + feed_list = [] + + # Validate and process feed_list. + is_list_feed = isinstance(feeds, (list, tuple)) + if not is_list_feed: + feeds = [feeds] + for feed in feeds: + for subfeed in _feed_fn(feed): + try: + subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True, + allow_operation=False) + feed_list.append(compat.as_bytes(subfeed_t.name)) + except Exception as e: + e.message = ('Cannot interpret feed_list key as Tensor: ' + + e.message) + e.args = (e.message,) + raise e + + # Set up a graph with feeds and fetches for partial run. + def _setup_fn(session, feed_list, fetch_list, target_list): + self._extend_graph() + return tf_session.TF_PRunSetup(session, feed_list, fetch_list, + target_list) + + return self._do_call(_setup_fn, self._session, feed_list, unique_fetches, + target_list) + + def _process_fetches(self, fetches): + """Validate and process fetches.""" + def _fetch_fn(fetch): + for tensor_type, fetch_fn, _, _ in BaseSession._REGISTERED_EXPANSIONS: + if isinstance(fetch, tensor_type): + return fetch_fn(fetch) + raise TypeError('Fetch argument %r has invalid type %r' + % (fetch, type(fetch))) + + # Validate and process fetches. is_list_fetch = isinstance(fetches, (list, tuple)) if not is_list_fetch: fetches = [fetches] @@ -357,6 +459,26 @@ class BaseSession(SessionInterface): fetch_info.append((subfetch_names, fetch_contraction_fn)) unique_fetch_targets = list(unique_fetch_targets) + return unique_fetch_targets, target_list, fetch_info + + def _run(self, handle, fetches, feed_dict): + """Perform either run or partial_run, depending the exitence of `handle`.""" + def _feed_fn(feed, feed_val): + for tensor_type, _, feed_fn, _ in BaseSession._REGISTERED_EXPANSIONS: + if isinstance(feed, tensor_type): + return feed_fn(feed, feed_val) + raise TypeError('Feed argument %r has invalid type %r' + % (feed, type(feed))) + + # Check session. + if self._closed: + raise RuntimeError('Attempted to use a closed Session.') + if self.graph.version == 0: + raise RuntimeError('The Session graph is empty. Add operations to the ' + 'graph before calling run().') + + # Validate and process fetches. + unique_fetches, target_list, fetch_info = self._process_fetches(fetches) # Create request. feed_dict_string = {} @@ -386,13 +508,14 @@ class BaseSession(SessionInterface): feed_dict_string[compat.as_bytes(subfeed_t.name)] = np_val # Run request and get response. - results = self._do_run(target_list, unique_fetch_targets, feed_dict_string) + results = self._do_run(handle, target_list, unique_fetches, + feed_dict_string) # User may have fetched the same tensor multiple times, but we # only fetch them from the runtime once. Furthermore, they may # be wrapped as a tuple of tensors. Here we map the results back # to what the client asked for. - fetched_results = dict(zip(unique_fetch_targets, results)) + fetched_results = dict(zip(unique_fetches, results)) ret = [] for fetch_names, fetch_contraction_fn in fetch_info: if fetch_names: @@ -401,7 +524,7 @@ class BaseSession(SessionInterface): else: ret.append(None) - if is_list_fetch: + if isinstance(fetches, (list, tuple)): return ret else: return ret[0] @@ -409,10 +532,11 @@ class BaseSession(SessionInterface): # Captures the name of a node in an error status. _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =') - def _do_run(self, target_list, fetch_list, feed_dict): + def _do_run(self, handle, target_list, fetch_list, feed_dict): """Runs a step based on the given fetches and feeds. Args: + handle: a handle for partial_run. None if this is just a call to run(). target_list: A list of byte arrays corresponding to names of tensors or operations to be run to, but not fetched. fetch_list: A list of byte arrays corresponding to names of tensors to @@ -426,28 +550,26 @@ class BaseSession(SessionInterface): name of an operation, the first Tensor output of that operation will be returned for that element. """ - try: + def _run_fn(session, feed_dict, fetch_list, target_list): # Ensure any changes to the graph are reflected in the runtime. - with self._extend_lock: - if self._graph.version > self._current_version: - graph_def = self._graph.as_graph_def( - from_version=self._current_version) + self._extend_graph() + return tf_session.TF_Run(session, feed_dict, fetch_list, target_list) - try: - status = tf_session.TF_NewStatus() - tf_session.TF_ExtendGraph( - self._session, graph_def.SerializeToString(), status) - if tf_session.TF_GetCode(status) != 0: - raise RuntimeError(compat.as_text(tf_session.TF_Message(status))) - self._opened = True - finally: - tf_session.TF_DeleteStatus(status) - - self._current_version = self._graph.version + def _prun_fn(session, handle, feed_dict, fetch_list): + if target_list: + raise RuntimeError('partial_run() requires empty target_list.') + return tf_session.TF_PRun(session, handle, feed_dict, fetch_list) - return tf_session.TF_Run(self._session, feed_dict, fetch_list, - target_list) + if handle is None: + return self._do_call(_run_fn, self._session, feed_dict, fetch_list, + target_list) + else: + return self._do_call(_prun_fn, self._session, handle, feed_dict, + fetch_list) + def _do_call(self, fn, *args): + try: + return fn(*args) except tf_session.StatusNotOK as e: e_type, e_value, e_traceback = sys.exc_info() error_message = compat.as_text(e.error_message) @@ -466,6 +588,25 @@ class BaseSession(SessionInterface): # pylint: enable=protected-access six.reraise(e_type, e_value, e_traceback) + def _extend_graph(self): + # Ensure any changes to the graph are reflected in the runtime. + with self._extend_lock: + if self._graph.version > self._current_version: + graph_def = self._graph.as_graph_def( + from_version=self._current_version) + + try: + status = tf_session.TF_NewStatus() + tf_session.TF_ExtendGraph( + self._session, graph_def.SerializeToString(), status) + if tf_session.TF_GetCode(status) != 0: + raise RuntimeError(compat.as_text(tf_session.TF_Message(status))) + self._opened = True + finally: + tf_session.TF_DeleteStatus(status) + + self._current_version = self._graph.version + class Session(BaseSession): """A class for running TensorFlow operations. diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index c7d5453f57..39f720629a 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -811,5 +811,50 @@ class SessionTest(test_util.TensorFlowTestCase): sess_2.run(c_1.op) self.assertEqual(2.0, sess_2.run(c_2)) + def testPartialRun(self): + with session.Session() as sess: + a = array_ops.placeholder(dtypes.float32, shape=[]) + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + r2 = math_ops.mul(r1, c) + + h = sess.partial_run_setup([r1, r2], [a, b, c]) + res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) + self.assertEqual(3, res) + temp = res * 17 + res = sess.partial_run(h, r2, feed_dict={c: temp}) + self.assertEqual(153, res) + + def testPartialRunIncomplete(self): + with session.Session() as sess: + a = array_ops.placeholder(dtypes.float32, shape=[]) + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + r2 = math_ops.mul(r1, c) + + h = sess.partial_run_setup([r1, r2], [a, b, c]) + res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) + self.assertEqual(3, res) + + def testConcurrentPartialRun(self): + with session.Session() as sess: + a = array_ops.placeholder(dtypes.float32, shape=[]) + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + r2 = math_ops.mul(r1, c) + + h1 = sess.partial_run_setup([r1], [a, b, c]) + h2 = sess.partial_run_setup([r1, r2], [a, b, c]) + res = sess.partial_run(h1, r1, feed_dict={a: 1, b: 2}) + self.assertEqual(3, res) + temp = res * 19 + res = sess.partial_run(h2, r1, feed_dict={a: temp, b: 9}) + self.assertEqual(66, res) + res = sess.partial_run(h2, r2, feed_dict={c: 7}) + self.assertEqual(462, res) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 3bf7623eb2..8e8e3132f1 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -170,6 +170,10 @@ tensorflow::ImportNumpy(); tensorflow::PyObjectVector temp) { $1 = &temp; } +%typemap(in, numinputs=0) char** out_handle ( + char* temp) { + $1 = &temp; +} // Raise a StatusNotOK exception if the out_status is not OK; // otherwise build a Python list of outputs and return it. @@ -196,6 +200,19 @@ tensorflow::ImportNumpy(); } } +// Raise a StatusNotOK exception if the out_status is not OK; +// otherwise return the handle as a python string object. +%typemap(argout, fragment="StatusNotOK") ( + tensorflow::Status* out_status, char** out_handle) { + if (!$1->ok()) { + RaiseStatusNotOK(*$1, $descriptor(tensorflow::Status*)); + SWIG_fail; + } else { + $result = PyString_FromStringAndSize(*$2, strlen(*$2)); + delete *$2; + } +} + //////////////////////////////////////////////////////////////////////////////// // END TYPEMAPS FOR tensorflow::TF_Run_wrapper() //////////////////////////////////////////////////////////////////////////////// @@ -264,6 +281,26 @@ tensorflow::ImportNumpy(); %unignore TF_Run; %unignore EqualGraphDefWrapper; +// Include the wrapper for TF_PRunSetup from tf_session_helper.h. + +// The %exception block above releases the Python GIL for the length +// of each wrapped method. We disable this behavior for TF_PRunSetup +// because it uses the Python allocator. +%noexception tensorflow::TF_PRunSetup_wrapper; +%rename(TF_PRunSetup) tensorflow::TF_PRunSetup_wrapper; +%unignore tensorflow; +%unignore TF_PRunSetup; + +// Include the wrapper for TF_PRun from tf_session_helper.h. + +// The %exception block above releases the Python GIL for the length +// of each wrapped method. We disable this behavior for TF_PRun +// because it uses the Python allocator. +%noexception tensorflow::TF_PRun_wrapper; +%rename(TF_PRun) tensorflow::TF_PRun_wrapper; +%unignore tensorflow; +%unignore TF_PRun; + %include "tensorflow/python/client/tf_session_helper.h" %unignoreall diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index 54ebc2447e..2698230b11 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -425,13 +425,11 @@ Safe_PyObjectPtr make_safe(PyObject* o) { return Safe_PyObjectPtr(o, Py_DECREF_wrapper); } -// Wrapper for TF_Run that converts the arguments to appropriate types. -// If *out_status is OK, the caller becomes the owner of the PyObjects -// in *out_values. -void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs, - const NameVector& output_names, - const NameVector& target_nodes, Status* out_status, - PyObjectVector* out_values) { +void TF_Run_wrapper_helper(TF_Session* session, const char* handle, + const FeedVector& inputs, + const NameVector& output_names, + const NameVector& target_nodes, Status* out_status, + PyObjectVector* out_values) { // 1. Convert the feed inputs to the appropriate form for TF_Run. NameVector input_names; Safe_PyObjectVector @@ -514,10 +512,20 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs, // 3. Actually call TF_Run(). Py_BEGIN_ALLOW_THREADS; - TF_Run(session, input_names.data(), inputs_unsafe.data(), input_names.size(), - const_cast<const char**>(output_names.data()), outputs.data(), - output_names.size(), const_cast<const char**>(target_nodes.data()), - target_nodes.size(), status.get()); + if (handle == nullptr) { + TF_Run(session, input_names.data(), inputs_unsafe.data(), + input_names.size(), const_cast<const char**>(output_names.data()), + outputs.data(), output_names.size(), + const_cast<const char**>(target_nodes.data()), target_nodes.size(), + status.get()); + } else { + TF_PRun(session, handle, input_names.data(), inputs_unsafe.data(), + input_names.size(), const_cast<const char**>(output_names.data()), + outputs.data(), output_names.size(), + const_cast<const char**>(target_nodes.data()), target_nodes.size(), + status.get()); + } + Py_END_ALLOW_THREADS; // 4. The TensorFlow runtime has taken ownership of the fed tensors, @@ -558,6 +566,49 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs, *out_status = Status::OK(); } +// Wrapper for TF_Run that converts the arguments to appropriate types. +// If *out_status is OK, the caller becomes the owner of the PyObjects +// in *out_values. +void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs, + const NameVector& output_names, + const NameVector& target_nodes, Status* out_status, + PyObjectVector* out_values) { + TF_Run_wrapper_helper(session, nullptr, inputs, output_names, target_nodes, + out_status, out_values); +} + +// Wrapper for TF_PRunSetup that converts the arguments to appropriate types. +// If *out_status is OK, the caller becomes the owner of *out_handle. +void TF_PRunSetup_wrapper(TF_Session* session, const NameVector& input_names, + const NameVector& output_names, + const NameVector& target_nodes, Status* out_status, + char** out_handle) { + Safe_TF_StatusPtr status = make_safe(TF_NewStatus()); + Py_BEGIN_ALLOW_THREADS; + TF_PRunSetup( + session, const_cast<const char**>(input_names.data()), input_names.size(), + const_cast<const char**>(output_names.data()), output_names.size(), + const_cast<const char**>(target_nodes.data()), target_nodes.size(), + out_handle, status.get()); + Py_END_ALLOW_THREADS; + + if (TF_GetCode(status.get()) != TF_OK) { + *out_status = TF_Status_to_Status(status.get()); + return; + } + *out_status = Status::OK(); +} + +// Wrapper for TF_PRun that converts the arguments to appropriate types. +// If *out_status is OK, the caller becomes the owner of the PyObjects +// in *out_values. +void TF_PRun_wrapper(TF_Session* session, const char* handle, + const FeedVector& inputs, const NameVector& output_names, + Status* out_status, PyObjectVector* out_values) { + TF_Run_wrapper_helper(session, handle, inputs, output_names, NameVector(), + out_status, out_values); +} + void ImportNumpy() { import_array1(); } string EqualGraphDefWrapper(const string& actual, const string& expected) { diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 60bbab69f3..a9a38da6e5 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -65,7 +65,7 @@ typedef std::vector<Safe_PyObjectPtr> Safe_PyObjectVector; Safe_PyObjectPtr make_safe(PyObject* o); // Run the graph associated with the session starting with the -// supplied inputs[]. Regardless of success of failure, inputs[] are +// supplied inputs[]. Regardless of success or failure, inputs[] are // stolen by the implementation (i.e. the implementation will // eventually call Py_DECREF on each array input). // @@ -80,6 +80,35 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs, const NameVector& target_nodes, Status* out_status, PyObjectVector* out_values); +// Set up the graph with the intended feeds and fetches for partial run. +// *out_handle is owned by the caller. +// +// On success, returns a handle that is used for subsequent PRun calls. +// +// On failure, out_status contains a tensorflow::Status with an error +// message. +// +// NOTE: This is EXPERIMENTAL and subject to change. +void TF_PRunSetup_wrapper(TF_Session* session, const NameVector& input_names, + const NameVector& output_names, + const NameVector& target_nodes, Status* out_status, + char** out_handle); + +// Continue to run the graph with additional feeds and fetches. The +// execution state is uniquely identified by the handle. +// +// On success, the tensors corresponding to output_names[0,noutputs-1] +// are placed in out_values[], and these outputs[] become the property +// of the caller (the caller must eventually call Py_DECREF on them). +// +// On failure, out_status contains a tensorflow::Status with an error +// message. +// +// NOTE: This is EXPERIMENTAL and subject to change. +void TF_PRun_wrapper(TF_Session* session, const char* handle, + const FeedVector& inputs, const NameVector& output_names, + Status* out_status, PyObjectVector* out_values); + // Import numpy. This wrapper function exists so that the // PY_ARRAY_UNIQUE_SYMBOL can be safely defined in a .cc file to // avoid weird linking issues. diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 4845cadbd3..f4389e58b2 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -3338,6 +3338,8 @@ class GraphKeys(object): MOVING_AVERAGE_VARIABLES = "moving_average_variables" # Key to collected regularization losses at graph construction. REGULARIZATION_LOSSES = "regularization_losses" + # Key to collect concatenated sharded variables. + CONCATENATED_VARIABLES = "concatenated_variables" def add_to_collection(name, value): diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 5f593bce3d..c3bcfed900 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -114,20 +114,6 @@ void AppendWithinWidth(string* dest, StringPiece append, int width) { } } -void RemoveDescriptionsFromOpDef(OpDef* op_def) { - for (int i = 0; i < op_def->input_arg_size(); ++i) { - op_def->mutable_input_arg(i)->clear_description(); - } - for (int i = 0; i < op_def->output_arg_size(); ++i) { - op_def->mutable_output_arg(i)->clear_description(); - } - for (int i = 0; i < op_def->attr_size(); ++i) { - op_def->mutable_attr(i)->clear_description(); - } - op_def->clear_summary(); - op_def->clear_description(); -} - // Like DataTypeString() but uses the Python names for the // float types. string PythonDataTypeString(DataType dtype) { diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index c9344d63ff..eac25fb2c6 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -255,7 +255,7 @@ class LSTMTest(tf.test.TestCase): input_size = 5 batch_size = 2 num_proj = 4 - num_proj_shards = 4 + num_proj_shards = 3 num_unit_shards = 2 max_length = 8 with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: @@ -281,12 +281,37 @@ class LSTMTest(tf.test.TestCase): input_value = np.random.randn(batch_size, input_size) sess.run(outputs, feed_dict={inputs[0]: input_value}) + def _testTooManyShards(self, use_gpu): + num_units = 3 + input_size = 5 + num_proj = 4 + num_proj_shards = 4 + num_unit_shards = 2 + max_length = 8 + with self.test_session(use_gpu=use_gpu, graph=tf.Graph()): + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) + + inputs = max_length * [ + tf.placeholder(tf.float32, shape=(None, input_size))] + + cell = tf.nn.rnn_cell.LSTMCell( + num_units, + input_size=input_size, + use_peepholes=True, + num_proj=num_proj, + num_unit_shards=num_unit_shards, + num_proj_shards=num_proj_shards, + initializer=initializer) + + with self.assertRaises(ValueError): + tf.nn.rnn(cell, inputs, dtype=tf.float32) + def _testDoubleInput(self, use_gpu): num_units = 3 input_size = 5 batch_size = 2 num_proj = 4 - num_proj_shards = 4 + num_proj_shards = 3 num_unit_shards = 2 max_length = 8 with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: @@ -318,7 +343,7 @@ class LSTMTest(tf.test.TestCase): input_size = 5 batch_size = 2 num_proj = 4 - num_proj_shards = 4 + num_proj_shards = 3 num_unit_shards = 2 max_length = 8 with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: @@ -369,7 +394,7 @@ class LSTMTest(tf.test.TestCase): input_size = 5 batch_size = 2 num_proj = 4 - num_proj_shards = 4 + num_proj_shards = 3 num_unit_shards = 2 max_length = 8 with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: @@ -494,6 +519,10 @@ class LSTMTest(tf.test.TestCase): self._testProjSharding(use_gpu=False) self._testProjSharding(use_gpu=True) + def testTooManyShards(self): + self._testTooManyShards(use_gpu=False) + self._testTooManyShards(use_gpu=True) + def testShardNoShardEquivalentOutput(self): self._testShardNoShardEquivalentOutput(use_gpu=False) self._testShardNoShardEquivalentOutput(use_gpu=True) diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py new file mode 100644 index 0000000000..13c519857d --- /dev/null +++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py @@ -0,0 +1,256 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for SparseSoftmaxCrossEntropyWithLogits op.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=g-bad-import-order + +# pylint: disable=unused-import +import tensorflow.python.platform +# pylint: enable=unused_import + +import numpy as np +import tensorflow as tf +import sys +import time + +from tensorflow.python.client import graph_util +from tensorflow.python.ops import sparse_ops + +# pylint: enable=g-bad-import-order + + +class SparseXentTest(tf.test.TestCase): + + def _npXent(self, features, labels): + batch_dim = 0 + class_dim = 1 + batch_size = features.shape[batch_dim] + e = np.exp(features - + np.reshape(np.amax(features, axis=class_dim), [batch_size, 1])) + probs = e / np.reshape(np.sum(e, axis=class_dim), [batch_size, 1]) + labels_mat = np.zeros_like(probs).astype(probs.dtype) + labels_mat[np.arange(batch_size), labels] = 1.0 + bp = (probs - labels_mat) + l = -np.sum(labels_mat * np.log(probs + 1.0e-20), axis=1) + return l, bp + + def _testXent(self, np_features, np_labels, use_gpu=False): + np_loss, np_backprop = self._npXent(np_features, np_labels) + with self.test_session(use_gpu=use_gpu) as sess: + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + np_features, np_labels) + backprop = loss.op.outputs[1] + tf_loss, tf_backprop = sess.run([loss, backprop]) + self.assertAllClose(np_loss, tf_loss) + self.assertAllClose(np_backprop, tf_backprop) + + def _testAll(self, features, labels): + self._testXent(features, labels, use_gpu=False) + self._testXent(features, labels, use_gpu=True) + + def _testSingleClass(self, use_gpu=False): + with self.test_session(use_gpu=use_gpu) as sess: + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + np.array([[1.], [-1.], [0.]]).astype(np.float32), + np.array([1, 1, 1]).astype(np.int64)) + backprop = loss.op.outputs[1] + tf_loss, tf_backprop = sess.run([loss, backprop]) + # loss = -1.0*log(1.0), 1.0*log(1.0), 0.0*log(0.0) == 1.0 + self.assertAllClose([0.0, 0.0, 0.0], tf_loss) + self.assertAllClose([[2.0], [1.0], [0.0]], tf_backprop) + + # def testSingleClass(self): + # self._testSingleClass(True) + # self._testSingleClass(False) + + def testRankTooLarge(self): + np_features = np.array( + [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]).astype(np.float32) + np_labels = np.array([1, 2]).astype(np.int64) + self.assertRaisesRegexp( + ValueError, "must have rank 2", + tf.nn.sparse_softmax_cross_entropy_with_logits, np_features, np_labels) + + def testNpXent(self): + # We create 2 batches of logits for testing. + # batch 0 is the boring uniform distribution: 1, 1, 1, 1, with target 3. + # batch 1 has a bit of difference: 1, 2, 3, 4, with target 0. + features = [[1., 1., 1., 1.], [1., 2., 3., 4.]] + labels = [3, 0] + + # For batch 0, we expect the uniform distribution: 0.25, 0.25, 0.25, 0.25 + # With a hard target 3, the backprop is [0.25, 0.25, 0.25, -0.75] + # The loss for this batch is -log(0.25) = 1.386 + # + # For batch 1, we have: + # exp(0) = 1 + # exp(1) = 2.718 + # exp(2) = 7.389 + # exp(3) = 20.085 + # SUM = 31.192 + # So we have as probabilities: + # exp(0) / SUM = 0.032 + # exp(1) / SUM = 0.087 + # exp(2) / SUM = 0.237 + # exp(3) / SUM = 0.644 + # With a hard 1, the backprop is [0.032 - 1.0 = -0.968, 0.087, 0.237, 0.644] + # The loss for this batch is [1.0 * -log(0.25), 1.0 * -log(0.032)] + # = [1.3862, 3.4420] + np_loss, np_backprop = self._npXent( + np.array(features), np.array(labels, dtype=np.int64)) + self.assertAllClose(np.array([[0.25, 0.25, 0.25, -0.75], + [-0.968, 0.087, 0.237, 0.6439]]), + np_backprop, + rtol=1.e-3, atol=1.e-3) + self.assertAllClose(np.array([1.3862, 3.4420]), np_loss, + rtol=1.e-3, atol=1.e-3) + + def testShapeMismatch(self): + with self.test_session(): + with self.assertRaises(ValueError): + tf.nn.sparse_softmax_cross_entropy_with_logits( + [[0., 1.], [2., 3.]], [[0, 2]]) + + def testNotMatrix(self): + with self.test_session(): + with self.assertRaises(ValueError): + tf.nn.sparse_softmax_cross_entropy_with_logits( + [0., 1., 2., 3.], [0, 2]) + + def testFloat(self): + self._testAll( + np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32), + np.array([3, 0]).astype(np.int64)) + + def testDouble(self): + self._testXent( + np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64), + np.array([0, 3]).astype(np.int64), + use_gpu=False) + + def testGradient(self): + with self.test_session(): + l = tf.constant([3, 0, 1], dtype=tf.int64, name="l") + f = tf.constant([0.1, 0.2, 0.3, 0.4, + 0.1, 0.4, 0.9, 1.6, + 0.1, 0.8, 2.7, 6.4], shape=[3, 4], + dtype=tf.float64, name="f") + x = tf.nn.sparse_softmax_cross_entropy_with_logits(f, l, name="xent") + err = tf.test.compute_gradient_error(f, [3, 4], x, [3]) + print("cross entropy gradient err = ", err) + self.assertLess(err, 5e-8) + + +def _sparse_vs_dense_xent_benchmark_dense(labels, logits): + labels = tf.identity(labels) + logits = tf.identity(logits) + with tf.device("/cpu:0"): # Sparse-to-dense must be on CPU + batch_size = tf.shape(logits)[0] + num_entries = tf.shape(logits)[1] + length = batch_size * num_entries + labels += num_entries * tf.range(batch_size) + target = sparse_ops.sparse_to_dense( + labels, tf.pack([length]), 1.0, 0.0) + target = tf.reshape(target, tf.pack([-1, num_entries])) + crossent = tf.nn.softmax_cross_entropy_with_logits( + logits, target, name="SequenceLoss/CrossEntropy") + crossent_sum = tf.reduce_sum(crossent) + grads = tf.gradients([crossent_sum], [logits])[0] + + return (crossent_sum, grads) + + +def _sparse_vs_dense_xent_benchmark_sparse(labels, logits): + # Using sparse_softmax_cross_entropy_with_logits + labels = labels.astype(np.int64) + labels = tf.identity(labels) + logits = tf.identity(logits) + crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits, labels, name="SequenceLoss/CrossEntropy") + crossent_sum = tf.reduce_sum(crossent) + grads = tf.gradients([crossent_sum], [logits])[0] + + return (crossent_sum, grads) + + +def sparse_vs_dense_xent_benchmark(batch_size, num_entries, use_gpu): + config = tf.ConfigProto() + config.allow_soft_placement = True + config.gpu_options.per_process_gpu_memory_fraction = 0.3 + labels = np.random.randint(num_entries, size=batch_size).astype(np.int32) + logits = np.random.randn(batch_size, num_entries).astype(np.float32) + + def _timer(sess, ops): + # Warm in + for _ in range(20): + sess.run(ops) + + # Timing run + start = time.time() + for _ in range(20): + sess.run(ops) + end = time.time() + + return (end - start)/20.0 # Average runtime per iteration + + # Using sparse_to_dense and softmax_cross_entropy_with_logits + with tf.Session(config=config) as sess: + if not use_gpu: + with tf.device(graph_util.pin_to_cpu): + ops = _sparse_vs_dense_xent_benchmark_dense(labels, logits) + else: + ops = _sparse_vs_dense_xent_benchmark_dense(labels, logits) + delta_dense = _timer(sess, ops) + + # Using sparse_softmax_cross_entropy_with_logits + with tf.Session(config=config) as sess: + if not use_gpu: + with tf.device(graph_util.pin_to_cpu): + ops = _sparse_vs_dense_xent_benchmark_sparse(labels, logits) + else: + ops = _sparse_vs_dense_xent_benchmark_sparse(labels, logits) + delta_sparse = _timer(sess, ops) + + print( + "%d \t %d \t %s \t %f \t %f \t %f" + % (batch_size, num_entries, use_gpu, delta_dense, delta_sparse, + delta_sparse/delta_dense)) + + +def main(_): + print("Sparse Xent vs. SparseToDense + Xent") + print("batch \t depth \t gpu \t dt(dense) \t dt(sparse) " + "\t dt(sparse)/dt(dense)") + for use_gpu in (False, True): + for batch_size in (32, 64, 128): + for num_entries in (100, 1000, 10000): + sparse_vs_dense_xent_benchmark( + batch_size, num_entries, use_gpu) + sparse_vs_dense_xent_benchmark( + 32, 100000, use_gpu) + sparse_vs_dense_xent_benchmark( + 8, 1000000, use_gpu) + + +if __name__ == "__main__": + if "--benchmarks" in sys.argv: + sys.argv.remove("--benchmarks") + tf.app.run() + else: + tf.test.main() diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index bc8c8b07eb..196d2d0931 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -416,23 +416,20 @@ class TensorArrayTest(tf.test.TestCase): dtype=tf.float32, tensor_array_name="foo", size=3) time_0 = tf.identity(0) - def body(time, flow, state): + def body(time, h_t, state): sliced = tf.slice(v0, begin=tf.pack([time, 0]), size=[1, -1]) sliced = tf.squeeze(sliced) out = sliced + var + state state += sliced - h_n = h - h_n._flow = flow - h_n = h_n.write(time, out) - return (time+1, h_n.flow, state) + h_t = h_t.write(time, out) + return (time+1, h_t, state) - (unused_0, final_flow, unused_2) = control_flow_ops.While( + (unused_0, h_final, unused_2) = control_flow_ops.While( cond=lambda time, unused_1, unused_2: time < 3, body=body, - loop_vars=(time_0, h.flow, state0), + loop_vars=(time_0, h, state0), parallel_iterations=3) - h._flow = final_flow - vout = h.pack() + vout = h_final.pack() grad_val = -np.arange(3*5, dtype=np.float32).reshape(3, 5) v0_grad = tf.gradients([vout], [v0], [grad_val])[0] diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index d9caa55d5a..51455a9ee7 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -68,6 +68,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import six from six.moves import xrange # pylint: disable=redefined-builtin @@ -81,6 +82,7 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import tensor_array_ops # pylint: disable=wildcard-import,undefined-variable from tensorflow.python.ops.gen_control_flow_ops import * from tensorflow.python.platform import logging @@ -283,6 +285,25 @@ def _SwitchRefOrTensor(data, pred, name="Switch"): return switch(data, pred, name=name) +def _convert_tensorarrays_to_flows(tensors_or_tensor_arrays): + return [ta.flow if isinstance(ta, tensor_array_ops.TensorArray) + else ta + for ta in tensors_or_tensor_arrays] + + +def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows): + if len(tensors_or_tensorarrays) != len(tensors_or_flows): + raise ValueError( + "Lengths of original Tensor list and new list do not match: %d vs. %d" + % (len(tensors_or_tensorarrays), len(tensors_or_flows))) + return [ + tensor_array_ops.TensorArray( + dtype=ta.dtype, handle=ta.handle, flow=t_or_flow) + if isinstance(ta, tensor_array_ops.TensorArray) + else t_or_flow + for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)] + + class ControlFlowOpWrapper(object): """A wrapper class for Operation. @@ -1402,6 +1423,10 @@ class WhileContext(ControlFlowContext): def BuildLoop(self, pred, body, loop_vars): """Add the loop termination condition and body to the graph.""" + # Keep original_loop_vars to identify which are TensorArrays + original_loop_vars = loop_vars + # Connvert TensorArrays to their flow variables + loop_vars = _convert_tensorarrays_to_flows(loop_vars) loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars) # Let the context know the loop variabes so the loop variables # would be added in the outer contexts properly. @@ -1421,18 +1446,28 @@ class WhileContext(ControlFlowContext): self._pivot_for_pred = merge_vars[0] # Build the graph for pred. - c = ops.convert_to_tensor(pred(*merge_vars)) + merge_vars_with_tensor_arrays = ( + _convert_flows_to_tensorarrays(original_loop_vars, merge_vars)) + c = ops.convert_to_tensor(pred(*merge_vars_with_tensor_arrays)) self._pivot = loop_cond(c, name="LoopCond") switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars] # Build the graph for body. vars_for_body = [_Identity(x[1]) for x in switch_vars] self._pivot_for_body = vars_for_body[0] + # Convert TensorArray flow variables inside the context back into + # their associated TensorArrays for calling the body. + vars_for_body_with_tensor_arrays = ( + _convert_flows_to_tensorarrays(original_loop_vars, vars_for_body)) - body_result = body(*vars_for_body) - if not isinstance(body_result, (list, _basetuple)): + body_result = body(*vars_for_body_with_tensor_arrays) + if not isinstance(body_result, collections.Sequence): body_result = [body_result] - result = ops.convert_n_to_tensor_or_indexed_slices(body_result) + # Store body_result to keep track of TensorArrays returned by body + original_body_result = body_result + # Convert TensorArrays returned by body into their flow variables + result = _convert_tensorarrays_to_flows(body_result) + result = ops.convert_n_to_tensor_or_indexed_slices(result) next_vars = [_NextIteration(x) for x in result] # Add the back edges to complete the loop. @@ -1450,7 +1485,14 @@ class WhileContext(ControlFlowContext): # Exit the loop. self.ExitResult(exit_vars) - return exit_vars[0] if len(exit_vars) == 1 else exit_vars + + # Convert TensorArray flow variables outside the context back into + # their associated TensorArrays for returning to caller. + exit_vars_with_tensor_arrays = ( + _convert_flows_to_tensorarrays(original_body_result, exit_vars)) + return (exit_vars_with_tensor_arrays[0] + if len(exit_vars) == 1 + else exit_vars_with_tensor_arrays) def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True, @@ -1462,6 +1504,10 @@ def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True, tensors of the same length and with the same types as the input. `loop_vars` is a list of tensors that is passed to both `cond` and `body`. + In addition to regular Tensors or IndexedSlices, the body may accept and + return TensorArray objects. The flows of the TensorArray objects will + be appropriately forwarded between loops and during gradient calculations. + While `cond` evaluates to true, `body` is executed. Args: diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index c3d966527e..2dc3ab5095 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -151,6 +151,7 @@ TensorFlow provides several operations that help you perform classification. @@sigmoid_cross_entropy_with_logits @@softmax @@softmax_cross_entropy_with_logits +@@sparse_softmax_cross_entropy_with_logits ## Embeddings diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 0a7e3b214a..72ec25b96c 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -158,6 +158,14 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _): return _BroadcastMul(grad_0, op.outputs[1]), None +@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits") +def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _): + # grad_0 is the backprop for cost, and we multiply it with the gradients + # (which is output[1]) + # There is no gradient for the labels + return _BroadcastMul(grad_0, op.outputs[1]), None + + @ops.RegisterGradient("Conv2D") def _Conv2DGrad(op, grad): return [nn_ops.conv2d_backprop_input(array_ops.shape(op.inputs[0]), diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index bdc32360a3..6c0cd339e6 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -157,6 +157,12 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None): example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both. + **NOTE:**: While the classes are mutually exclusive, their probabilities + need not be. All that is required is that each row of `labels` is + a valid probability distribution. If using exclusive `labels` + (wherein one and only one class is true at a time), see + `sparse_softmax_cross_entropy_with_logits`. + **WARNING:** This op expects unscaled logits, since it performs a `softmax` on `logits` internally for efficiency. Do not call this op with the output of `softmax`, as it will produce incorrect results. @@ -180,6 +186,57 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None): return cost +def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None): + """Computes sparse softmax cross entropy between `logits` and `labels`. + + Measures the probability error in discrete classification tasks in which the + classes are mutually exclusive (each entry is in exactly one class). For + example, each CIFAR-10 image is labeled with one and only one label: an image + can be a dog or a truck, but not both. + + **NOTE:**: For this operation, the probability of a given label is considered + exclusive. That is, soft classes are not allowed, and the `labels` vector + must provide a single specific index for the true class for each row of + `logits` (each minibatch entry). For soft softmax classification with + a probability distribution for each entry, see + `softmax_cross_entropy_with_logits`. + + **WARNING:** This op expects unscaled logits, since it performs a `softmax` + on `logits` internally for efficiency. Do not call this op with the + output of `softmax`, as it will produce incorrect results. + + `logits` and must have the shape `[batch_size, num_classes]` + and the dtype (either `float32` or `float64`). + + `labels` must have the shape `[batch_size]` and the dtype `int64`. + + Args: + logits: Unscaled log probabilities. + labels: Each entry `labels[i]` must be an index in `[0, num_classes)`. + name: A name for the operation (optional). + + Returns: + A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the + softmax cross entropy loss. + """ + # The second output tensor contains the gradients. We use it in + # _CrossEntropyGrad() in nn_grad but not here. + cost, unused_backprop = gen_nn_ops._sparse_softmax_cross_entropy_with_logits( + logits, labels, name=name) + return cost + + +@ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits") +def _SparseSoftmaxCrossEntropyWithLogitsShape(op): + """Shape function for SparseSoftmaxCrossEntropyWithLogits op.""" + logits_shape = op.inputs[0].get_shape() + input_shape = logits_shape.with_rank(2) + batch_size = input_shape[0] + # labels_shape + op.inputs[1].get_shape().merge_with(tensor_shape.vector(batch_size)) + return [tensor_shape.vector(batch_size.value), input_shape] + + @ops.RegisterShape("SoftmaxCrossEntropyWithLogits") def _SoftmaxCrossEntropyWithLogitsShape(op): """Shape function for SoftmaxCrossEntropyWithLogits op.""" diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py index 28f33b1e1a..820af12626 100644 --- a/tensorflow/python/ops/rnn_cell.py +++ b/tensorflow/python/ops/rnn_cell.py @@ -210,24 +210,42 @@ class BasicLSTMCell(RNNCell): return new_h, array_ops.concat(1, [new_c, new_h]) -def _get_sharded_variable(name, shape, initializer, dtype, num_shards): - """Get a list of sharded variables with the given dtype and initializer.""" - unit_shard_size = int(math.ceil(shape[1] / num_shards)) +def _get_concat_variable(name, shape, dtype, num_shards): + """Get a sharded variable concatenated into one tensor.""" + sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) + if len(sharded_variable) == 1: + return sharded_variable[0] + + concat_name = name + "/concat" + concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0" + for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES): + if value.name == concat_full_name: + return value + + concat_variable = array_ops.concat(0, sharded_variable, name=concat_name) + ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, + concat_variable) + return concat_variable + + +def _get_sharded_variable(name, shape, dtype, num_shards): + """Get a list of sharded variables with the given dtype.""" + if num_shards > shape[0]: + raise ValueError("Too many shards: shape=%s, num_shards=%d" % + (shape, num_shards)) + unit_shard_size = int(math.floor(shape[0] / num_shards)) + remaining_rows = shape[0] - unit_shard_size * num_shards shards = [] for i in range(num_shards): - current_size = min(unit_shard_size, shape[1] - unit_shard_size * i) - shards.append(vs.get_variable(name + "_%d" % i, [shape[0], current_size], - initializer=initializer, dtype=dtype)) + current_size = unit_shard_size + if i < remaining_rows: + current_size += 1 + shards.append(vs.get_variable(name + "_%d" % i, [current_size, shape[1]], + dtype=dtype)) return shards -def _matmul_with_sharded_variable(tensor, sharded_tensor): - """Multiply tensor with each tensor in sharded_tensor, column-concatenated.""" - return array_ops.concat(1, [math_ops.matmul(tensor, shard) - for shard in sharded_tensor]) - - class LSTMCell(RNNCell): """Long short-term memory unit (LSTM) recurrent network cell. @@ -317,10 +335,11 @@ class LSTMCell(RNNCell): dtype = input_.dtype - with vs.variable_scope(scope or type(self).__name__): # "LSTMCell" - sharded_w = _get_sharded_variable( + with vs.variable_scope(scope or type(self).__name__, + initializer=self._initializer): # "LSTMCell" + concat_w = _get_concat_variable( "W", [self.input_size + num_proj, 4 * self._num_units], - self._initializer, dtype, self._num_unit_shards) + dtype, self._num_unit_shards) b = vs.get_variable( "B", shape=[4 * self._num_units], @@ -328,24 +347,17 @@ class LSTMCell(RNNCell): # i = input_gate, j = new_input, f = forget_gate, o = output_gate cell_inputs = array_ops.concat(1, [input_, m_prev]) - lstm_matrix = nn_ops.bias_add( - _matmul_with_sharded_variable(cell_inputs, sharded_w), b) + lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) i, j, f, o = array_ops.split(1, 4, lstm_matrix) # Diagonal connections if self._use_peepholes: w_f_diag = vs.get_variable( - "W_F_diag", shape=[self._num_units], - initializer=self._initializer, - dtype=dtype) + "W_F_diag", shape=[self._num_units], dtype=dtype) w_i_diag = vs.get_variable( - "W_I_diag", shape=[self._num_units], - initializer=self._initializer, - dtype=dtype) + "W_I_diag", shape=[self._num_units], dtype=dtype) w_o_diag = vs.get_variable( - "W_O_diag", shape=[self._num_units], - initializer=self._initializer, - dtype=dtype) + "W_O_diag", shape=[self._num_units], dtype=dtype) if self._use_peepholes: c = (sigmoid(f + 1 + w_f_diag * c_prev) * c_prev + @@ -362,11 +374,11 @@ class LSTMCell(RNNCell): m = sigmoid(o) * tanh(c) if self._num_proj is not None: - sharded_w_proj = _get_sharded_variable( - "W_P", [self._num_units, self._num_proj], self._initializer, + concat_w_proj = _get_concat_variable( + "W_P", [self._num_units, self._num_proj], dtype, self._num_proj_shards) - m = _matmul_with_sharded_variable(m, sharded_w_proj) + m = math_ops.matmul(m, concat_w_proj) return m, array_ops.concat(1, [c, m]) diff --git a/tensorflow/python/ops/tensor_array_grad.py b/tensorflow/python/ops/tensor_array_grad.py index 23eb02c39c..05cb4f12ad 100644 --- a/tensorflow/python/ops/tensor_array_grad.py +++ b/tensorflow/python/ops/tensor_array_grad.py @@ -102,9 +102,8 @@ def _TensorArrayWriteGrad(op, flow): dtype = op.get_attr("T") grad_source = _GetGradSource(flow) g = tensor_array_ops.TensorArray(size=None, dtype=dtype, handle=handle).grad( - source=grad_source) - with ops.control_dependencies([flow]): - grad = g.read(index) + source=grad_source, flow=flow) + grad = g.read(index) return [None, None, grad, flow] @@ -144,8 +143,7 @@ def _TensorArrayUnpackGrad(op, flow): dtype = op.get_attr("T") grad_source = _GetGradSource(flow) g = tensor_array_ops.TensorArray(size=None, dtype=dtype, handle=handle).grad( - source=grad_source) - with ops.control_dependencies([flow]): - grad = g.pack() + source=grad_source, flow=flow) + grad = g.pack() return [None, grad, flow] # pylint: enable=protected-access diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index c12506694f..901dfbe913 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -46,8 +46,8 @@ class TensorArray(object): @@grad """ - def __init__( - self, dtype, size=None, tensor_array_name=None, handle=None, name=None): + def __init__(self, dtype, size=None, tensor_array_name=None, + handle=None, flow=None, name=None): """Construct a new TensorArray or wrap an existing TensorArray handle. Args: @@ -59,6 +59,8 @@ class TensorArray(object): set, handle should be None. handle: (optional) A `Tensor` handle to an existing TensorArray. If this is set, tensor_array_name should be None. + flow: (optional) A float `Tensor` scalar coming from an existing + TensorArray.flow. name: A name for the operation (optional). Raises: @@ -73,16 +75,15 @@ class TensorArray(object): if handle is None and size is None: raise ValueError("Size must be provided if handle is not provided") - with ops.op_scope([handle, size], name, "TensorArray") as scope: + self._dtype = dtype + with ops.op_scope([handle, size, flow], name, "TensorArray") as scope: if handle: self._handle = handle else: self._handle = gen_data_flow_ops._tensor_array( dtype=dtype, size=size, tensor_array_name=tensor_array_name, name=scope) - - self._flow = constant_op.constant(0, dtype=_dtypes.float32) - self._dtype = dtype + self._flow = flow or constant_op.constant(0, dtype=_dtypes.float32) @property def flow(self): @@ -90,14 +91,19 @@ class TensorArray(object): return self._flow @property + def dtype(self): + """The data type of this TensorArray.""" + return self._dtype + + @property def handle(self): """The reference to the TensorArray.""" return self._handle - def grad(self, source): + def grad(self, source, flow=None): g_handle = gen_data_flow_ops._tensor_array_grad( handle=self._handle, source=source) - g = TensorArray(dtype=self._dtype, size=None, handle=g_handle) + g = TensorArray(dtype=self._dtype, size=None, handle=g_handle, flow=flow) return g def read(self, index, name=None): diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc index 600f083840..8a7d0925ce 100644 --- a/tensorflow/stream_executor/dso_loader.cc +++ b/tensorflow/stream_executor/dso_loader.cc @@ -35,8 +35,13 @@ namespace perftools { namespace gputools { namespace internal { +// TensorFlow OSS configure uses the following lines to configure versions. For +// any modifications of the format, please make sure the script still works. +string GetCudaVersion() { return "7.0"; } +string GetCudnnVersion() { return "6.5"; } + /* static */ port::Status DsoLoader::GetCublasDsoHandle(void** dso_handle) { - return GetDsoHandle(FindDsoPath("libcublas.so.7.0", + return GetDsoHandle(FindDsoPath("libcublas.so." + GetCudaVersion(), "third_party/gpus/cuda/lib64"), dso_handle); } @@ -46,18 +51,19 @@ namespace internal { // different version number than other CUDA libraries. See b/22397368 for // some details about the complications surrounding this. return GetDsoHandle( - FindDsoPath("libcudnn.so.6.5", "third_party/gpus/cuda/lib64"), + FindDsoPath("libcudnn.so." + GetCudnnVersion(), + "third_party/gpus/cuda/lib64"), dso_handle); } /* static */ port::Status DsoLoader::GetCufftDsoHandle(void** dso_handle) { - return GetDsoHandle(FindDsoPath("libcufft.so.7.0", + return GetDsoHandle(FindDsoPath("libcufft.so." + GetCudaVersion(), "third_party/gpus/cuda/lib64"), dso_handle); } /* static */ port::Status DsoLoader::GetCurandDsoHandle(void** dso_handle) { - return GetDsoHandle(FindDsoPath("libcurand.so.7.0", + return GetDsoHandle(FindDsoPath("libcurand.so." + GetCudaVersion(), "third_party/gpus/cuda/lib64"), dso_handle); } @@ -70,7 +76,7 @@ namespace internal { /* static */ port::Status DsoLoader::GetLibcuptiDsoHandle(void** dso_handle) { return GetDsoHandle( - FindDsoPath("libcupti.so.7.0", + FindDsoPath("libcupti.so." + GetCudaVersion(), "third_party/gpus/cuda/extras/CUPTI/lib64"), dso_handle); } @@ -92,8 +98,6 @@ namespace internal { if (*dso_handle == nullptr) { LOG(INFO) << "Couldn't open CUDA library " << path << ". LD_LIBRARY_PATH: " << getenv("LD_LIBRARY_PATH"); - // TODO(b/22689637): Eliminate unnecessary ToString once StrCat has been - // moved to the open-sourceable version. return port::Status( port::error::FAILED_PRECONDITION, port::StrCat("could not dlopen DSO: ", path, "; dlerror: ", dlerror())); diff --git a/tensorflow/tensorboard/BUILD b/tensorflow/tensorboard/BUILD index bd4229b5f4..52518130f6 100644 --- a/tensorflow/tensorboard/BUILD +++ b/tensorflow/tensorboard/BUILD @@ -15,7 +15,7 @@ filegroup( py_library( name = "tensorboard_handler", - srcs = ["tensorboard_handler.py"], + srcs = ["backend/tensorboard_handler.py"], deps = [ ":float_wrapper", "//tensorflow/python:platform", @@ -26,14 +26,14 @@ py_library( py_library( name = "float_wrapper", - srcs = ["float_wrapper.py"], + srcs = ["backend/float_wrapper.py"], srcs_version = "PY2AND3", ) py_test( name = "float_wrapper_test", size = "small", - srcs = ["float_wrapper_test.py"], + srcs = ["backend/float_wrapper_test.py"], deps = [ ":float_wrapper", "//tensorflow/python:platform_test", @@ -43,7 +43,7 @@ py_test( py_binary( name = "tensorboard", - srcs = ["tensorboard.py"], + srcs = ["backend/tensorboard.py"], data = [":tensorboard_frontend"], deps = [ ":tensorboard_handler", diff --git a/tensorflow/tensorboard/backend/__init__.py b/tensorflow/tensorboard/backend/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tensorflow/tensorboard/backend/__init__.py diff --git a/tensorflow/tensorboard/float_wrapper.py b/tensorflow/tensorboard/backend/float_wrapper.py index 5d5f0ea6a1..5d5f0ea6a1 100644 --- a/tensorflow/tensorboard/float_wrapper.py +++ b/tensorflow/tensorboard/backend/float_wrapper.py diff --git a/tensorflow/tensorboard/float_wrapper_test.py b/tensorflow/tensorboard/backend/float_wrapper_test.py index 7f731b911b..ae29880285 100644 --- a/tensorflow/tensorboard/float_wrapper_test.py +++ b/tensorflow/tensorboard/backend/float_wrapper_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import tensorflow.python.platform from tensorflow.python.platform import googletest -from tensorflow.tensorboard import float_wrapper +from tensorflow.tensorboard.backend import float_wrapper _INFINITY = float('inf') diff --git a/tensorflow/tensorboard/tensorboard.py b/tensorflow/tensorboard/backend/tensorboard.py index bbb98ff4aa..30a31c6468 100644 --- a/tensorflow/tensorboard/tensorboard.py +++ b/tensorflow/tensorboard/backend/tensorboard.py @@ -38,7 +38,7 @@ from tensorflow.python.platform import resource_loader from tensorflow.python.platform import status_bar from tensorflow.python.summary import event_accumulator from tensorflow.python.summary import event_multiplexer -from tensorflow.tensorboard import tensorboard_handler +from tensorflow.tensorboard.backend import tensorboard_handler flags.DEFINE_string('logdir', None, """logdir specifies the directory where TensorBoard will look to find TensorFlow event files that it can display. diff --git a/tensorflow/tensorboard/tensorboard_handler.py b/tensorflow/tensorboard/backend/tensorboard_handler.py index c5ab674121..46c29e3d9b 100644 --- a/tensorflow/tensorboard/tensorboard_handler.py +++ b/tensorflow/tensorboard/backend/tensorboard_handler.py @@ -45,7 +45,7 @@ from tensorflow.python.platform import logging from tensorflow.python.platform import resource_loader from tensorflow.python.summary import event_accumulator from tensorflow.python.util import compat -from tensorflow.tensorboard import float_wrapper +from tensorflow.tensorboard.backend import float_wrapper DATA_PREFIX = '/data' diff --git a/tensorflow/tensorboard/components/tf-image-dashboard/tf-image-loader.html b/tensorflow/tensorboard/components/tf-image-dashboard/tf-image-loader.html index 10f03b0006..b782e3bc65 100644 --- a/tensorflow/tensorboard/components/tf-image-dashboard/tf-image-loader.html +++ b/tensorflow/tensorboard/components/tf-image-dashboard/tf-image-loader.html @@ -15,6 +15,7 @@ future for loading older images. img { width: 100%; height: 100%; + image-rendering: pixelated; } </style> <template> diff --git a/third_party/gpus/cuda/BUILD b/third_party/gpus/cuda/BUILD index 690728c5d2..47adb1d9d0 100644 --- a/third_party/gpus/cuda/BUILD +++ b/third_party/gpus/cuda/BUILD @@ -1,6 +1,10 @@ licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like load("/tensorflow/tensorflow", "if_cuda") +load("//tensorflow/core:platform/default/build_config.bzl", + "tf_get_cuda_version", + "tf_get_cudnn_version", + ) package(default_visibility = ["//visibility:public"]) @@ -53,10 +57,10 @@ cc_library( cc_library( name = "cudart", srcs = [ - "lib64/libcudart.so.7.0", + "lib64/libcudart.so." + tf_get_cuda_version(), ], data = [ - "lib64/libcudart.so.7.0", + "lib64/libcudart.so." + tf_get_cuda_version(), ], includes = ["include/"], visibility = ["//visibility:public"], @@ -66,10 +70,10 @@ cc_library( cc_library( name = "cublas", srcs = [ - "lib64/libcublas.so.7.0", + "lib64/libcublas.so." + tf_get_cuda_version(), ], data = [ - "lib64/libcublas.so.7.0", + "lib64/libcublas.so." + tf_get_cuda_version(), ], includes = ["include/"], visibility = ["//visibility:public"], @@ -79,10 +83,10 @@ cc_library( cc_library( name = "cudnn", srcs = [ - "lib64/libcudnn.so.6.5", + "lib64/libcudnn.so." + tf_get_cudnn_version(), ], data = [ - "lib64/libcudnn.so.6.5", + "lib64/libcudnn.so." + tf_get_cudnn_version(), ], includes = ["include/"], visibility = ["//visibility:public"], @@ -92,10 +96,10 @@ cc_library( cc_library( name = "cufft", srcs = [ - "lib64/libcufft.so.7.0", + "lib64/libcufft.so." + tf_get_cuda_version(), ], data = [ - "lib64/libcufft.so.7.0", + "lib64/libcufft.so." + tf_get_cuda_version(), ], includes = ["include/"], visibility = ["//visibility:public"], @@ -130,10 +134,10 @@ genrule( "include/cublas.h", "include/cudnn.h", "lib64/libcudart_static.a", - "lib64/libcublas.so.7.0", - "lib64/libcudnn.so.6.5", - "lib64/libcudart.so.7.0", - "lib64/libcufft.so.7.0", + "lib64/libcublas.so." + tf_get_cuda_version(), + "lib64/libcudnn.so." + tf_get_cudnn_version(), + "lib64/libcudart.so." + tf_get_cuda_version(), + "lib64/libcufft.so." + tf_get_cuda_version(), ], cmd = if_cuda( # Under cuda config, create all the symbolic links to the actual cuda files @@ -147,10 +151,10 @@ genrule( "touch $(@D)/include/cublas.h", "touch $(@D)/include/cudnn.h", "touch $(@D)/lib64/libcudart_static.a", - "touch $(@D)/lib64/libcublas.so.7.0", - "touch $(@D)/lib64/libcudnn.so.6.5", - "touch $(@D)/lib64/libcudart.so.7.0", - "touch $(@D)/lib64/libcufft.so.7.0" + "touch $(@D)/lib64/libcublas.so." + tf_get_cuda_version(), + "touch $(@D)/lib64/libcudnn.so." + tf_get_cudnn_version(), + "touch $(@D)/lib64/libcudart.so." + tf_get_cuda_version(), + "touch $(@D)/lib64/libcufft.so." + tf_get_cuda_version(), ]), ), local = 1, diff --git a/third_party/gpus/cuda/cuda_config.sh b/third_party/gpus/cuda/cuda_config.sh index 44e07433a0..87c35349c0 100755 --- a/third_party/gpus/cuda/cuda_config.sh +++ b/third_party/gpus/cuda/cuda_config.sh @@ -16,7 +16,7 @@ # A simple script to configure the Cuda tree needed for the TensorFlow GPU -# build. We need both Cuda toolkit 7.0 and Cudnn 6.5. +# build. We need both Cuda toolkit $TF_CUDA_VERSION and Cudnn $TF_CUDNN_VERSION. # Useage: # * User edit cuda.config to point both Cuda toolkit and Cudnn libraries to their local path # * run cuda_config.sh to generate symbolic links in the source tree to reflect @@ -62,8 +62,8 @@ function CudaError { cat << EOF ############################################################################## ############################################################################## -Cuda 7.0 toolkit is missing. -1. Download and install the CUDA 7.0 toolkit and CUDNN 6.5 library; +Cuda $TF_CUDA_VERSION toolkit is missing. +1. Download and install the CUDA $TF_CUDA_VERSION toolkit and CUDNN $TF_CUDNN_VERSION library; 2. Run configure from the root of the source tree, before rerunning bazel; Please refer to README.md for more details. ############################################################################## @@ -78,8 +78,8 @@ function CudnnError { cat << EOF ############################################################################## ############################################################################## -Cudnn 6.5 is missing. -1. Download and install the CUDA 7.0 toolkit and CUDNN 6.5 library; +Cudnn $TF_CUDNN_VERSION is missing. +1. Download and install the CUDA $TF_CUDA_VERSION toolkit and CUDNN $TF_CUDNN_VERSION library; 2. Run configure from the root of the source tree, before rerunning bazel; Please refer to README.md for more details. ############################################################################## @@ -110,18 +110,18 @@ if [ "$CHECK_ONLY" == "1" ]; then CheckAndLinkToSrcTree CudaError include/cublas.h CheckAndLinkToSrcTree CudnnError include/cudnn.h CheckAndLinkToSrcTree CudaError lib64/libcudart_static.a - CheckAndLinkToSrcTree CudaError lib64/libcublas.so.7.0 - CheckAndLinkToSrcTree CudnnError lib64/libcudnn.so.6.5 - CheckAndLinkToSrcTree CudaError lib64/libcudart.so.7.0 - CheckAndLinkToSrcTree CudaError lib64/libcufft.so.7.0 + CheckAndLinkToSrcTree CudaError lib64/libcublas.so.$TF_CUDA_VERSION + CheckAndLinkToSrcTree CudnnError lib64/libcudnn.so.$TF_CUDNN_VERSION + CheckAndLinkToSrcTree CudaError lib64/libcudart.so.$TF_CUDA_VERSION + CheckAndLinkToSrcTree CudaError lib64/libcufft.so.$TF_CUDA_VERSION exit 0 fi # Actually configure the source tree for TensorFlow's canonical view of Cuda # libraries. -if test ! -e ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so.7.0; then - CudaError "cannot find ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so.7.0" +if test ! -e ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so.$TF_CUDA_VERSION; then + CudaError "cannot find ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so.$TF_CUDA_VERSION" fi if test ! -d ${CUDNN_INSTALL_PATH}; then @@ -137,13 +137,13 @@ else CudnnError "cannot find cudnn.h under: ${CUDNN_INSTALL_PATH}" fi -# Locate libcudnn.so.6.5 -if test -e ${CUDNN_INSTALL_PATH}/libcudnn.so.6.5; then +# Locate libcudnn.so.${$TF_CUDNN_VERSION} +if test -e ${CUDNN_INSTALL_PATH}/libcudnn.so.$TF_CUDNN_VERSION; then CUDNN_LIB_PATH=${CUDNN_INSTALL_PATH} -elif test -e ${CUDNN_INSTALL_PATH}/lib64/libcudnn.so.6.5; then +elif test -e ${CUDNN_INSTALL_PATH}/lib64/libcudnn.so.$TF_CUDNN_VERSION; then CUDNN_LIB_PATH=${CUDNN_INSTALL_PATH}/lib64 else - CudnnError "cannot find libcudnn.so.6.5 under: ${CUDNN_INSTALL_PATH}" + CudnnError "cannot find libcudnn.so.$TF_CUDNN_VERSION under: ${CUDNN_INSTALL_PATH}" fi # Helper function to build symbolic links for all files under a directory. @@ -182,4 +182,4 @@ LinkAllFiles ${CUDA_TOOLKIT_PATH}/nvvm $OUTPUTDIR/third_party/gpus/cuda/nvvm || # Set up symbolic link for cudnn ln -sf $CUDNN_HEADER_PATH/cudnn.h $OUTPUTDIR/third_party/gpus/cuda/include/cudnn.h || exit -1 -ln -sf $CUDNN_LIB_PATH/libcudnn.so.6.5 $OUTPUTDIR/third_party/gpus/cuda/lib64/libcudnn.so.6.5 || exit -1 +ln -sf $CUDNN_LIB_PATH/libcudnn.so.$TF_CUDNN_VERSION $OUTPUTDIR/third_party/gpus/cuda/lib64/libcudnn.so.$TF_CUDNN_VERSION || exit -1 |