diff options
Diffstat (limited to 'tensorflow/core')
71 files changed, 1457 insertions, 510 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b6b48a077c..ef8c3f358a 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -89,6 +89,7 @@ load( "tf_generate_proto_text_sources", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_android", + "tf_features_nomodules_if_android", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl") load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") @@ -998,6 +999,7 @@ tf_gen_op_libs( "nn_ops", "no_op", "parsing_ops", + "random_grad", "random_ops", "remote_fused_graph_ops", "resource_variable_ops", @@ -2339,6 +2341,7 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [ "framework/op_segment.h", "framework/rendezvous.h", # only needed for tests + "framework/resource_var.h", "framework/tensor_reference.h", "framework/tracking_allocator.h", # only needed for tests "framework/unique_tensor_references.h", @@ -3369,6 +3372,7 @@ tf_cc_tests( "framework/bfloat16_test.cc", "framework/cancellation_test.cc", "framework/common_shape_fns_test.cc", + "framework/device_base_test.cc", "framework/function_test.cc", "framework/graph_def_util_test.cc", "framework/graph_to_functiondef_test.cc", diff --git a/tensorflow/core/api_def/base_api/api_def_IgammaGradA.pbtxt b/tensorflow/core/api_def/base_api/api_def_IgammaGradA.pbtxt new file mode 100644 index 0000000000..747a8badfd --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_IgammaGradA.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "IgammaGradA" + visibility: HIDDEN + summary: "Computes the gradient of `igamma(a, x)` wrt `a`." +} diff --git a/tensorflow/core/api_def/base_api/api_def_RandomGammaGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_RandomGammaGrad.pbtxt new file mode 100644 index 0000000000..d2bd76f8b9 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_RandomGammaGrad.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "RandomGammaGrad" + visibility: HIDDEN + summary: "Computes the derivative of a Gamma random sample w.r.t. `alpha`." +} diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index d66963ec74..b4bf1c408f 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -74,6 +74,9 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { options.config.mutable_graph_options() ->mutable_rewrite_options() ->set_constant_folding(RewriterConfig::OFF); + options.config.mutable_graph_options() + ->mutable_rewrite_options() + ->set_min_graph_nodes(-1); std::unique_ptr<Session> session(NewSession(options)); TF_ASSERT_OK(session->Create(def)); std::vector<std::pair<string, Tensor>> inputs; diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index b5120f2872..7f28f3b793 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -22,14 +22,19 @@ tf_cuda_library( "eager_executor.h", ], visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - ], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], + }), ) tf_cuda_library( @@ -44,17 +49,23 @@ tf_cuda_library( deps = [ ":eager_executor", ":kernel_and_device", - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:session_options", - "//tensorflow/core/distributed_runtime:worker_session", - "//tensorflow/core/distributed_runtime/eager:eager_client", - "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_session", + "//tensorflow/core/distributed_runtime/eager:eager_client", + ], + }), ) tf_cuda_library( @@ -86,14 +97,20 @@ tf_cuda_library( ":context", ":eager_executor", ":kernel_and_device", - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:session_options", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + ], + }), ) tf_cuda_library( @@ -106,14 +123,19 @@ tf_cuda_library( ":context", ":eager_executor", ":tensor_handle", - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:session_options", - ], + ] + select({ + "//tensorflow:android": [ + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + ], + }), ) tf_cuda_library( @@ -125,14 +147,20 @@ tf_cuda_library( "kernel_and_device.h", ], visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - ], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + "//util/hash:farmhash_fingerprint", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], + }), ) tf_cc_test( @@ -168,14 +196,20 @@ cc_library( ":eager_operation", ":kernel_and_device", ":tensor_handle", - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/distributed_runtime/eager:eager_client", - "//tensorflow/core/distributed_runtime/eager:remote_execute_node", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/eager:remote_execute_node", + ], + }), ) tf_cuda_library( @@ -183,13 +217,15 @@ tf_cuda_library( srcs = ["attr_builder.cc"], hdrs = ["attr_builder.h"], visibility = ["//tensorflow:internal"], - deps = select({ + deps = [ + ":kernel_and_device", + "//tensorflow/c:c_api", + ] + select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", + "//util/hash:farmhash_fingerprint", ], "//conditions:default": [ - ":kernel_and_device", - "//tensorflow/c:c_api", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 8381cb58d2..8a87ba7a19 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -38,10 +38,11 @@ EagerContext::EagerContext(const SessionOptions& opts, InitDeviceMapAndAsync(); } +#ifndef __ANDROID__ EagerContext::EagerContext( const SessionOptions& opts, ContextDevicePlacementPolicy default_policy, bool async, DeviceMgr* local_device_mgr, Rendezvous* rendezvous, - std::unique_ptr<GrpcServer> server, + std::unique_ptr<ServerInterface> server, std::unique_ptr<eager::EagerClientCache> remote_eager_workers, std::unique_ptr<DeviceMgr> remote_device_manager, const gtl::FlatMap<string, uint64>& remote_contexts) @@ -55,12 +56,13 @@ EagerContext::EagerContext( &func_lib_def_, {}, thread_pool_.get())), log_device_placement_(opts.config.log_device_placement()), async_default_(async), + remote_device_manager_(std::move(remote_device_manager)), server_(std::move(server)), remote_eager_workers_(std::move(remote_eager_workers)), - remote_device_manager_(std::move(remote_device_manager)), remote_contexts_(remote_contexts) { InitDeviceMapAndAsync(); } +#endif void EagerContext::InitDeviceMapAndAsync() { if (async_default_) { @@ -125,10 +127,11 @@ ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() { } EagerContext::~EagerContext() { +#ifndef __ANDROID__ if (server_) { // TODO(nareshmodi): Fix this. LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. " - "GrpcServer doesn't support clean shutdown."; + "Servers don't support clean shutdown."; server_.release(); } @@ -158,6 +161,7 @@ EagerContext::~EagerContext() { } counter.Wait(); +#endif executor_.WaitForAllPendingNodes().IgnoreError(); ClearCaches(); @@ -224,6 +228,7 @@ Status GetTaskName(Device* d, string* task_name) { } } // namespace +#ifndef __ANDROID__ Status EagerContext::GetClientAndContextID(Device* device, eager::EagerClient** client, uint64* context_id) { @@ -253,5 +258,6 @@ Status EagerContext::GetClientAndContextID(Device* device, return Status::OK(); } +#endif } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 096ed3112e..601b9e4545 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -29,8 +29,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#ifndef __ANDROID__ #include "tensorflow/core/distributed_runtime/eager/eager_client.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#endif #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -75,21 +77,22 @@ class EagerContext { // workers. // // Additional remote-specific args are: - // - server: A GrpcServer that exports the tensorflow.WorkerService. Note - // that this class expects the server to already have been started. + // - server: A ServerInterface that exports the tensorflow.WorkerService. + // Note that this class expects the server to already have been started. // - remote_eager_workers: A cache from which we can get "EagerClient"s to // communicate with remote eager services. // - remote_device_mgr: A DeviceMgr* which contains all remote devices // (should contain no local devices). // - remote_contexts: A map containing task name to remote context ID. +#ifndef __ANDROID__ explicit EagerContext( const SessionOptions& opts, ContextDevicePlacementPolicy default_policy, bool async, DeviceMgr* local_device_mgr, Rendezvous* rendezvous, - std::unique_ptr<GrpcServer> server, + std::unique_ptr<ServerInterface> server, std::unique_ptr<eager::EagerClientCache> remote_eager_workers, std::unique_ptr<DeviceMgr> remote_device_manager, const gtl::FlatMap<string, uint64>& remote_contexts); - +#endif ~EagerContext(); // Returns the function library runtime for the given device. @@ -174,9 +177,10 @@ class EagerContext { FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; } +#ifndef __ANDROID__ Status GetClientAndContextID(Device* device, eager::EagerClient** client, uint64* context_id); - +#endif private: void InitDeviceMapAndAsync(); @@ -228,16 +232,19 @@ class EagerContext { std::unordered_map<std::thread::id, bool> thread_local_async_ GUARDED_BY(async_map_mu_); + const std::unique_ptr<DeviceMgr> remote_device_manager_; + // The server_ is not const since we release it when the context is destroyed. // Therefore the server_ object is not marked as const (even though it should // be). - std::unique_ptr<GrpcServer> server_; +#ifndef __ANDROID__ + std::unique_ptr<ServerInterface> server_; const std::unique_ptr<eager::EagerClientCache> remote_eager_workers_; - const std::unique_ptr<DeviceMgr> remote_device_manager_; const gtl::FlatMap<string, uint64> remote_contexts_; gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>> device_to_client_cache_; +#endif }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index c619857b78..14aa520e19 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -24,8 +24,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/execute_node.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#ifndef __ANDROID__ #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h" +#endif #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" @@ -39,6 +41,11 @@ namespace tensorflow { namespace { +// Copy of the definition in third_party/tensorflow/compiler/jit/defs.h +// Copied here because we don't currently compile XLA on windows. So, can't +// depend on it directly. +const char* const kXlaCompileAttr = "_XlaCompile"; + // Initializes the step stats if needed. void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) { // Lazily initialize the RunMetadata with information about all devices if @@ -472,6 +479,15 @@ Status EagerLocalExecute(EagerOperation* op, device == nullptr ? "unspecified" : device->name()); KernelAndDevice* kernel = ctx->GetCachedKernel(cache_key); if (kernel == nullptr) { + // If we are running a function on explicitly requested TPU, + // compile it with XLA. + // Note that it is not ideal, but currently ok, to set this + // attribute after computing the kernel cache key above. + if (op->is_function() && device != nullptr && + device->device_type() == "TPU") { + op->MutableAttrs()->Set(kXlaCompileAttr, true); + } + const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef(); if (device == nullptr) { status = SelectDevice(ndef, ctx, &device); @@ -559,9 +575,19 @@ Status EagerLocalExecute(EagerOperation* op, return status; } -Status EagerRemoteExecute(EagerOperation* op, eager::EagerClient* eager_client, - uint64 context_id, TensorHandle** retvals, +Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, int* num_retvals) { +#ifdef __ANDROID__ + return errors::Unimplemented( + "Eager's remote execution is not available on Android devices."); +#else + EagerContext* ctx = op->EagerContext(); + + eager::EagerClient* eager_client; + uint64 context_id; + TF_RETURN_IF_ERROR( + ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id)); + eager::EnqueueRequest request; eager::EnqueueResponse response; @@ -622,7 +648,6 @@ Status EagerRemoteExecute(EagerOperation* op, eager::EagerClient* eager_client, } tensorflow::Device* op_device = op->Device(); - EagerContext* ctx = op->EagerContext(); const tensorflow::uint64 id = remote_op->id(); for (int i = 0; i < *num_retvals; i++) { @@ -657,6 +682,7 @@ Status EagerRemoteExecute(EagerOperation* op, eager::EagerClient* eager_client, } return Status::OK(); +#endif } } // namespace @@ -669,15 +695,7 @@ Status EagerExecute(EagerOperation* op, return EagerLocalExecute(op, retvals, num_retvals); } - auto* ctx = op->EagerContext(); - - tensorflow::eager::EagerClient* eager_client; - tensorflow::uint64 context_id; - TF_RETURN_IF_ERROR( - ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id)); - - return EagerRemoteExecute(op, eager_client, context_id, retvals->data(), - num_retvals); + return EagerRemoteExecute(op, retvals->data(), num_retvals); } Status EagerExecute(EagerContext* ctx, Device* device, diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 68d37ddbcd..6d8cea8297 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -1188,11 +1188,13 @@ static bool ValidateInlining(const Node* node, const FunctionBody* fbody) { return true; } -// Given a "caller" in "graph", which is a function call of a function +// Given a "caller" in graph "g", which is a function call of a function // to "fbody". Replaces the "caller" with fbody->graph and connects -// edges properly. +// edges properly. "override_device" specifies whether inlining should replace +// explicitly specified devices inside fbody with the callee's device. void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, - Node* caller, const FunctionBody* fbody) { + Node* caller, const FunctionBody* fbody, + bool override_device) { if (!ValidateInlining(caller, fbody)) { LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. " << DebugString(fbody->graph); @@ -1227,7 +1229,9 @@ void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, for (Node* n : fbody->graph->op_nodes()) { NodeDef ndef = n->def(); ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name())); - ndef.set_device(caller->def().device()); + if (override_device || ndef.device().empty()) { + ndef.set_device(caller->def().device()); + } Node* clone = g->AddNode(ndef, &s); TF_CHECK_OK(s); node_map[n->id()] = clone; @@ -1581,6 +1585,12 @@ FunctionBody* SymbolicGradientHelper::Compute() { g->RemoveNode(n); } gbody_->ret_types = fbody_->arg_types; + // TODO(apassos): use the right dtype for gradients of resource variables + for (int i = 0; i < gbody_->ret_types.size(); ++i) { + if (gbody_->ret_types[i] == DT_RESOURCE) { + gbody_->ret_types[i] = DT_FLOAT; + } + } gbody_->ret_nodes.clear(); // Add new return nodes to the function gradient body for each node // in 'x_grad_nodes'. diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h index a0f9fcae0a..a274f1ef51 100644 --- a/tensorflow/core/common_runtime/function.h +++ b/tensorflow/core/common_runtime/function.h @@ -155,9 +155,11 @@ FunctionBody* SymbolicGradient(const FunctionBody& f); // Given a "caller" in graph "g", which is a function call of a function // to "fbody". Replaces the "caller" with fbody->graph and connects -// edges properly. +// edges properly. "override_device" specifies whether inlining should replace +// explicitly specified devices inside fbody with the callee's device. void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, - Node* caller, const FunctionBody* fbody); + Node* caller, const FunctionBody* fbody, + bool override_device = true); // Instantiates FunctionDef into a graph. Set *fbody to point to the // FunctionBody that holds the instantiated FunctionDef. diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index eb710bdbc5..58018689d5 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -43,7 +43,6 @@ limitations under the License. #include "tensorflow/core/util/util.h" #ifndef IS_MOBILE_PLATFORM -#include "tensorflow/core/grappler/clusters/utils.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" @@ -476,21 +475,15 @@ Status GraphExecutionState::OptimizeGraph( } } - std::unordered_map<string, DeviceProperties> device_map; Device* cpu_device = nullptr; for (const auto& device : device_set_->devices()) { - DeviceProperties props = grappler::GetDeviceInfo(device->parsed_name()); - if (props.type() == "UNKNOWN") { - continue; - } - device_map[device->name()] = props; if (device->parsed_name().id == 0 && StringPiece(device->parsed_name().type) == "CPU" && device->GetAllocator(AllocatorAttributes()) != nullptr) { cpu_device = device; } } - grappler::VirtualCluster cluster(device_map, device_set_); + grappler::VirtualCluster cluster(device_set_); GraphDef new_graph; TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer( item, rewrite_options, cpu_device, &cluster, &new_graph)); diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index 567c81870c..dfce7c23e7 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -206,7 +206,7 @@ Status InlineCallInGraph(Node* n, Graph* g) { &fbody)); // TODO(jpienaar): Improve this interface to make the need to delete it // explicit. - InlineFunctionBody(g->flib_def(), g, n, fbody); + InlineFunctionBody(g->flib_def(), g, n, fbody, false); delete fbody; return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 66c4e5d7a9..4a10d99a60 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -286,7 +286,9 @@ cc_library( "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:session_mgr", + "//tensorflow/core/distributed_runtime:worker_cache_wrapper", "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl", ], alwayslink = 1, ) diff --git a/tensorflow/core/distributed_runtime/rpc/eager/BUILD b/tensorflow/core/distributed_runtime/rpc/eager/BUILD index 6b44d8cecf..d09a85c6a5 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/eager/BUILD @@ -43,25 +43,11 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:ptr_util", "//tensorflow/core/distributed_runtime/eager:eager_service_impl", + "//tensorflow/core/distributed_runtime/rpc:async_service_interface", "//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", - "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", ], ) - -cc_library( - name = "eager_grpc_server_lib", - hdrs = ["eager_grpc_server_lib.h"], - deps = [ - ":grpc_eager_service_impl", - "//tensorflow/core:core_cpu", - "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", - "//tensorflow/core/distributed_runtime:worker_cache_wrapper", - "//tensorflow/core/distributed_runtime/eager:eager_service_impl", - "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", - "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", - ], -) diff --git a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h deleted file mode 100644 index 9b863ccee5..0000000000 --- a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_ -#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_ - -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" -#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" -#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h" - -namespace tensorflow { -namespace eager { - -class EagerGrpcServer : public GrpcServer { - public: - static Status Create(const ServerDef& server_def, - std::unique_ptr<EagerGrpcServer>* server) { - std::unique_ptr<EagerGrpcServer> ret(new EagerGrpcServer(server_def)); - - TF_RETURN_IF_ERROR(ret->InitEager()); - - *server = std::move(ret); - - return Status::OK(); - } - - Status Start() override { - TF_RETURN_IF_ERROR(GrpcServer::Start()); - - eager_service_->Start(); - - return Status::OK(); - } - - Status Stop() override { - TF_RETURN_IF_ERROR(GrpcServer::Stop()); - - eager_service_->Stop(); - - return Status::OK(); - } - - using GrpcServer::channel_cache; - using GrpcServer::master_env; - using GrpcServer::worker_env; - - private: - EagerGrpcServer(const ServerDef& server_def) - : GrpcServer(server_def, Env::Default()), - worker_name_( - strings::StrCat("/job:", server_def.job_name(), - "/replica:0/task:", server_def.task_index())) {} - - Status InitEager() { - TF_RETURN_IF_ERROR(this->Init( - [this](const WorkerEnv* worker_env, - ::grpc::ServerBuilder* server_builder) { - this->eager_service_.reset( - new eager::GrpcEagerServiceImpl(worker_env, server_builder)); - }, - nullptr, nullptr)); - - worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr( - "", worker_name_, - std::unique_ptr<WorkerCacheInterface>( - new WorkerCacheWrapper(master_env()->worker_cache)), - worker_env()->device_mgr, {}); - - auto* r = worker_env()->rendezvous_mgr->Find(0); - return r->Initialize(worker_session_.get()); - } - - std::unique_ptr<GrpcEagerServiceImpl> eager_service_; - std::shared_ptr<WorkerSession> worker_session_; - const string worker_name_; -}; // namespace eager - -} // namespace eager -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc index b36c6dce86..52e06c263d 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc @@ -18,10 +18,8 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -36,7 +34,7 @@ GrpcEagerServiceImpl::GrpcEagerServiceImpl( cq_ = server_builder->AddCompletionQueue(); } -void GrpcEagerServiceImpl::DriveCQ() { +void GrpcEagerServiceImpl::HandleRPCsLoop() { #define ENQUEUE_REQUEST(method) \ do { \ Call<GrpcEagerServiceImpl, \ @@ -74,12 +72,7 @@ void GrpcEagerServiceImpl::DriveCQ() { } } -void GrpcEagerServiceImpl::Start() { - // TODO(nareshmodi) separate thread for driving CQ - request_handler_threadpool_->Schedule([this]() { DriveCQ(); }); -} - -void GrpcEagerServiceImpl::Stop() { +void GrpcEagerServiceImpl::Shutdown() { // This enqueues a special event (with a null tag) // that causes the completion queue to be shut down on the // polling thread. diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h index e94aedf535..9a94026342 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h @@ -20,16 +20,16 @@ limitations under the License. #include "grpcpp/completion_queue.h" #include "grpcpp/server_builder.h" #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" +#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" namespace tensorflow { namespace eager { // This class is a wrapper that handles communication for gRPC. -class GrpcEagerServiceImpl { +class GrpcEagerServiceImpl : public AsyncServiceInterface { public: template <class RequestMessage, class ResponseMessage> using EagerCall = Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService, @@ -39,8 +39,8 @@ class GrpcEagerServiceImpl { ::grpc::ServerBuilder* server_builder); virtual ~GrpcEagerServiceImpl() {} - void Start(); - void Stop(); + void HandleRPCsLoop() override; + void Shutdown() override; private: #define HANDLER(method) \ @@ -66,8 +66,6 @@ class GrpcEagerServiceImpl { EagerServiceImpl local_impl_; - void DriveCQ(); - std::unique_ptr<::grpc::Alarm> shutdown_alarm_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index e7914740ae..aa334f9424 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/master_env.h" #include "tensorflow/core/distributed_runtime/master_session.h" #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" #include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -81,6 +83,7 @@ GrpcServer::~GrpcServer() { delete master_service_; delete worker_service_; + delete eager_service_; // TODO(mrry): Refactor the *Env classes so that it is less fiddly // to destroy them. @@ -192,6 +195,8 @@ Status GrpcServer::Init( worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_); worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder).release(); + eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder); + // extra service: if (service_func != nullptr) { service_func(&worker_env_, &builder); @@ -264,7 +269,15 @@ Status GrpcServer::Init( LocalMaster::Register(target(), master_impl_.get(), config.operation_timeout_in_ms()); - return Status::OK(); + // Generate a dummy worker session that is used to register the + // Rendezvous for eager (we use Step 0 for eager). + worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr( + "", name_prefix, + std::unique_ptr<WorkerCacheInterface>( + new WorkerCacheWrapper(master_env_.worker_cache)), + worker_env_.device_mgr, {}); + auto* r = worker_env()->rendezvous_mgr->Find(0); + return r->Initialize(worker_session_.get()); } Status GrpcServer::Init( @@ -365,6 +378,9 @@ Status GrpcServer::Start() { worker_thread_.reset( env_->StartThread(ThreadOptions(), "TF_worker_service", [this] { worker_service_->HandleRPCsLoop(); })); + eager_thread_.reset( + env_->StartThread(ThreadOptions(), "TF_eager_service", + [this] { eager_service_->HandleRPCsLoop(); })); state_ = STARTED; LOG(INFO) << "Started server with target: " << target(); return Status::OK(); @@ -407,6 +423,7 @@ Status GrpcServer::Join() { case STOPPED: master_thread_.reset(); worker_thread_.reset(); + eager_thread_.reset(); return Status::OK(); default: LOG(FATAL); @@ -443,6 +460,17 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env, return Status::OK(); } +/* static */ +Status GrpcServer::Create(const ServerDef& server_def, Env* env, + std::unique_ptr<GrpcServer>* out_server) { + std::unique_ptr<GrpcServer> ret( + new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); + ServiceInitFunction service_func = nullptr; + TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr, nullptr)); + *out_server = std::move(ret); + return Status::OK(); +} + namespace { class GrpcServerFactory : public ServerFactory { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 9e53330f85..115148b84e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -63,6 +63,8 @@ class GrpcServer : public ServerInterface { public: static Status Create(const ServerDef& server_def, Env* env, std::unique_ptr<ServerInterface>* out_server); + static Status Create(const ServerDef& server_def, Env* env, + std::unique_ptr<GrpcServer>* out_server); // Destruction is only supported in the factory method. Clean // shutdown is not currently implemented for this server type. @@ -74,6 +76,11 @@ class GrpcServer : public ServerInterface { Status Join() override; const string target() const override; + WorkerEnv* worker_env() { return &worker_env_; } + MasterEnv* master_env() { return &master_env_; } + + std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; } + protected: Status Init(ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func, @@ -115,11 +122,6 @@ class GrpcServer : public ServerInterface { // This method may only be called after `this->Init()` returns successfully. int bound_port() const { return bound_port_; } - WorkerEnv* worker_env() { return &worker_env_; } - MasterEnv* master_env() { return &master_env_; } - - std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; } - const ServerDef& server_def() const { return server_def_; } private: @@ -158,6 +160,11 @@ class GrpcServer : public ServerInterface { AsyncServiceInterface* worker_service_ = nullptr; std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_); + // TensorFlow Eager implementation, and RPC polling thread. + AsyncServiceInterface* eager_service_ = nullptr; + std::unique_ptr<Thread> eager_thread_ GUARDED_BY(mu_); + std::shared_ptr<WorkerSession> worker_session_; + std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_); }; diff --git a/tensorflow/core/framework/device_base.cc b/tensorflow/core/framework/device_base.cc index e30ee84cc3..9108c32942 100644 --- a/tensorflow/core/framework/device_base.cc +++ b/tensorflow/core/framework/device_base.cc @@ -13,11 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS + #include "tensorflow/core/framework/device_base.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/util/work_sharder.h" + namespace tensorflow { -DeviceBase::~DeviceBase() {} +DeviceBase::~DeviceBase() { gtl::STLDeleteElements(&eigen_cpu_devices_); } const DeviceAttributes& DeviceBase::attributes() const { LOG(FATAL) << "Device does not implement attributes()"; @@ -27,4 +33,29 @@ const string& DeviceBase::name() const { LOG(FATAL) << "Device does not implement name()"; } +void DeviceBase::set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) { + // Eigen::ThreadPoolDevice is a very cheap struct (one pointer and + // an int). Therefore, we can afford a pre-allocated array of + // Eigen::ThreadPoolDevice. Here, we ensure that + // Eigen::ThreadPoolDevices in eigen_cpu_devices_ has increasingly + // larger numThreads. + for (int i = 1; i <= d->numThreads(); ++i) { + eigen_cpu_devices_.push_back( + new Eigen::ThreadPoolDevice(d->getPool(), i /* numThreads() */)); + } +} + +const Eigen::ThreadPoolDevice* DeviceBase::eigen_cpu_device() { + // Based on GetPerThreadMaxParallelism(), we return a different + // pre-allocated Eigen::ThreadPoolDevice. All these ThreadPoolDevice + // use the same underlying threadpool. But they use different + // nominal numThreads() hoping that the user of the returned + // Eigen::ThreadPoolDevice may not aggressively occupy all the + // threads in the underlying threadpool. + const int parallelism = std::max<int>( + 1, + std::min<int>(GetPerThreadMaxParallelism(), eigen_cpu_devices_.size())); + return eigen_cpu_devices_[parallelism - 1]; +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index ec26d92a61..922d34fac9 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -18,7 +18,7 @@ limitations under the License. #include <memory> #include <string> -#include <unordered_map> +#include <vector> #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" @@ -154,9 +154,7 @@ class DeviceBase { } // Does not take ownership. - void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) { - eigen_cpu_device_ = d; - } + void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d); #ifdef TENSORFLOW_USE_SYCL void set_eigen_sycl_device(Eigen::SyclDevice* d) { eigen_sycl_device_ = d; } @@ -186,11 +184,12 @@ class DeviceBase { virtual ScopedAllocatorMgr* GetScopedAllocatorMgr() const { return nullptr; } - virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() { - CHECK(eigen_cpu_device_ != nullptr); - return eigen_cpu_device_; + const bool has_eigen_cpu_device() const { + return !eigen_cpu_devices_.empty(); } + virtual const Eigen::ThreadPoolDevice* eigen_cpu_device(); + #ifdef TENSORFLOW_USE_SYCL virtual const Eigen::SyclDevice* eigen_sycl_device() const { CHECK(eigen_sycl_device_ != nullptr); @@ -242,7 +241,7 @@ class DeviceBase { // Set by GPUs as well as by TPU devices. GpuDeviceInfo* gpu_device_info_ = nullptr; thread::ThreadPool* device_thread_pool_ = nullptr; - Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr; + std::vector<Eigen::ThreadPoolDevice*> eigen_cpu_devices_; #ifdef TENSORFLOW_USE_SYCL Eigen::SyclDevice* eigen_sycl_device_ = nullptr; #endif diff --git a/tensorflow/core/framework/device_base_test.cc b/tensorflow/core/framework/device_base_test.cc new file mode 100644 index 0000000000..6909559ea2 --- /dev/null +++ b/tensorflow/core/framework/device_base_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/device_base.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +TEST(DeviceBaseTest, CpuDevice) { + DeviceBase dbase(Env::Default()); + thread::ThreadPool pool(Env::Default(), "test", 16); + EigenThreadPoolWrapper wrapper(&pool); + Eigen::ThreadPoolDevice eigen_device(&wrapper, pool.NumThreads()); + ASSERT_FALSE(dbase.has_eigen_cpu_device()); + dbase.set_eigen_cpu_device(&eigen_device); + ASSERT_TRUE(dbase.has_eigen_cpu_device()); + + { + auto d = dbase.eigen_cpu_device(); + EXPECT_EQ(d->numThreads(), 16); + } + + { + ScopedPerThreadMaxParallelism maxp(4); + auto d = dbase.eigen_cpu_device(); + EXPECT_EQ(d->numThreads(), 4); + } + + { + ScopedPerThreadMaxParallelism maxp(1); + auto d = dbase.eigen_cpu_device(); + EXPECT_EQ(d->numThreads(), 1); + } + + { + ScopedPerThreadMaxParallelism maxp(1000); + auto d = dbase.eigen_cpu_device(); + EXPECT_EQ(d->numThreads(), 16); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index 21fc6c1bd5..0a19861efd 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -60,8 +60,8 @@ namespace internal { Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) { if (ctx->device()->attributes().name() != p.device()) { return errors::InvalidArgument( - "Trying to access resource located in device ", p.device(), - " from device ", ctx->device()->attributes().name()); + "Trying to access resource ", p.name(), " located in device ", + p.device(), " from device ", ctx->device()->attributes().name()); } return Status::OK(); } diff --git a/tensorflow/core/framework/resource_var.h b/tensorflow/core/framework/resource_var.h index 872b8f8b30..ff7b3e78a7 100644 --- a/tensorflow/core/framework/resource_var.h +++ b/tensorflow/core/framework/resource_var.h @@ -29,6 +29,8 @@ class Var : public ResourceBase { Var(const Var&) = delete; Var& operator=(const Var&) = delete; + // When locking multiple variables, the locks must be acquired in order of + // increasing mu() address. // TODO(ebrevdo): Use LockSet instead of exposing mu. mutex* mu() { return &mu_; } Tensor* tensor() { return &tensor_; } diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index b02bc3adbe..8d597e198d 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -340,6 +340,20 @@ string InferenceContext::DebugString() const { ProtoDebugString(*node_def_)); } +string InferenceContext::DebugString(const ShapeAndType& shape_and_type) { + return strings::StrCat(DebugString(shape_and_type.shape), ":", + DataTypeString(shape_and_type.dtype)); +} + +string InferenceContext::DebugString( + gtl::ArraySlice<ShapeAndType> shape_and_types) { + std::vector<string> pieces; + for (const ShapeAndType& s : shape_and_types) { + pieces.push_back(DebugString(s)); + } + return strings::StrCat("[", str_util::Join(pieces, ","), "]"); +} + Status InferenceContext::WithRank(ShapeHandle shape, int64 rank, ShapeHandle* out) { if (rank > kint32max) { diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 3f3729dcf9..81258b55b3 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -381,6 +381,8 @@ class InferenceContext { string DebugString(ShapeHandle s); string DebugString(DimensionHandle d); + string DebugString(const ShapeAndType& shape_and_type); + string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types); // Describes the whole context, for debugging purposes. string DebugString() const; diff --git a/tensorflow/core/graph/gradients.cc b/tensorflow/core/graph/gradients.cc index 6b56613470..c1a8a63784 100644 --- a/tensorflow/core/graph/gradients.cc +++ b/tensorflow/core/graph/gradients.cc @@ -106,8 +106,15 @@ static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) { AddNodeAttr("Tin", in_types, &ndef); // The gradient node's outputs have the same types as the node 'n's - // inputs. - AddNodeAttr("Tout", n->input_types(), &ndef); + // inputs, except for resources. + DataTypeVector out_types = n->input_types(); + for (int i = 0; i < out_types.size(); ++i) { + if (out_types[i] == DT_RESOURCE) { + // TODO(apassos): figure out how to get the right dtype + out_types[i] = DT_FLOAT; + } + } + AddNodeAttr("Tout", out_types, &ndef); NameAttrList func; func.set_name(n->type_string()); for (const auto& attr : n->attrs()) { diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index 193cf88aed..60337e30aa 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -81,7 +81,9 @@ Status FeedInputs( // Update name_index (*name_index)[feed_node->name()] = feed_node; - g->AddControlEdge(g->source_node(), feed_node); + // Duplicate control edges aren't allowed, but feed_node was *just* created + // so there's no need to check for a duplicate. + g->AddControlEdge(g->source_node(), feed_node, true); // Look through edges coming out of "n" for edges whose src_output() index // matches "output_index". If found, replace the edges with a connection @@ -107,7 +109,9 @@ Status FeedInputs( g->AddEdge(feed_node, 0, e->dst(), e->dst_input()); } else { CHECK_EQ(Graph::kControlSlot, e->src_output()); - g->AddControlEdge(feed_node, e->dst()); + // Duplicate control edges aren't allowed, but feed_node was *just* + // created so there's no need to check for a duplicate. + g->AddControlEdge(feed_node, e->dst(), true); } g->RemoveEdge(e); } @@ -160,7 +164,9 @@ Status FetchOutputs( // Update the index. (*name_index)[fetch_node->name()] = fetch_node; - g->AddControlEdge(fetch_node, g->sink_node()); + // Duplicate control edges aren't allowed, but fetch_node was *just* created + // so there's no need to check for a duplicate. + g->AddControlEdge(fetch_node, g->sink_node(), true); out_fetch_nodes->push_back(fetch_node); out_fetch_types->push_back(BaseType(n->output_type(id.second))); } diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index d0b2cf01be..ab8f4bebb3 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -77,6 +77,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":cluster", + ":utils", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h index d33aaa7e4c..06db36b3aa 100644 --- a/tensorflow/core/grappler/clusters/cluster.h +++ b/tensorflow/core/grappler/clusters/cluster.h @@ -95,7 +95,7 @@ class Cluster { // The DeviceSet is not always available, but when it is it contains a // superset of the devices listed in GetDevices/GetDeviceNames(). - const DeviceSet* GetDeviceSet() const { return device_set_; } + virtual const DeviceSet* GetDeviceSet() const { return nullptr; } // Enables collecting the allocator stats. Call with enable=true must be made // before Provision(). @@ -124,7 +124,6 @@ class Cluster { protected: std::unordered_map<string, DeviceProperties> devices_; - const DeviceSet* device_set_ = nullptr; // Not owned const int timeout_s_; SessionOptions options_; RunOptions run_options_; diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc index 313ef90d81..b97603c890 100644 --- a/tensorflow/core/grappler/clusters/single_machine.cc +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -368,6 +368,15 @@ Status SingleMachine::ResetSession() { } coordinator_.reset(new Coordinator()); + // Build the DeviceSet. + device_set_.reset(new DeviceSet); + const DeviceMgr* device_mgr; + TF_RETURN_IF_ERROR(session_->LocalDeviceManager(&device_mgr)); + for (auto d : device_mgr->ListDevices()) { + device_set_->AddDevice(d); + // We currently don't care about the client device. + } + return Status::OK(); } diff --git a/tensorflow/core/grappler/clusters/single_machine.h b/tensorflow/core/grappler/clusters/single_machine.h index 0ae188e0d6..c0421dd4de 100644 --- a/tensorflow/core/grappler/clusters/single_machine.h +++ b/tensorflow/core/grappler/clusters/single_machine.h @@ -43,6 +43,8 @@ class SingleMachine : public Cluster { const std::vector<std::pair<string, Tensor>>& feed, const std::vector<string>& fetch, RunMetadata* metadata) override; + const DeviceSet* GetDeviceSet() const override { return device_set_.get(); } + Status EnablePeakMemoryStats(bool enable) override; // It requires EnableAllocatorStats(true) be called before Provision(). @@ -73,6 +75,7 @@ class SingleMachine : public Cluster { int64 expected_init_time_s_; std::unique_ptr<Coordinator> coordinator_; std::unique_ptr<thread::ThreadPool> thread_pool_; + std::unique_ptr<DeviceSet> device_set_; RunMetadata init_metadata_; diff --git a/tensorflow/core/grappler/clusters/single_machine_test.cc b/tensorflow/core/grappler/clusters/single_machine_test.cc index 352f08fede..31b19cfcfd 100644 --- a/tensorflow/core/grappler/clusters/single_machine_test.cc +++ b/tensorflow/core/grappler/clusters/single_machine_test.cc @@ -546,7 +546,7 @@ TEST_F(SingleMachineTest, ReleaseMemoryAfterDestruction) { TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory_before)); EXPECT_EQ(device_peak_memory_before.size(), 1); // There might be a bit memory used before session's running anything. - EXPECT_LT(device_peak_memory_before.begin()->second, 200); + EXPECT_LT(device_peak_memory_before.begin()->second, 400); RunMetadata metadata; TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata)); @@ -567,8 +567,8 @@ TEST_F(SingleMachineTest, ReleaseMemoryAfterDestruction) { // Check memory used by resources are released after cluster destruction. EXPECT_EQ(device_peak_memory_before.size(), 1); EXPECT_EQ(device_peak_memory_after.size(), 1); - EXPECT_LT(device_peak_memory_before.begin()->second, 200); - EXPECT_LT(device_peak_memory_after.begin()->second, 200); + EXPECT_LT(device_peak_memory_before.begin()->second, 400); + EXPECT_LT(device_peak_memory_after.begin()->second, 400); } TEST_F(SingleMachineTest, PeakMemory) { @@ -597,7 +597,7 @@ TEST_F(SingleMachineTest, PeakMemory) { device_peak_memory.end()); cpu_memory = device_peak_memory["/job:localhost/replica:0/task:0/device:CPU:0"]; - EXPECT_LT(cpu_memory, 100); + EXPECT_LT(cpu_memory, 200); } TEST_F(SingleMachineTest, PeakMemoryStatsNotEnabled) { diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc index 5c9b2320b5..12e3e46f65 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/clusters/utils.h" #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" #include "tensorflow/core/grappler/costs/virtual_scheduler.h" @@ -38,11 +39,14 @@ VirtualCluster::VirtualCluster( devices_ = devices; } -VirtualCluster::VirtualCluster( - const std::unordered_map<string, DeviceProperties>& devices, - const DeviceSet* device_set) - : VirtualCluster(devices) { +VirtualCluster::VirtualCluster(const DeviceSet* device_set) + : VirtualCluster(std::unordered_map<string, DeviceProperties>()) { device_set_ = device_set; + for (const auto& device : device_set_->devices()) { + DeviceProperties props = GetDeviceInfo(device->parsed_name()); + if (props.type() == "UNKNOWN") continue; + devices_[device->name()] = props; + } } VirtualCluster::~VirtualCluster() {} diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h index eebac68e1b..6adb0b99bc 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.h +++ b/tensorflow/core/grappler/clusters/virtual_cluster.h @@ -36,8 +36,7 @@ class VirtualCluster : public Cluster { VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices, OpLevelCostEstimator* node_estimator, ReadyNodeManager* node_manager); - VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices, - const DeviceSet* device_set); + VirtualCluster(const DeviceSet* device_set); ~VirtualCluster() override; @@ -48,10 +47,12 @@ class VirtualCluster : public Cluster { Status Run(const GraphDef& item, const std::vector<std::pair<string, Tensor>>& feed, const std::vector<string>& fetch, RunMetadata* metadata) override; + const DeviceSet* GetDeviceSet() const override { return device_set_; } private: std::unique_ptr<OpLevelCostEstimator> node_estimator_; std::unique_ptr<ReadyNodeManager> node_manager_; + const DeviceSet* device_set_ = nullptr; // Not owned }; } // end namespace grappler diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index d9a08d42db..0c02876ac5 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -353,12 +353,12 @@ void VerboseLogUnknownDimensionSources( class TopoQueue { public: explicit TopoQueue(const std::unordered_map<const NodeDef*, int>& topo_order) - : queue_(CompareNodes(topo_order)) {} - void push(const NodeDef* n) { queue_.insert(n); } + : topo_order_(topo_order) {} + void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); } const NodeDef* pop() { CHECK(!empty()); auto it = queue_.begin(); - const NodeDef* n = *it; + const NodeDef* n = it->first; queue_.erase(it); return n; } @@ -367,20 +367,16 @@ class TopoQueue { std::size_t size() const { return queue_.size(); } private: + using NodeAndId = std::pair<const NodeDef*, int>; // Graph nodes are created in (roughly) topological order. Therefore we can // use their id to ensure they're sorted topologically. - struct CompareNodes { - explicit CompareNodes( - const std::unordered_map<const NodeDef*, int>& topo_ordering) - : topo_order(topo_ordering) {} - bool operator()(const NodeDef* lhs, const NodeDef* rhs) const { - return topo_order.at(lhs) < topo_order.at(rhs); + struct OrderByIdAscending { + bool operator()(const NodeAndId& lhs, const NodeAndId& rhs) const { + return lhs.second < rhs.second; } - - private: - const std::unordered_map<const NodeDef*, int>& topo_order; }; - std::set<const NodeDef*, CompareNodes> queue_; + const std::unordered_map<const NodeDef*, int>& topo_order_; + std::set<NodeAndId, OrderByIdAscending> queue_; }; // Processes symbolic shapes. @@ -1082,6 +1078,9 @@ Status GraphProperties::UpdateShapes( // itself. TF_RETURN_IF_ERROR( UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes)); + } else if (IsQueue(*n)) { + // Set shapes and types of Queue ops, if needed. + TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes)); } else { auto c = shape_refiner->GetNodeContext(n); if (c && c->op_data && c->op_data->is_function_op) { @@ -1147,6 +1146,53 @@ Status GraphProperties::PropagateShapes( return Status::OK(); } +Status GraphProperties::UpdateQueue(const NodeDef* queue_node, + SymbolicShapeRefiner* shape_refiner, + bool* new_shapes) { + auto ctx = shape_refiner->GetNodeContext(queue_node); + if (!ctx) { + TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node)); + ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(queue_node)); + } + auto* ic = ctx->inference_context.get(); + + auto* outputs = ic->output_handle_shapes_and_types(0); + if (outputs) { + // Shapes and types are already set, presumably by Enqueue ops. + return shape_refiner->UpdateNode(queue_node, new_shapes); + } + + if (queue_node->attr().count("shapes") <= 0 || + queue_node->attr().count("component_types") <= 0 || + queue_node->attr().at("shapes").list().shape_size() != + queue_node->attr().at("component_types").list().type_size()) { + // Errors in shapes and component_types attr. + return shape_refiner->UpdateNode(queue_node, new_shapes); + } + + // Extract types and shapes from Queue attr. + const auto& shapes = queue_node->attr().at("shapes").list().shape(); + const auto& types = queue_node->attr().at("component_types").list().type(); + std::vector<ShapeAndType> shapes_and_types; + for (int i = 0; i < types.size(); i++) { + const auto& shape = shapes[i]; + ShapeHandle shape_handle; + TF_RETURN_IF_ERROR( + ic->MakeShapeFromPartialTensorShape(shape, &shape_handle)); + DataType data_type = + queue_node->attr().at("component_types").list().type(i); + ShapeAndType shape_and_type(shape_handle, data_type); + shapes_and_types.push_back(shape_and_type); + } + ic->set_output_handle_shapes_and_types(0, shapes_and_types); + + // Queue node is updated with output_handle_shapes_and_types, so set + // new_shapes and ignore it from UpdateNoe(). + *new_shapes = true; + bool dummy_new_shapes = false; + return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes); +} + Status GraphProperties::UpdateEnqueue( const NodeDef* enqueue_node, const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles, diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index 8703613a12..f716cd72c9 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -91,6 +91,11 @@ class GraphProperties { resource_handles, SymbolicShapeRefiner* shape_refiner, bool* new_shapes); + // Update the shapes and types of the Queue node, if not set by Enqueue node. + static Status UpdateQueue(const NodeDef* queue_node, + SymbolicShapeRefiner* shape_refiner, + bool* new_shapes); + // Update the output shapes of a Merge node, and enqueue its fanout in // new_shapes if needed. Status UpdateMergeNode(SymbolicShapeRefiner* shape_refiner, diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 3e44b222fd..aa787ae620 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -262,6 +262,59 @@ TEST_F(GraphPropertiesTest, VarHandles) { EXPECT_EQ(7, prop.shape().dim(1).size()); } +TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_NoShapeAttr) { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT}); + auto dequeue1 = + ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT}); + + GrapplerItem item; + TF_CHECK_OK(root.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + + const auto props1 = properties.GetOutputProperties("Dequeue1"); + ASSERT_EQ(1, props1.size()); + EXPECT_EQ("float: ?", PropToString(props1[0])); +} + +TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_ShapeAttr) { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT}, + ops::FIFOQueue::Attrs().Shapes({{3, 7, 1}})); + auto dequeue1 = + ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT}); + + GrapplerItem item; + TF_CHECK_OK(root.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + + const auto props1 = properties.GetOutputProperties("Dequeue1"); + ASSERT_EQ(1, props1.size()); + EXPECT_EQ("float: [3,7,1]", PropToString(props1[0])); +} + +TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_PartialShapeAttr) { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT}, + ops::FIFOQueue::Attrs().Shapes({{3, 7, -1}})); + auto dequeue1 = + ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT}); + + GrapplerItem item; + TF_CHECK_OK(root.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + + const auto props1 = properties.GetOutputProperties("Dequeue1"); + ASSERT_EQ(1, props1.size()); + EXPECT_EQ("float: [3,7,-1]", PropToString(props1[0])); +} + TEST_F(GraphPropertiesTest, Queues) { // Create a graph with known input shapes, and propagate the shapes through a // couple of queues. diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 2227904dbf..bdeb5c66fc 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -135,6 +135,18 @@ bool IsDequeueOp(const NodeDef& node) { bool IsDiv(const NodeDef& node) { return node.op() == "Div"; } +bool IsElementWiseMonotonic(const NodeDef& node) { + static const std::unordered_set<string>* element_wise_monotonic_ops = + CHECK_NOTNULL((new std::unordered_set<string>{ + "Relu", + "Relu6", + "Sigmoid", + "Sqrt", + "Tanh", + })); + return element_wise_monotonic_ops->count(node.op()) > 0; +} + bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; } bool IsEnter(const NodeDef& node) { @@ -617,7 +629,8 @@ bool HasOpDef(const NodeDef& node) { } bool IsIdempotent(const NodeDef& node) { - return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node); + return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) && + !ModifiesFrameInfo(node); } } // namespace grappler diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 7110a9c63d..2de7d8cc9a 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -55,6 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node); bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); bool IsDiv(const NodeDef& node); +bool IsElementWiseMonotonic(const NodeDef& node); bool IsEluGrad(const NodeDef& node); bool IsEnter(const NodeDef& node); bool IsEqual(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 33c2a0d420..8ca726df0b 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -679,6 +679,7 @@ cc_library( deps = [ ":constant_folding", ":graph_optimizer", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 9d500f8f54..90be051764 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1722,19 +1722,15 @@ class RemoveIdempotentStage : public ArithmeticOptimizerStage { ~RemoveIdempotentStage() override = default; bool IsSupported(const NodeDef* node) const override { - return IsIdempotent(*node) && !IsInPreserveSet(*node); + return node->input_size() == 1 && IsIdempotent(*node) && + !IsInPreserveSet(*node); } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { NodeDef* input; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); - auto root_scope_and_name = ParseNodeScopeAndName(node->name()); - const string new_name = OptimizedNodeName(root_scope_and_name); - if (input->op() == node->op() && input->device() == node->device() && - IsIdempotent(*input) && !ctx().node_map->NodeExists(new_name)) { - NodeDef* new_input_node = AddCopyNode(new_name, input); - ForwardControlDependencies(new_input_node, {node}); - *simplified_node_name = new_input_node->name(); + if (input->op() == node->op() && input->device() == node->device()) { + *simplified_node_name = node->input(0); } return Status::OK(); } @@ -2600,6 +2596,58 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { } }; +// Performs conversions like: +// Max(Sqrt(x)) => Sqrt(Max(x)) +// Checks for a max/min reduction over element-wise monotonic functions, such +// as Sqrt, Sigmoid, Tanh, etc. +class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { + public: + explicit OptimizeMaxOrMinOfMonotonicStage( + const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx, + ctx_ext) {} + ~OptimizeMaxOrMinOfMonotonicStage() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsMax(*node) || IsMin(*node); + } + + Status TrySimplify(NodeDef* reduction_node, + string* simplified_node_name) override { + NodeDef* inner_function; + TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function)); + // Optimize only if: + // 1. inner_function's Op is element-wise monotonic + // 2. inner_function's output is not being consumed elsewhere. + if (IsElementWiseMonotonic(*inner_function) && + (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) { + // Swap the first inputs of the inner function Op & the reduction Op. + NodeDef* inner_input; + TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input)); + inner_function->set_input(0, reduction_node->name()); + UpdateConsumersAvoidingLoop(inner_function, reduction_node->name()); + reduction_node->set_input(0, inner_input->name()); + UpdateConsumersAvoidingLoop(reduction_node, inner_function->name()); + } + return Status::OK(); + } + + void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) { + const string& node_name = node->name(); + const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name); + for (NodeDef* consumer : consumers) { + for (int i = 0; i < consumer->input_size(); ++i) { + if (consumer->input(i) == node_name && consumer->name() != new_input) { + consumer->set_input(i, new_input); + ctx().node_map->UpdateInput(consumer->name(), node_name, new_input); + } + } + AddToOptimizationQueue(consumer); + } + } +}; + } // namespace class UniqueNodes { @@ -2878,6 +2926,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext); if (options_.convert_log1p) pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext); + if (options_.optimize_max_or_min_of_monotonic) + pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 9a6081dcd8..824ef35ef6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -63,6 +63,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool hoist_common_factor_out_of_aggregation = true; bool hoist_cwise_unary_chains = false; bool minimize_broadcasts = true; + bool optimize_max_or_min_of_monotonic = true; bool remove_idempotent = true; bool remove_identity_transpose = true; bool remove_involution = true; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 177c237fe7..d0e6b04679 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -269,6 +269,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.convert_log1p = true; } + + void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.optimize_max_or_min_of_monotonic = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -2971,12 +2976,8 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) { TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 3.14f, {32}); - Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {}); - Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {}); - Output sn1 = - ops::Snapshot(s.WithOpName("sn1").WithControlDependencies(ctrl1), a); - Output sn2 = - ops::Snapshot(s.WithOpName("sn2").WithControlDependencies(ctrl2), sn1); + Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a); + Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1); Output out1 = ops::Identity(s.WithOpName("out1"), sn2); Output id1 = ops::Identity(s.WithOpName("id1"), a); Output id2 = ops::Identity(s.WithOpName("id2"), id1); @@ -2992,32 +2993,24 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) { EnableOnlyRemoveIdempotent(&optimizer); OptimizeTwice(&optimizer, &item, &output); - EXPECT_EQ(11, output.node_size()); + EXPECT_EQ(7, output.node_size()); int found = 0; for (const NodeDef& node : output.node()) { if (node.name() == "out1") { EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_sn2", node.input(0)); - found++; - } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_sn2") { - EXPECT_EQ(3, node.input_size()); - EXPECT_EQ("Snapshot", node.op()); - EXPECT_EQ("a", node.input(0)); - EXPECT_EQ("^ctrl1", node.input(1)); - EXPECT_EQ("^ctrl2", node.input(2)); + EXPECT_EQ("sn1", node.input(0)); found++; } else if (node.name() == "out2") { EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_id2", node.input(0)); + EXPECT_EQ("id1", node.input(0)); found++; - } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_id2") { - EXPECT_EQ("Identity", node.op()); + } else if (node.name() == "sn1") { EXPECT_EQ(1, node.input_size()); EXPECT_EQ("a", node.input(0)); found++; } } - EXPECT_EQ(4, found); + EXPECT_EQ(3, found); auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(tensors.size(), tensors_expected.size()); @@ -3125,5 +3118,46 @@ TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) { } } +TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x); + Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0}); + Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max); + + GrapplerItem item; + item.fetch = {"final_out"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); + EXPECT_EQ(item.graph.node_size(), output.node_size()); + // Check if the inputs are switched + int required_node_count = 0; + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + if (node.name() == "sqrt") { + EXPECT_EQ("Sqrt", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("reduce_max", node.input(0)); + ++required_node_count; + } else if (node.name() == "reduce_max") { + EXPECT_EQ("Max", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + ++required_node_count; + } + } + EXPECT_EQ(2, required_node_count); +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index f4b384ec1e..76c928f995 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -354,12 +354,14 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { } if (op == "TensorArraySizeV3") { - const NodeDef* array = node_map_->GetNode(node->input(0)); - if (array->attr().count("dynamic_size") != 0 && - array->attr().at("dynamic_size").b()) { + const NodeDef* array = CHECK_NOTNULL(node_map_->GetNode(node->input(0))); + if (array->input_size() == 0 || + (array->attr().count("dynamic_size") != 0 && + array->attr().at("dynamic_size").b())) { continue; } - const NodeDef* array_size = node_map_->GetNode(array->input(0)); + const NodeDef* array_size = + CHECK_NOTNULL(node_map_->GetNode(array->input(0))); if (IsReallyConstant(*array_size)) { // Don't materialize 0 sizes to avoid triggering incorrect static // checks. A 0 sized array that can't grow isn't useful anyway. @@ -374,6 +376,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { if (value.flat<int32>()(0) == 0) { continue; } + node->set_op("Const"); *node->mutable_attr() = array_size->attr(); node->set_input(0, AsControlDependency(NodeName(node->input(0)))); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 9f051ca248..b9765b9292 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -3000,6 +3000,10 @@ TEST_F(ConstantFoldingTest, Enter) { TEST_F(ConstantFoldingTest, TensorArraySize) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({})); + Output placeholder = + ops::Placeholder(scope.WithOpName("placeholder"), DT_RESOURCE, + ops::Placeholder::Shape(TensorShape({2}))); + Output foo = ops::Const(scope.WithOpName("foo"), 5.0f, TensorShape({})); auto dynamic_array = ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT, ops::TensorArray::DynamicSize(true)); @@ -3010,6 +3014,8 @@ TEST_F(ConstantFoldingTest, TensorArraySize) { scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow); auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"), static_array.handle, static_array.flow); + auto placeholder_sz = ops::TensorArraySize(scope.WithOpName("placeholder_sz"), + placeholder, foo); GrapplerItem item; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); @@ -3026,11 +3032,13 @@ TEST_F(ConstantFoldingTest, TensorArraySize) { status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(5, output.node_size()); - EXPECT_EQ("dynamic_sz", output.node(3).name()); - EXPECT_EQ("TensorArraySizeV3", output.node(3).op()); - EXPECT_EQ("static_sz", output.node(4).name()); - EXPECT_EQ("Const", output.node(4).op()); + EXPECT_EQ(8, output.node_size()); + EXPECT_EQ("dynamic_sz", output.node(5).name()); + EXPECT_EQ("TensorArraySizeV3", output.node(5).op()); + EXPECT_EQ("static_sz", output.node(6).name()); + EXPECT_EQ("Const", output.node(6).op()); + EXPECT_EQ("placeholder_sz", output.node(7).name()); + EXPECT_EQ("TensorArraySizeV3", output.node(7).op()); auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"}); EXPECT_EQ(2, tensors_expected.size()); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index 3f5bab9d3b..fdd82b9603 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -260,14 +260,14 @@ void DependencyOptimizer::OptimizeNode(int node_idx, } continue; } + // Replace a normal input with a control input. const string ctrl_input = ConstantFolding::AddControlDependency( old_input, optimized_graph_, node_map_.get()); - if (ctrl_inputs.insert(ctrl_input).second) { - node->set_input(pos, ctrl_input); - node_map_->UpdateInput(node_name, old_input, ctrl_input); - const NodeDef* old_input_node = node_map_->GetNode(old_input); - nodes_to_simplify->PushBack(node_to_idx_[old_input_node]); - } + ctrl_inputs.insert(ctrl_input); + node->set_input(pos, ctrl_input); + node_map_->UpdateInput(node_name, old_input, ctrl_input); + const NodeDef* old_input_node = node_map_->GetNode(old_input); + nodes_to_simplify->PushBack(node_to_idx_[old_input_node]); ++pos; } node->set_op("NoOp"); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc index 0ae3b4ec34..c0f07562af 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -124,25 +124,62 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop) { TF_EXPECT_OK(status); EXPECT_EQ(item.graph.node_size(), output.node_size()); + int found = 0; for (int i = 0; i < item.graph.node_size(); ++i) { const NodeDef& node = item.graph.node(i); - if (node.name() == "add") { - EXPECT_EQ("NoOp", node.op()); - EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("^x", node.input(0)); - EXPECT_EQ("^y", node.input(1)); - } else if (node.name() == "id1") { + // "add" should get turned into a NoOp and removed. + EXPECT_NE("add", node.name()); + if (node.name() == "id1") { EXPECT_EQ("Identity", node.op()); EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^y", node.input(1)); + ++found; } else if (node.name() == "id2") { EXPECT_EQ("Identity", node.op()); EXPECT_EQ(2, node.input_size()); EXPECT_EQ("y", node.input(0)); EXPECT_EQ("^x", node.input(1)); + ++found; + } + } + EXPECT_EQ(2, found); +} + +TEST_F(DependencyOptimizerTest, ChangeToNoop_RepeatedInput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT); + Output add = ops::Add(s.WithOpName("add"), x, x); + Output id1 = + ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"id1"}; + + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + LOG(INFO) << output.DebugString(); + + EXPECT_EQ(item.graph.node_size(), output.node_size()); + int found = 0; + for (int i = 0; i < item.graph.node_size(); ++i) { + const NodeDef& node = item.graph.node(i); + // "add" should get turned into a NoOp and removed. + EXPECT_NE("add", node.name()); + if (node.name() == "id1") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("x", node.input(0)); + ++found; } } + EXPECT_EQ(1, found); } TEST_F(DependencyOptimizerTest, ChangeToNoop_SwitchIdentity) { @@ -400,6 +437,7 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity) { TF_EXPECT_OK(status); EXPECT_EQ(item.graph.node_size() - 3, output.node_size()); + int found = 0; for (const NodeDef& node : output.node()) { EXPECT_NE("id_a", node.name()); EXPECT_NE("id_b", node.name()); @@ -407,30 +445,36 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity) { if (node.name() == "a_a" || node.name() == "a_b") { EXPECT_EQ(1, node.input_size()); EXPECT_EQ("x", node.input(0)); + ++found; } if (node.name() == "a_c" || node.name() == "a_d") { EXPECT_EQ(2, node.input_size()); EXPECT_EQ("z", node.input(0)); EXPECT_EQ("^x", node.input(1)); + ++found; } if (node.name() == "b_a") { EXPECT_EQ(3, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^y", node.input(1)); EXPECT_EQ("^z", node.input(2)); + ++found; } if (node.name() == "c_a") { EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^y", node.input(1)); + ++found; } if (node.name() == "c_b") { EXPECT_EQ(3, node.input_size()); EXPECT_EQ("z", node.input(0)); EXPECT_EQ("^x", node.input(1)); EXPECT_EQ("^y", node.input(2)); + ++found; } } + EXPECT_EQ(found, 7); } TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) { @@ -460,17 +504,20 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) { TF_EXPECT_OK(status); EXPECT_EQ(item.graph.node_size() - 1, output.node_size()); + int found = 0; for (const NodeDef& node : output.node()) { EXPECT_NE("id0", node.name()); if (node.name() == "or0") { EXPECT_EQ(2, node.input_size()); EXPECT_EQ("switch:1", node.input(0)); EXPECT_EQ("switch:1", node.input(1)); + ++found; } if (node.name() == "or1") { EXPECT_EQ(2, node.input_size()); EXPECT_EQ("switch:1", node.input(0)); EXPECT_EQ("y", node.input(1)); + ++found; } if (node.name() == "or2") { // or1 should be unchanged. @@ -478,8 +525,10 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) { EXPECT_EQ("y", node.input(0)); EXPECT_EQ("y", node.input(1)); EXPECT_EQ("^id1", node.input(2)); + ++found; } } + EXPECT_EQ(found, 3); } TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) { @@ -535,6 +584,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) { TF_EXPECT_OK(status); EXPECT_EQ(item.graph.node_size() - 2, output.node_size()); + bool found = false; for (int i = 0; i < output.node_size(); ++i) { const NodeDef& node = output.node(i); // "id0" and "id1" but neither "ConstantFoldingCtrl/switch_1", @@ -545,8 +595,10 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) { EXPECT_EQ("Const", node.op()); EXPECT_EQ(1, node.input_size()); EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0)); + found = true; } } + EXPECT_TRUE(found); } TEST_F(DependencyOptimizerTest, IdentityInputs) { diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 143d9dc1c6..b1f31ad0d0 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -42,6 +42,7 @@ namespace grappler { namespace { constexpr int kDefaultNumberOfIterations = 2; +constexpr int kDefaultMinGraphNodes = 4; int64 NumEdges(const GraphDef& graph) { int64 num_edges = 0; @@ -194,6 +195,15 @@ Status MetaOptimizer::InitializeOptimizersByName( Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { + int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes + : cfg_.min_graph_nodes(); + if (item.graph.node_size() < min_graph_nodes) { + VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes + << " nodes."; + *optimized_graph = item.graph; + return Status::OK(); + } + std::vector<std::unique_ptr<GraphOptimizer>> optimizers; if (cfg_.optimizers().empty() && cfg_.custom_optimizers().empty()) { TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers)); @@ -202,10 +212,11 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, } VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id - << " num_optimizers=" << optimizers.size(); + << " num_optimizers=" << optimizers.size() + << ", num nodes = " << item.graph.node_size(); if (optimizers.empty()) { - VLOG(3) << "Skip graph optimization, no optimizers registered"; + VLOG(3) << "Skipping graph optimization, no optimizers registered"; *optimized_graph = item.graph; return Status::OK(); } @@ -221,8 +232,15 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, GraphOptimizer* sa_optimizer = nullptr; for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) { - VLOG(4) << "Starting optimization iteration " << iteration + 1; + // Don't bother optimizing further if the graph is already tiny. + if (optimized_graph->node_size() < min_graph_nodes) { + VLOG(3) << "Stopping after iteration " << iteration + << ", graph is tiny (#nodes = " << optimized_graph->node_size() + << " < " << min_graph_nodes << ")"; + break; + } + VLOG(4) << "Starting optimization iteration " << iteration; for (const auto& optimizer : optimizers) { // Some optimizers can run only once. if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue; @@ -235,7 +253,6 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, if (fusion_optimizer == nullptr) fusion_optimizer = optimizer.get(); continue; } - Status status = RunOptimizer(optimizer.get(), cluster, &optimized_item, optimized_graph, &optimization_result); if (status.ok()) is_optimized = true; @@ -297,7 +314,7 @@ Status MetaOptimizer::RunOptimizer( PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph), ", time = ", duration_ms, "ms."); } - VLOG(4) << optimizer->name() << ": " << result; + VLOG(1) << optimizer->name() << ": " << result; OptimizerResult optimizer_result{optimizer->name(), result}; optimization_result->results.push_back(optimizer_result); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index 8247cce339..9a03c7dfef 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -74,6 +74,7 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { TestOptimizer::SetOptimized(false); RewriterConfig rewriter_config; rewriter_config.add_optimizers("TestOptimizer"); + rewriter_config.set_min_graph_nodes(-1); MetaOptimizer optimizer(nullptr, rewriter_config); GraphDef output; @@ -89,6 +90,7 @@ TEST_F(MetaOptimizerTest, RunOptimizersTwice) { RewriterConfig rewriter_config; rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO); + rewriter_config.set_min_graph_nodes(-1); MetaOptimizer optimizer(nullptr, rewriter_config); GraphDef output; @@ -104,6 +106,7 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) { rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO); rewriter_config.set_function_optimization(RewriterConfig::ON); rewriter_config.add_optimizers("function"); + rewriter_config.set_min_graph_nodes(-1); MetaOptimizer optimizer(nullptr, rewriter_config); diff --git a/tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc index 5a529bd8ca..508a47deda 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc @@ -16,10 +16,12 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" +#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h" namespace tensorflow { namespace functor { DEFINE_BINARY2(igamma, float, double); +DEFINE_BINARY2(igamma_grad_a, float, double); DEFINE_BINARY2(igammac, float, double); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_random_grad.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_random_grad.cu.cc new file mode 100644 index 0000000000..fd0a95ecc5 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_random_grad.cu.cc @@ -0,0 +1,26 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY2(random_gamma_grad, float, double); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_igammas.cc b/tensorflow/core/kernels/cwise_op_igammas.cc index 4b5f888bc1..cadda3b723 100644 --- a/tensorflow/core/kernels/cwise_op_igammas.cc +++ b/tensorflow/core/kernels/cwise_op_igammas.cc @@ -14,12 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/cwise_ops_common.h" +#include "tensorflow/core/kernels/cwise_ops_gradients.h" namespace tensorflow { REGISTER2(BinaryOp, CPU, "Igamma", functor::igamma, float, double); +REGISTER2(BinaryOp, CPU, "IgammaGradA", functor::igamma_grad_a, float, double); REGISTER2(BinaryOp, CPU, "Igammac", functor::igammac, float, double); #if GOOGLE_CUDA REGISTER2(BinaryOp, GPU, "Igamma", functor::igamma, float, double); +REGISTER2(BinaryOp, GPU, "IgammaGradA", functor::igamma_grad_a, float, double); REGISTER2(BinaryOp, GPU, "Igammac", functor::igammac, float, double); #endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_random_grad.cc b/tensorflow/core/kernels/cwise_op_random_grad.cc new file mode 100644 index 0000000000..8e388ead9e --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_random_grad.cc @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER2(BinaryOp, CPU, "RandomGammaGrad", functor::random_gamma_grad, float, + double); +#if GOOGLE_CUDA +REGISTER2(BinaryOp, GPU, "RandomGammaGrad", functor::random_gamma_grad, float, + double); +#endif +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 8b015df4e1..1b1a704d42 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -771,6 +771,10 @@ template <typename T> struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {}; template <typename T> +struct random_gamma_grad + : base<T, Eigen::internal::scalar_gamma_sample_der_alpha_op<T>> {}; + +template <typename T> struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {}; template <typename T> diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h index 82cdae9a34..7a6f14babc 100644 --- a/tensorflow/core/kernels/cwise_ops_gradients.h +++ b/tensorflow/core/kernels/cwise_ops_gradients.h @@ -202,6 +202,9 @@ struct sqrt_grad : base<T, Eigen::internal::scalar_sqrt_gradient_op<T>> {}; template <typename T> struct rsqrt_grad : base<T, Eigen::internal::scalar_rsqrt_gradient_op<T>> {}; +template <typename T> +struct igamma_grad_a : base<T, Eigen::internal::scalar_igamma_der_a_op<T>> {}; + } // end namespace functor } // end namespace tensorflow diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index f33e9cec29..b476a452a5 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -779,7 +779,7 @@ class OneShotIteratorOp : public AsyncOpKernel { } private: - void Init(OpKernelContext* ctx, DoneCallback done) { + void Init(OpKernelContext* ctx, const DoneCallback& done) { IteratorResource* iterator = nullptr; ContainerInfo cinfo; Status s = TryInit(ctx, &iterator, &cinfo); diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 586677a2d6..aa40f95cde 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -219,8 +219,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } std::swap(result, batch_results_.front()); batch_results_.pop_front(); - cond_var_.notify_all(); } + cond_var_.notify_all(); return ProcessBatch(ctx, result, out_tensors, end_of_sequence); } @@ -286,7 +286,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { void Callback(const std::shared_ptr<IteratorContext>& ctx, const std::shared_ptr<BatchResult>& result, const std::shared_ptr<std::vector<Tensor>>& return_values, - int64 offset, const Status& status) { + int64 offset, const Status& status) LOCKS_EXCLUDED(mu_) { result->UpdateStatus(status); if (status.ok()) { EnsureOutputAllocated(ctx, result, return_values); @@ -318,36 +318,37 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { mutex_lock l(result->mu); result->num_elements++; } - { - mutex_lock l(mu_); - CallCompleted(result); - } + CallCompleted(result); } void CallCompleted(const std::shared_ptr<BatchResult>& result) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - num_calls_--; + LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + num_calls_--; + result->num_calls--; + } cond_var_.notify_all(); - result->num_calls--; } void CallFunction(std::shared_ptr<IteratorContext> ctx, const std::shared_ptr<BatchResult>& result, - int64 offset) { + int64 offset) LOCKS_EXCLUDED(mu_) { // Get the next input element. std::vector<Tensor> input_element; bool end_of_input; Status status = input_impl_->GetNext(ctx.get(), &input_element, &end_of_input); + bool return_early; { - mutex_lock l(mu_); - mutex_lock l2(result->mu); + mutex_lock l(result->mu); result->end_of_input = result->end_of_input || end_of_input; result->status.Update(status); - if (result->end_of_input || !result->status.ok()) { - CallCompleted(result); - return; - } + return_early = result->end_of_input || !result->status.ok(); + } + if (return_early) { + CallCompleted(result); + return; } // Call `captured_func_(input_element)`, using `Callback` to store the @@ -468,36 +469,43 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return result->status; } - void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { - mutex_lock l(mu_); + void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) + LOCKS_EXCLUDED(mu_) { + std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls; + new_calls.reserve(dataset()->num_parallel_calls_); while (true) { - while (!cancelled_ && - (num_calls_ >= dataset()->num_parallel_calls_ || - batch_results_.size() > MaxBatchResults() || - (batch_results_.size() == MaxBatchResults() && - call_counter_ % dataset()->batch_size_ == 0))) { - cond_var_.wait(l); - } + { + mutex_lock l(mu_); + while (!cancelled_ && + (num_calls_ >= dataset()->num_parallel_calls_ || + batch_results_.size() > MaxBatchResults() || + (batch_results_.size() == MaxBatchResults() && + call_counter_ % dataset()->batch_size_ == 0))) { + cond_var_.wait(l); + } - if (cancelled_) { - return; - } + if (cancelled_) { + return; + } - while (num_calls_ < dataset()->num_parallel_calls_ && - (batch_results_.size() < MaxBatchResults() || - (batch_results_.size() == MaxBatchResults() && - call_counter_ % dataset()->batch_size_ != 0))) { - if (call_counter_ % dataset()->batch_size_ == 0) { - batch_results_.emplace_back( - new BatchResult(dataset()->batch_size_)); + while (num_calls_ < dataset()->num_parallel_calls_ && + (batch_results_.size() < MaxBatchResults() || + (batch_results_.size() == MaxBatchResults() && + call_counter_ % dataset()->batch_size_ != 0))) { + if (call_counter_ % dataset()->batch_size_ == 0) { + batch_results_.emplace_back( + new BatchResult(dataset()->batch_size_)); + } + int64 offset = call_counter_++ % dataset()->batch_size_; + new_calls.emplace_back(batch_results_.back(), offset); + num_calls_++; } - std::shared_ptr<BatchResult> result = batch_results_.back(); - int64 offset = call_counter_++ % dataset()->batch_size_; - num_calls_++; - mu_.unlock(); - CallFunction(ctx, result, offset); - mu_.lock(); } + + for (const auto& call : new_calls) { + CallFunction(ctx, call.first, call.second); + } + new_calls.clear(); } } diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 3fa6b0d3a9..15f3dc3b1d 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -151,8 +151,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator<Dataset> { public: explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params), - invocation_results_(params.dataset->num_parallel_calls_) {} + : DatasetIterator<Dataset>(params) {} ~Iterator() override { // TODO(mrry): Replace this cancellation logic with a @@ -160,13 +159,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { // but it would be possible to thread a cancellation manager // through the IteratorContext to upstream, // potentially-blocking iterators, when we add these. - { - mutex_lock l(mu_); - for (size_t i = 0; i < dataset()->num_parallel_calls_; ++i) { - if (invocation_results_[i].notification) { - invocation_results_[i].notification->WaitForNotification(); - } - } + mutex_lock l(mu_); + // Cancel the runner thread. + cancelled_ = true; + cond_var_.notify_all(); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); } } @@ -177,173 +176,191 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { - mutex_lock l(mu_); - - // Ensure that there are `dataset()->num_parallel_calls_` - // invocations of `func_` outstanding at once. - while (input_impl_ && (num_inputs_consumed_ - num_outputs_consumed_ < - dataset()->num_parallel_calls_)) { - InvokeFunctionLocked(ctx); - } - - if (!input_impl_ && num_inputs_consumed_ == num_outputs_consumed_) { - *end_of_sequence = true; - return Status::OK(); - } - - // Read the next result out of `invocation_results_`, which - // acts as a circular buffer. - const size_t result_index = - num_outputs_consumed_ % dataset()->num_parallel_calls_; - InvocationResult* result = &invocation_results_[result_index]; - *end_of_sequence = false; - if (result->notification) { - result->notification->WaitForNotification(); - if (result->status.ok()) { - std::swap(*out_tensors, result->return_values); + std::shared_ptr<InvocationResult> result; + { + mutex_lock l(mu_); + EnsureRunnerThreadStarted(ctx); + while (invocation_results_.empty()) { + cond_var_.wait(l); } + std::swap(result, invocation_results_.front()); + invocation_results_.pop_front(); } - ++num_outputs_consumed_; - if (errors::IsOutOfRange(result->status)) { - // `f` may deliberately raise `errors::OutOfRange` to indicate - // that we should terminate the iteration early. - *end_of_sequence = true; - return Status::OK(); - } else { - return result->status; - } + cond_var_.notify_all(); + result->notification.WaitForNotification(); + return ProcessResult(result, out_tensors, end_of_sequence); } protected: Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - if (input_impl_) { - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); - } else { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("end_of_input"), "")); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); } - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_inputs_consumed"), - num_inputs_consumed_)); + CHECK_EQ(num_calls_, 0); + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name("num_outputs_consumed"), num_outputs_consumed_)); - - for (size_t i = 0; i < dataset()->num_parallel_calls_; i++) { - if (invocation_results_[i].notification) { - invocation_results_[i].notification->WaitForNotification(); - TF_RETURN_IF_ERROR( - WriteStatusLocked(writer, i, invocation_results_[i].status)); - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat("invocation_results[", i, "].size")), - invocation_results_[i].return_values.size())); - for (size_t j = 0; j < invocation_results_[i].return_values.size(); - j++) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name( - strings::StrCat("invocation_results[", i, "][", j, "]")), - invocation_results_[i].return_values[j])); - } - } else { + full_name("invocation_results.size"), invocation_results_.size())); + for (size_t i = 0; i < invocation_results_.size(); i++) { + std::shared_ptr<InvocationResult> result = invocation_results_[i]; + TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("invocation_results[", i, "].size")), + result->return_values.size())); + for (size_t j = 0; j < result->return_values.size(); j++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name( + strings::StrCat("invocation_results[", i, "][", j, "]")), + result->return_values[j])); + } + if (result->end_of_input) { TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat("invocation_results[", i, "]_empty")), + full_name(strings::StrCat("invocation_results[", i, + "].end_of_input")), "")); } } - return Status::OK(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); - if (reader->Contains(full_name("end_of_input"))) { - input_impl_.reset(); - } else { - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); - } - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_inputs_consumed"), - &num_inputs_consumed_)); - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_outputs_consumed"), - &num_outputs_consumed_)); - for (size_t i = 0; i < dataset()->num_parallel_calls_; i++) { - InvocationResult* result = &invocation_results_[i]; - *result = InvocationResult(); - if (!reader->Contains(full_name( - strings::StrCat("invocation_results[", i, "]_empty")))) { - result->notification.reset(new Notification); - result->notification->Notify(); - TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status)); - size_t num_return_values; - { - int64 size; - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name(strings::StrCat( - "invocation_results[", i, "].size")), - &size)); - num_return_values = static_cast<size_t>(size); - if (num_return_values != size) { - return errors::InvalidArgument(strings::StrCat( - full_name( - strings::StrCat("invocation_results[", i, "].size")), - ": ", size, " is not a valid value of type size_t.")); - } - } - result->return_values.reserve(num_return_values); - for (size_t j = 0; j < num_return_values; j++) { - result->return_values.emplace_back(); - TF_RETURN_IF_ERROR(reader->ReadTensor( + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + int64 invocation_results_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name("invocation_results.size"), &invocation_results_size)); + for (size_t i = 0; i < invocation_results_size; i++) { + std::shared_ptr<InvocationResult> result(new InvocationResult()); + invocation_results_.push_back(result); + TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status)); + size_t num_return_values; + { + int64 size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("invocation_results[", i, "].size")), + &size)); + num_return_values = static_cast<size_t>(size); + if (num_return_values != size) { + return errors::InvalidArgument(strings::StrCat( full_name( - strings::StrCat("invocation_results[", i, "][", j, "]")), - &result->return_values.back())); + strings::StrCat("invocation_results[", i, "].size")), + ": ", size, " is not a valid value of type size_t.")); } } + result->return_values.reserve(num_return_values); + for (size_t j = 0; j < num_return_values; j++) { + result->return_values.emplace_back(); + TF_RETURN_IF_ERROR( + reader->ReadTensor(full_name(strings::StrCat( + "invocation_results[", i, "][", j, "]")), + &result->return_values.back())); + } + result->end_of_input = reader->Contains(full_name( + strings::StrCat("invocation_results[", i, "].end_of_input"))); + result->notification.Notify(); } return Status::OK(); } private: struct InvocationResult { + Notification notification; Status status; - std::unique_ptr<Notification> notification; std::vector<Tensor> return_values; + bool end_of_input; }; - void InvokeFunctionLocked(IteratorContext* ctx) + void EnsureRunnerThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - DCHECK(input_impl_); - DCHECK(num_inputs_consumed_ - num_outputs_consumed_ < - dataset()->num_parallel_calls_); + if (!runner_thread_) { + std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); + runner_thread_.reset(ctx->env()->StartThread( + {}, "runner_thread", + std::bind(&Iterator::RunnerThread, this, ctx_copy))); + } + } - // The result of invoking the function will be written into the next - // slot in `invocation_results_`, which acts as a circular buffer. - const size_t result_index = - num_inputs_consumed_ % dataset()->num_parallel_calls_; - InvocationResult* result = &invocation_results_[result_index]; - *result = InvocationResult(); + void CallCompleted(const std::shared_ptr<InvocationResult>& result) + LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + num_calls_--; + } + result->notification.Notify(); + cond_var_.notify_all(); + } + void CallFunction(const std::shared_ptr<IteratorContext>& ctx, + const std::shared_ptr<InvocationResult>& result) + LOCKS_EXCLUDED(mu_) { // Get the next input element. std::vector<Tensor> input_element; - bool end_of_input = false; - result->status = - input_impl_->GetNext(ctx, &input_element, &end_of_input); - if (end_of_input) { - input_impl_.reset(); - result->status = errors::OutOfRange(""); - } else { - ++num_inputs_consumed_; + result->status = input_impl_->GetNext(ctx.get(), &input_element, + &result->end_of_input); + if (result->end_of_input || !result->status.ok()) { + CallCompleted(result); + return; } - if (result->status.ok()) { - // Call `func_(input_element)`, store the result in - // `result->return_values`, and notify `result->notification` - // to unblock a consumer. - result->notification.reset(new Notification); - dataset()->captured_func_->RunAsync( - ctx, std::move(input_element), &result->return_values, - [result, result_index](Status ret_status) { - result->status.Update(ret_status); - result->notification->Notify(); - }); + // Call `func_(input_element)`, store the result in + // `result->return_values`, and notify `result->notification` to unblock + // a consumer. + auto done = [this, result](Status status) { + result->status.Update(status); + CallCompleted(result); + }; + dataset()->captured_func_->RunAsync(ctx.get(), std::move(input_element), + &result->return_values, done); + } + + int64 MaxInvocationResults() { return dataset()->num_parallel_calls_; } + + Status ProcessResult(const std::shared_ptr<InvocationResult>& result, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) { + if (!result->end_of_input && result->status.ok()) { + *out_tensors = std::move(result->return_values); + *end_of_sequence = false; + return Status::OK(); + } + if (errors::IsOutOfRange(result->status)) { + // `f` may deliberately raise `errors::OutOfRange` to indicate that we + // should terminate the iteration early. + *end_of_sequence = true; + return Status::OK(); + } + *end_of_sequence = result->end_of_input; + return result->status; + } + + void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { + std::vector<std::shared_ptr<InvocationResult>> new_calls; + new_calls.reserve(dataset()->num_parallel_calls_); + while (true) { + { + mutex_lock l(mu_); + while (!cancelled_ && + (num_calls_ >= dataset()->num_parallel_calls_ || + invocation_results_.size() >= MaxInvocationResults())) { + cond_var_.wait(l); + } + if (cancelled_) { + return; + } + while (num_calls_ < dataset()->num_parallel_calls_ && + invocation_results_.size() < MaxInvocationResults()) { + invocation_results_.emplace_back(new InvocationResult()); + new_calls.push_back(invocation_results_.back()); + num_calls_++; + } + } + cond_var_.notify_all(); + for (const auto& call : new_calls) { + CallFunction(ctx, call); + } + new_calls.clear(); } } @@ -386,11 +403,22 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { strings::StrCat("invocation_results[", index, "].error_message")); } + // Used for coordination between the main thread and the runner thread. mutex mu_; - std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); - std::vector<InvocationResult> invocation_results_ GUARDED_BY(mu_); - int64 num_inputs_consumed_ GUARDED_BY(mu_) = 0; - int64 num_outputs_consumed_ GUARDED_BY(mu_) = 0; + // Used for coordination between the main thread and the runner thread. In + // particular, the runner thread should only schedule new calls when the + // number of in-flight calls is less than the user specified level of + // parallelism and there are slots available in the `invocation_results_` + // buffer. + condition_variable cond_var_; + // Counts the number of outstanding calls. + int64 num_calls_ GUARDED_BY(mu_) = 0; + std::unique_ptr<IteratorBase> input_impl_; + // Buffer for storing the invocation results. + std::deque<std::shared_ptr<InvocationResult>> invocation_results_ + GUARDED_BY(mu_); + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); + bool cancelled_ GUARDED_BY(mu_) = false; }; const DatasetBase* const input_; diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index d0703d7576..15004ae4df 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -24,6 +24,14 @@ limitations under the License. // non-GPU targets. This only breaks in clang, because it's more strict for // template code and CudaAtomicMax is used in template context. + +// This file requires the following include because it uses CudaAtomicMax: +// #include "tensorflow/core/util/cuda_kernel_helper.h" + +// Unfortunately we can't add the #include, since it breaks compilation for +// non-GPU targets. This only breaks in clang, because it's more strict for +// template code and CudaAtomicMax is used in template context. + // This file requires the following include because it uses CudaAtomicMax: // #include "tensorflow/core/util/cuda_kernel_helper.h" @@ -138,4 +146,4 @@ struct Highest { } // namespace functor } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ +#endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc index 3996ff0027..26ab72f12e 100644 --- a/tensorflow/core/kernels/string_split_op.cc +++ b/tensorflow/core/kernels/string_split_op.cc @@ -184,7 +184,7 @@ class StringSplitV2Op : public OpKernel { public: explicit StringSplitV2Op(OpKernelConstruction* context) : OpKernel(context), maxsplit_(-1) { - context->GetAttr("maxsplit", &maxsplit_); + OP_REQUIRES_OK(context, context->GetAttr("maxsplit", &maxsplit_)); } void Compute(OpKernelContext* ctx) override { diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 726bfd63b7..11ed50d30e 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -6426,6 +6426,68 @@ op { } } op { + name: "AsString" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_FLOAT + type: DT_DOUBLE + type: DT_BOOL + } + } + } + attr { + name: "precision" + type: "int" + default_value { + i: -1 + } + } + attr { + name: "scientific" + type: "bool" + default_value { + b: false + } + } + attr { + name: "shortest" + type: "bool" + default_value { + b: false + } + } + attr { + name: "width" + type: "int" + default_value { + i: -1 + } + } + attr { + name: "fill" + type: "string" + default_value { + s: "" + } + } +} +op { name: "Asin" input_arg { name: "x" @@ -25578,6 +25640,31 @@ op { } } op { + name: "IgammaGradA" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { name: "Igammac" input_arg { name: "a" @@ -41254,6 +41341,31 @@ op { is_stateful: true } op { + name: "RandomGammaGrad" + input_arg { + name: "alpha" + type_attr: "T" + } + input_arg { + name: "sample" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { name: "RandomPoisson" input_arg { name: "shape" @@ -68140,6 +68252,36 @@ op { } } op { + name: "StringSplitV2" + input_arg { + name: "input" + type: DT_STRING + } + input_arg { + name: "sep" + type: DT_STRING + } + output_arg { + name: "indices" + type: DT_INT64 + } + output_arg { + name: "values" + type: DT_STRING + } + output_arg { + name: "shape" + type: DT_INT64 + } + attr { + name: "maxsplit" + type: "int" + default_value { + i: -1 + } + } +} +op { name: "StringStrip" input_arg { name: "input" @@ -72671,6 +72813,73 @@ op { } } op { + name: "UnsortedSegmentProd" + input_arg { + name: "data" + type_attr: "T" + } + input_arg { + name: "segment_ids" + type_attr: "Tindices" + } + input_arg { + name: "num_segments" + type_attr: "Tnumsegments" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tnumsegments" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "UnsortedSegmentSum" input_arg { name: "data" diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 0f2d438a03..fd59622b27 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -489,6 +489,13 @@ REGISTER_OP("Igamma") .Attr("T: {float, double}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); +REGISTER_OP("IgammaGradA") + .Input("a: T") + .Input("x: T") + .Output("z: T") + .Attr("T: {float, double}") + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); + REGISTER_OP("Zeta") .Input("x: T") .Input("q: T") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index c609703bcb..c7f74c205a 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -1977,13 +1977,14 @@ op { type: "type" allowed_values { list { + type: DT_INT8 + type: DT_INT16 type: DT_INT32 type: DT_INT64 type: DT_COMPLEX64 type: DT_FLOAT type: DT_DOUBLE type: DT_BOOL - type: DT_INT8 } } } @@ -12446,6 +12447,31 @@ op { } } op { + name: "IgammaGradA" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { name: "Igammac" input_arg { name: "a" @@ -20707,6 +20733,31 @@ op { is_stateful: true } op { + name: "RandomGammaGrad" + input_arg { + name: "alpha" + type_attr: "T" + } + input_arg { + name: "sample" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { name: "RandomPoisson" input_arg { name: "shape" @@ -31613,6 +31664,36 @@ op { } } op { + name: "StringSplitV2" + input_arg { + name: "input" + type: DT_STRING + } + input_arg { + name: "sep" + type: DT_STRING + } + output_arg { + name: "indices" + type: DT_INT64 + } + output_arg { + name: "values" + type: DT_STRING + } + output_arg { + name: "shape" + type: DT_INT64 + } + attr { + name: "maxsplit" + type: "int" + default_value { + i: -1 + } + } +} +op { name: "StringStrip" input_arg { name: "input" @@ -34534,9 +34615,14 @@ op { type: DT_UINT8 type: DT_INT16 type: DT_INT8 + type: DT_COMPLEX64 type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 type: DT_BFLOAT16 type: DT_UINT16 + type: DT_COMPLEX128 type: DT_HALF type: DT_UINT32 type: DT_UINT64 diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc index 80ffae5796..a76248e05f 100644 --- a/tensorflow/core/ops/random_ops.cc +++ b/tensorflow/core/ops/random_ops.cc @@ -138,6 +138,13 @@ REGISTER_OP("RandomGamma") return Status::OK(); }); +REGISTER_OP("RandomGammaGrad") + .Input("alpha: T") + .Input("sample: T") + .Output("output: T") + .Attr("T: {float, double}") + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); + REGISTER_OP("RandomPoisson") .SetIsStateful() .Input("shape: S") diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 3d0a6c2157..26499540f1 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -14,6 +14,7 @@ // ============================================================================ #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -84,6 +85,22 @@ REGISTER_OP("ReadVariableOp") .Attr("dtype: type") .SetShapeFn(ReadVariableShapeFn); +Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FunctionDefHelper::Define( + // Arg defs + {"x: resource", "dy: float"}, + // Ret val defs + {"dy: float"}, + // Attr defs + {}, + // Nodes + {}); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("ReadVariableOp", ReadGrad); + REGISTER_OP("DestroyResourceOp") .Input("resource: resource") .Attr("ignore_lookup_error: bool = true") diff --git a/tensorflow/core/platform/fingerprint.h b/tensorflow/core/platform/fingerprint.h index b47dcdedd7..720dc4c3d6 100644 --- a/tensorflow/core/platform/fingerprint.h +++ b/tensorflow/core/platform/fingerprint.h @@ -74,7 +74,7 @@ inline uint64 FingerprintCat64(const uint64 fp1, const uint64 fp2) { } // namespace tensorflow -#if defined(PLATFORM_GOOGLE) +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) #include "tensorflow/core/platform/google/fingerprint.h" #else #include "tensorflow/core/platform/default/fingerprint.h" diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index bbb25d6f3f..07f984ceea 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -80,6 +80,12 @@ message RewriterConfig { // is once). NumIterationsType meta_optimizer_iterations = 12; + // The minimum number of nodes in a graph to optimizer. For smaller graphs, + // optimization is skipped. + // 0 means the system picks an appropriate number. + // < 0 means do not skip optimization. + int32 min_graph_nodes = 17; + enum MemOptType { // The default setting (SCHEDULING and SWAPPING HEURISTICS only) DEFAULT_MEM_OPT = 0; diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc index 337af07b50..f4bd2950e9 100644 --- a/tensorflow/core/util/work_sharder.cc +++ b/tensorflow/core/util/work_sharder.cc @@ -20,12 +20,22 @@ limitations under the License. namespace tensorflow { +/* ABSL_CONST_INIT */ thread_local int per_thread_max_parallism = 1000000; + +void SetPerThreadMaxParallelism(int max_parallelism) { + CHECK_LE(0, max_parallelism); + per_thread_max_parallism = max_parallelism; +} + +int GetPerThreadMaxParallelism() { return per_thread_max_parallism; } + void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total, int64 cost_per_unit, std::function<void(int64, int64)> work) { CHECK_GE(total, 0); if (total == 0) { return; } + max_parallelism = std::min(max_parallelism, GetPerThreadMaxParallelism()); if (max_parallelism <= 1) { // Just inline the whole work since we only have 1 thread (core). work(0, total); @@ -35,6 +45,13 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total, workers->ParallelFor(total, cost_per_unit, work); return; } + Sharder::Do(total, cost_per_unit, work, + [&workers](Sharder::Closure c) { workers->Schedule(c); }, + max_parallelism); +} + +void Sharder::Do(int64 total, int64 cost_per_unit, const Work& work, + const Runner& runner, int max_parallelism) { cost_per_unit = std::max(int64{1}, cost_per_unit); // We shard [0, total) into "num_shards" shards. // 1 <= num_shards <= num worker threads @@ -63,7 +80,7 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total, BlockingCounter counter(num_shards_used - 1); for (int64 start = block_size; start < total; start += block_size) { auto limit = std::min(start + block_size, total); - workers->Schedule([&work, &counter, start, limit]() { + runner([&work, &counter, start, limit]() { work(start, limit); // Compute the shard. counter.DecrementCount(); // The shard is done. }); diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h index 451da98b6b..72ce493c1b 100644 --- a/tensorflow/core/util/work_sharder.h +++ b/tensorflow/core/util/work_sharder.h @@ -41,6 +41,12 @@ namespace tensorflow { // work(start, limit) computes the work units from [start, // limit), i.e., [start, limit) is a shard. // +// Too much parallelism can also cause excessive thread switches, +// therefore, Shard() often limits the maximum parallelism. Each +// caller can provide the 1st argument max_parallelism. A thread can +// call SetMaxParallelism() so that all Shard() calls later limits the +// thread parallelism. +// // REQUIRES: max_parallelism >= 0 // REQUIRES: workers != nullptr // REQUIRES: total >= 0 @@ -48,6 +54,45 @@ namespace tensorflow { void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total, int64 cost_per_unit, std::function<void(int64, int64)> work); +// Each thread has an associated option to express the desired maximum +// parallelism. Its default is a very large quantity. +// +// Within TF runtime, per-thread max parallelism affects Shard() and +// intra-op parallelism. E.g., if SetPerThreadMaxParallelism(1) is +// arranged to be called by a tf_compute thread, Shard() calls and +// eigen device assignment happens in that thread afterwards becomes +// single-threaded. +void SetPerThreadMaxParallelism(int max_parallelism); +int GetPerThreadMaxParallelism(); + +// Helper to set and unset per-thread max parallelism. +class ScopedPerThreadMaxParallelism { + public: + ScopedPerThreadMaxParallelism(int max_parallelism) + : previous_(GetPerThreadMaxParallelism()) { + SetPerThreadMaxParallelism(max_parallelism); + } + + ~ScopedPerThreadMaxParallelism() { SetPerThreadMaxParallelism(previous_); } + + private: + int previous_ = -1; +}; + +// Implementation details for Shard(). +class Sharder { + public: + typedef std::function<void()> Closure; + typedef std::function<void(Closure)> Runner; + typedef std::function<void(int64, int64)> Work; + + // Refers to Shard()'s comment for the meaning of total, + // cost_per_unit, work, max_parallelism. runner is an interface to + // schedule a closure. Shard() uses thread::ThreadPool instead. + static void Do(int64 total, int64 cost_per_unit, const Work& work, + const Runner& runner, int max_parallelism); +}; + } // end namespace tensorflow #endif // TENSORFLOW_UTIL_WORK_SHARDER_H_ diff --git a/tensorflow/core/util/work_sharder_test.cc b/tensorflow/core/util/work_sharder_test.cc index 0694566ad9..bc5a1d221f 100644 --- a/tensorflow/core/util/work_sharder_test.cc +++ b/tensorflow/core/util/work_sharder_test.cc @@ -28,6 +28,7 @@ namespace tensorflow { namespace { void RunSharding(int64 num_workers, int64 total, int64 cost_per_unit, + int64 per_thread_max_parallelism, thread::ThreadPool* threads) { mutex mu; int64 num_shards = 0; @@ -46,9 +47,18 @@ void RunSharding(int64 num_workers, int64 total, int64 cost_per_unit, work[start] = true; } }); - EXPECT_EQ(num_done_work, total); LOG(INFO) << num_workers << " " << total << " " << cost_per_unit << " " << num_shards; + EXPECT_EQ(num_done_work, total); + if (std::min(num_workers, per_thread_max_parallelism) < + threads->NumThreads()) { + // If the intention is to limit the parallelism explicitly, we'd + // better honor it. Ideally, even if per_thread_max_parallelism > + // num_workers, we should expect that Shard() implementation do + // not over-shard. Unfortunately, ThreadPoolDevice::parallelFor + // tends to over-shard. + EXPECT_LE(num_shards, 1 + per_thread_max_parallelism); + } } TEST(Shard, Basic) { @@ -56,7 +66,10 @@ TEST(Shard, Basic) { for (auto workers : {0, 1, 2, 3, 5, 7, 10, 11, 15, 100, 1000}) { for (auto total : {0, 1, 7, 10, 64, 100, 256, 1000, 9999}) { for (auto cost_per_unit : {0, 1, 11, 102, 1003, 10005, 1000007}) { - RunSharding(workers, total, cost_per_unit, &threads); + for (auto maxp : {1, 2, 4, 8, 100}) { + ScopedPerThreadMaxParallelism s(maxp); + RunSharding(workers, total, cost_per_unit, maxp, &threads); + } } } } |