aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/BUILD4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IgammaGradA.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_RandomGammaGrad.pbtxt5
-rw-r--r--tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc3
-rw-r--r--tensorflow/core/common_runtime/eager/BUILD144
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc12
-rw-r--r--tensorflow/core/common_runtime/eager/context.h23
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc42
-rw-r--r--tensorflow/core/common_runtime/function.cc18
-rw-r--r--tensorflow/core/common_runtime/function.h6
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.cc9
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/BUILD16
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h97
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc11
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h10
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc30
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h17
-rw-r--r--tensorflow/core/framework/device_base.cc33
-rw-r--r--tensorflow/core/framework/device_base.h15
-rw-r--r--tensorflow/core/framework/device_base_test.cc62
-rw-r--r--tensorflow/core/framework/resource_mgr.cc4
-rw-r--r--tensorflow/core/framework/resource_var.h2
-rw-r--r--tensorflow/core/framework/shape_inference.cc14
-rw-r--r--tensorflow/core/framework/shape_inference.h2
-rw-r--r--tensorflow/core/graph/gradients.cc11
-rw-r--r--tensorflow/core/graph/subgraph.cc12
-rw-r--r--tensorflow/core/grappler/clusters/BUILD1
-rw-r--r--tensorflow/core/grappler/clusters/cluster.h3
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.cc9
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.h3
-rw-r--r--tensorflow/core/grappler/clusters/single_machine_test.cc8
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.cc12
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.h5
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc72
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h5
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc53
-rw-r--r--tensorflow/core/grappler/op_types.cc15
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc66
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc72
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc18
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer.cc12
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc64
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc27
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc3
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_random_grad.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_igammas.cc3
-rw-r--r--tensorflow/core/kernels/cwise_op_random_grad.cc25
-rw-r--r--tensorflow/core/kernels/cwise_ops.h4
-rw-r--r--tensorflow/core/kernels/cwise_ops_gradients.h3
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc2
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc90
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc310
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h10
-rw-r--r--tensorflow/core/kernels/string_split_op.cc2
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt209
-rw-r--r--tensorflow/core/ops/math_ops.cc7
-rw-r--r--tensorflow/core/ops/ops.pbtxt88
-rw-r--r--tensorflow/core/ops/random_ops.cc7
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc17
-rw-r--r--tensorflow/core/platform/fingerprint.h2
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto6
-rw-r--r--tensorflow/core/util/work_sharder.cc19
-rw-r--r--tensorflow/core/util/work_sharder.h45
-rw-r--r--tensorflow/core/util/work_sharder_test.cc17
71 files changed, 1457 insertions, 510 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index b6b48a077c..ef8c3f358a 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -89,6 +89,7 @@ load(
"tf_generate_proto_text_sources",
"tf_genrule_cmd_append_to_srcs",
"tf_opts_nortti_if_android",
+ "tf_features_nomodules_if_android",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
@@ -998,6 +999,7 @@ tf_gen_op_libs(
"nn_ops",
"no_op",
"parsing_ops",
+ "random_grad",
"random_ops",
"remote_fused_graph_ops",
"resource_variable_ops",
@@ -2339,6 +2341,7 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [
FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"framework/op_segment.h",
"framework/rendezvous.h", # only needed for tests
+ "framework/resource_var.h",
"framework/tensor_reference.h",
"framework/tracking_allocator.h", # only needed for tests
"framework/unique_tensor_references.h",
@@ -3369,6 +3372,7 @@ tf_cc_tests(
"framework/bfloat16_test.cc",
"framework/cancellation_test.cc",
"framework/common_shape_fns_test.cc",
+ "framework/device_base_test.cc",
"framework/function_test.cc",
"framework/graph_def_util_test.cc",
"framework/graph_to_functiondef_test.cc",
diff --git a/tensorflow/core/api_def/base_api/api_def_IgammaGradA.pbtxt b/tensorflow/core/api_def/base_api/api_def_IgammaGradA.pbtxt
new file mode 100644
index 0000000000..747a8badfd
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IgammaGradA.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "IgammaGradA"
+ visibility: HIDDEN
+ summary: "Computes the gradient of `igamma(a, x)` wrt `a`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RandomGammaGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_RandomGammaGrad.pbtxt
new file mode 100644
index 0000000000..d2bd76f8b9
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RandomGammaGrad.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "RandomGammaGrad"
+ visibility: HIDDEN
+ summary: "Computes the derivative of a Gamma random sample w.r.t. `alpha`."
+}
diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
index d66963ec74..b4bf1c408f 100644
--- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
@@ -74,6 +74,9 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
options.config.mutable_graph_options()
->mutable_rewrite_options()
->set_constant_folding(RewriterConfig::OFF);
+ options.config.mutable_graph_options()
+ ->mutable_rewrite_options()
+ ->set_min_graph_nodes(-1);
std::unique_ptr<Session> session(NewSession(options));
TF_ASSERT_OK(session->Create(def));
std::vector<std::pair<string, Tensor>> inputs;
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index b5120f2872..7f28f3b793 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -22,14 +22,19 @@ tf_cuda_library(
"eager_executor.h",
],
visibility = ["//tensorflow:internal"],
- deps = [
- "//tensorflow/core:core_cpu_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- ],
+ deps = select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+ }),
)
tf_cuda_library(
@@ -44,17 +49,23 @@ tf_cuda_library(
deps = [
":eager_executor",
":kernel_and_device",
- "//tensorflow/core:core_cpu_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:session_options",
- "//tensorflow/core/distributed_runtime:worker_session",
- "//tensorflow/core/distributed_runtime/eager:eager_client",
- "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ "//tensorflow/core/distributed_runtime:server_lib",
+ "//tensorflow/core/distributed_runtime:worker_session",
+ "//tensorflow/core/distributed_runtime/eager:eager_client",
+ ],
+ }),
)
tf_cuda_library(
@@ -86,14 +97,20 @@ tf_cuda_library(
":context",
":eager_executor",
":kernel_and_device",
- "//tensorflow/core:core_cpu_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:session_options",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ ],
+ }),
)
tf_cuda_library(
@@ -106,14 +123,19 @@ tf_cuda_library(
":context",
":eager_executor",
":tensor_handle",
- "//tensorflow/core:core_cpu_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:session_options",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ ],
+ }),
)
tf_cuda_library(
@@ -125,14 +147,20 @@ tf_cuda_library(
"kernel_and_device.h",
],
visibility = ["//tensorflow:internal"],
- deps = [
- "//tensorflow/core:core_cpu_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- ],
+ deps = select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ "//util/hash:farmhash_fingerprint",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+ }),
)
tf_cc_test(
@@ -168,14 +196,20 @@ cc_library(
":eager_operation",
":kernel_and_device",
":tensor_handle",
- "//tensorflow/core:core_cpu_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core/distributed_runtime/eager:eager_client",
- "//tensorflow/core/distributed_runtime/eager:remote_execute_node",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/distributed_runtime/eager:eager_client",
+ "//tensorflow/core/distributed_runtime/eager:remote_execute_node",
+ ],
+ }),
)
tf_cuda_library(
@@ -183,13 +217,15 @@ tf_cuda_library(
srcs = ["attr_builder.cc"],
hdrs = ["attr_builder.h"],
visibility = ["//tensorflow:internal"],
- deps = select({
+ deps = [
+ ":kernel_and_device",
+ "//tensorflow/c:c_api",
+ ] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
+ "//util/hash:farmhash_fingerprint",
],
"//conditions:default": [
- ":kernel_and_device",
- "//tensorflow/c:c_api",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 8381cb58d2..8a87ba7a19 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -38,10 +38,11 @@ EagerContext::EagerContext(const SessionOptions& opts,
InitDeviceMapAndAsync();
}
+#ifndef __ANDROID__
EagerContext::EagerContext(
const SessionOptions& opts, ContextDevicePlacementPolicy default_policy,
bool async, DeviceMgr* local_device_mgr, Rendezvous* rendezvous,
- std::unique_ptr<GrpcServer> server,
+ std::unique_ptr<ServerInterface> server,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<DeviceMgr> remote_device_manager,
const gtl::FlatMap<string, uint64>& remote_contexts)
@@ -55,12 +56,13 @@ EagerContext::EagerContext(
&func_lib_def_, {}, thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
async_default_(async),
+ remote_device_manager_(std::move(remote_device_manager)),
server_(std::move(server)),
remote_eager_workers_(std::move(remote_eager_workers)),
- remote_device_manager_(std::move(remote_device_manager)),
remote_contexts_(remote_contexts) {
InitDeviceMapAndAsync();
}
+#endif
void EagerContext::InitDeviceMapAndAsync() {
if (async_default_) {
@@ -125,10 +127,11 @@ ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() {
}
EagerContext::~EagerContext() {
+#ifndef __ANDROID__
if (server_) {
// TODO(nareshmodi): Fix this.
LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
- "GrpcServer doesn't support clean shutdown.";
+ "Servers don't support clean shutdown.";
server_.release();
}
@@ -158,6 +161,7 @@ EagerContext::~EagerContext() {
}
counter.Wait();
+#endif
executor_.WaitForAllPendingNodes().IgnoreError();
ClearCaches();
@@ -224,6 +228,7 @@ Status GetTaskName(Device* d, string* task_name) {
}
} // namespace
+#ifndef __ANDROID__
Status EagerContext::GetClientAndContextID(Device* device,
eager::EagerClient** client,
uint64* context_id) {
@@ -253,5 +258,6 @@ Status EagerContext::GetClientAndContextID(Device* device,
return Status::OK();
}
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 096ed3112e..601b9e4545 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -29,8 +29,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#ifndef __ANDROID__
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
+#endif
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
@@ -75,21 +77,22 @@ class EagerContext {
// workers.
//
// Additional remote-specific args are:
- // - server: A GrpcServer that exports the tensorflow.WorkerService. Note
- // that this class expects the server to already have been started.
+ // - server: A ServerInterface that exports the tensorflow.WorkerService.
+ // Note that this class expects the server to already have been started.
// - remote_eager_workers: A cache from which we can get "EagerClient"s to
// communicate with remote eager services.
// - remote_device_mgr: A DeviceMgr* which contains all remote devices
// (should contain no local devices).
// - remote_contexts: A map containing task name to remote context ID.
+#ifndef __ANDROID__
explicit EagerContext(
const SessionOptions& opts, ContextDevicePlacementPolicy default_policy,
bool async, DeviceMgr* local_device_mgr, Rendezvous* rendezvous,
- std::unique_ptr<GrpcServer> server,
+ std::unique_ptr<ServerInterface> server,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<DeviceMgr> remote_device_manager,
const gtl::FlatMap<string, uint64>& remote_contexts);
-
+#endif
~EagerContext();
// Returns the function library runtime for the given device.
@@ -174,9 +177,10 @@ class EagerContext {
FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
+#ifndef __ANDROID__
Status GetClientAndContextID(Device* device, eager::EagerClient** client,
uint64* context_id);
-
+#endif
private:
void InitDeviceMapAndAsync();
@@ -228,16 +232,19 @@ class EagerContext {
std::unordered_map<std::thread::id, bool> thread_local_async_
GUARDED_BY(async_map_mu_);
+ const std::unique_ptr<DeviceMgr> remote_device_manager_;
+
// The server_ is not const since we release it when the context is destroyed.
// Therefore the server_ object is not marked as const (even though it should
// be).
- std::unique_ptr<GrpcServer> server_;
+#ifndef __ANDROID__
+ std::unique_ptr<ServerInterface> server_;
const std::unique_ptr<eager::EagerClientCache> remote_eager_workers_;
- const std::unique_ptr<DeviceMgr> remote_device_manager_;
const gtl::FlatMap<string, uint64> remote_contexts_;
gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>>
device_to_client_cache_;
+#endif
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index c619857b78..14aa520e19 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -24,8 +24,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/execute_node.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#ifndef __ANDROID__
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
+#endif
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
@@ -39,6 +41,11 @@ namespace tensorflow {
namespace {
+// Copy of the definition in third_party/tensorflow/compiler/jit/defs.h
+// Copied here because we don't currently compile XLA on windows. So, can't
+// depend on it directly.
+const char* const kXlaCompileAttr = "_XlaCompile";
+
// Initializes the step stats if needed.
void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) {
// Lazily initialize the RunMetadata with information about all devices if
@@ -472,6 +479,15 @@ Status EagerLocalExecute(EagerOperation* op,
device == nullptr ? "unspecified" : device->name());
KernelAndDevice* kernel = ctx->GetCachedKernel(cache_key);
if (kernel == nullptr) {
+ // If we are running a function on explicitly requested TPU,
+ // compile it with XLA.
+ // Note that it is not ideal, but currently ok, to set this
+ // attribute after computing the kernel cache key above.
+ if (op->is_function() && device != nullptr &&
+ device->device_type() == "TPU") {
+ op->MutableAttrs()->Set(kXlaCompileAttr, true);
+ }
+
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
if (device == nullptr) {
status = SelectDevice(ndef, ctx, &device);
@@ -559,9 +575,19 @@ Status EagerLocalExecute(EagerOperation* op,
return status;
}
-Status EagerRemoteExecute(EagerOperation* op, eager::EagerClient* eager_client,
- uint64 context_id, TensorHandle** retvals,
+Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
int* num_retvals) {
+#ifdef __ANDROID__
+ return errors::Unimplemented(
+ "Eager's remote execution is not available on Android devices.");
+#else
+ EagerContext* ctx = op->EagerContext();
+
+ eager::EagerClient* eager_client;
+ uint64 context_id;
+ TF_RETURN_IF_ERROR(
+ ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id));
+
eager::EnqueueRequest request;
eager::EnqueueResponse response;
@@ -622,7 +648,6 @@ Status EagerRemoteExecute(EagerOperation* op, eager::EagerClient* eager_client,
}
tensorflow::Device* op_device = op->Device();
- EagerContext* ctx = op->EagerContext();
const tensorflow::uint64 id = remote_op->id();
for (int i = 0; i < *num_retvals; i++) {
@@ -657,6 +682,7 @@ Status EagerRemoteExecute(EagerOperation* op, eager::EagerClient* eager_client,
}
return Status::OK();
+#endif
}
} // namespace
@@ -669,15 +695,7 @@ Status EagerExecute(EagerOperation* op,
return EagerLocalExecute(op, retvals, num_retvals);
}
- auto* ctx = op->EagerContext();
-
- tensorflow::eager::EagerClient* eager_client;
- tensorflow::uint64 context_id;
- TF_RETURN_IF_ERROR(
- ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id));
-
- return EagerRemoteExecute(op, eager_client, context_id, retvals->data(),
- num_retvals);
+ return EagerRemoteExecute(op, retvals->data(), num_retvals);
}
Status EagerExecute(EagerContext* ctx, Device* device,
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 68d37ddbcd..6d8cea8297 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -1188,11 +1188,13 @@ static bool ValidateInlining(const Node* node, const FunctionBody* fbody) {
return true;
}
-// Given a "caller" in "graph", which is a function call of a function
+// Given a "caller" in graph "g", which is a function call of a function
// to "fbody". Replaces the "caller" with fbody->graph and connects
-// edges properly.
+// edges properly. "override_device" specifies whether inlining should replace
+// explicitly specified devices inside fbody with the callee's device.
void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
- Node* caller, const FunctionBody* fbody) {
+ Node* caller, const FunctionBody* fbody,
+ bool override_device) {
if (!ValidateInlining(caller, fbody)) {
LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. "
<< DebugString(fbody->graph);
@@ -1227,7 +1229,9 @@ void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
for (Node* n : fbody->graph->op_nodes()) {
NodeDef ndef = n->def();
ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name()));
- ndef.set_device(caller->def().device());
+ if (override_device || ndef.device().empty()) {
+ ndef.set_device(caller->def().device());
+ }
Node* clone = g->AddNode(ndef, &s);
TF_CHECK_OK(s);
node_map[n->id()] = clone;
@@ -1581,6 +1585,12 @@ FunctionBody* SymbolicGradientHelper::Compute() {
g->RemoveNode(n);
}
gbody_->ret_types = fbody_->arg_types;
+ // TODO(apassos): use the right dtype for gradients of resource variables
+ for (int i = 0; i < gbody_->ret_types.size(); ++i) {
+ if (gbody_->ret_types[i] == DT_RESOURCE) {
+ gbody_->ret_types[i] = DT_FLOAT;
+ }
+ }
gbody_->ret_nodes.clear();
// Add new return nodes to the function gradient body for each node
// in 'x_grad_nodes'.
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
index a0f9fcae0a..a274f1ef51 100644
--- a/tensorflow/core/common_runtime/function.h
+++ b/tensorflow/core/common_runtime/function.h
@@ -155,9 +155,11 @@ FunctionBody* SymbolicGradient(const FunctionBody& f);
// Given a "caller" in graph "g", which is a function call of a function
// to "fbody". Replaces the "caller" with fbody->graph and connects
-// edges properly.
+// edges properly. "override_device" specifies whether inlining should replace
+// explicitly specified devices inside fbody with the callee's device.
void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
- Node* caller, const FunctionBody* fbody);
+ Node* caller, const FunctionBody* fbody,
+ bool override_device = true);
// Instantiates FunctionDef into a graph. Set *fbody to point to the
// FunctionBody that holds the instantiated FunctionDef.
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index eb710bdbc5..58018689d5 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -43,7 +43,6 @@ limitations under the License.
#include "tensorflow/core/util/util.h"
#ifndef IS_MOBILE_PLATFORM
-#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
@@ -476,21 +475,15 @@ Status GraphExecutionState::OptimizeGraph(
}
}
- std::unordered_map<string, DeviceProperties> device_map;
Device* cpu_device = nullptr;
for (const auto& device : device_set_->devices()) {
- DeviceProperties props = grappler::GetDeviceInfo(device->parsed_name());
- if (props.type() == "UNKNOWN") {
- continue;
- }
- device_map[device->name()] = props;
if (device->parsed_name().id == 0 &&
StringPiece(device->parsed_name().type) == "CPU" &&
device->GetAllocator(AllocatorAttributes()) != nullptr) {
cpu_device = device;
}
}
- grappler::VirtualCluster cluster(device_map, device_set_);
+ grappler::VirtualCluster cluster(device_set_);
GraphDef new_graph;
TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
item, rewrite_options, cpu_device, &cluster, &new_graph));
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index 567c81870c..dfce7c23e7 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -206,7 +206,7 @@ Status InlineCallInGraph(Node* n, Graph* g) {
&fbody));
// TODO(jpienaar): Improve this interface to make the need to delete it
// explicit.
- InlineFunctionBody(g->flib_def(), g, n, fbody);
+ InlineFunctionBody(g->flib_def(), g, n, fbody, false);
delete fbody;
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index 66c4e5d7a9..4a10d99a60 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -286,7 +286,9 @@ cc_library(
"//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:session_mgr",
+ "//tensorflow/core/distributed_runtime:worker_cache_wrapper",
"//tensorflow/core/distributed_runtime:worker_env",
+ "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl",
],
alwayslink = 1,
)
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/BUILD b/tensorflow/core/distributed_runtime/rpc/eager/BUILD
index 6b44d8cecf..d09a85c6a5 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/eager/BUILD
@@ -43,25 +43,11 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:ptr_util",
"//tensorflow/core/distributed_runtime/eager:eager_service_impl",
+ "//tensorflow/core/distributed_runtime/rpc:async_service_interface",
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
- "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
],
)
-
-cc_library(
- name = "eager_grpc_server_lib",
- hdrs = ["eager_grpc_server_lib.h"],
- deps = [
- ":grpc_eager_service_impl",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
- "//tensorflow/core/distributed_runtime:worker_cache_wrapper",
- "//tensorflow/core/distributed_runtime/eager:eager_service_impl",
- "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
- "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
- ],
-)
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h
deleted file mode 100644
index 9b863ccee5..0000000000
--- a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h
+++ /dev/null
@@ -1,97 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_
-#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_
-
-#include "tensorflow/core/common_runtime/device_factory.h"
-#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
-#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
-#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
-#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
-
-namespace tensorflow {
-namespace eager {
-
-class EagerGrpcServer : public GrpcServer {
- public:
- static Status Create(const ServerDef& server_def,
- std::unique_ptr<EagerGrpcServer>* server) {
- std::unique_ptr<EagerGrpcServer> ret(new EagerGrpcServer(server_def));
-
- TF_RETURN_IF_ERROR(ret->InitEager());
-
- *server = std::move(ret);
-
- return Status::OK();
- }
-
- Status Start() override {
- TF_RETURN_IF_ERROR(GrpcServer::Start());
-
- eager_service_->Start();
-
- return Status::OK();
- }
-
- Status Stop() override {
- TF_RETURN_IF_ERROR(GrpcServer::Stop());
-
- eager_service_->Stop();
-
- return Status::OK();
- }
-
- using GrpcServer::channel_cache;
- using GrpcServer::master_env;
- using GrpcServer::worker_env;
-
- private:
- EagerGrpcServer(const ServerDef& server_def)
- : GrpcServer(server_def, Env::Default()),
- worker_name_(
- strings::StrCat("/job:", server_def.job_name(),
- "/replica:0/task:", server_def.task_index())) {}
-
- Status InitEager() {
- TF_RETURN_IF_ERROR(this->Init(
- [this](const WorkerEnv* worker_env,
- ::grpc::ServerBuilder* server_builder) {
- this->eager_service_.reset(
- new eager::GrpcEagerServiceImpl(worker_env, server_builder));
- },
- nullptr, nullptr));
-
- worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr(
- "", worker_name_,
- std::unique_ptr<WorkerCacheInterface>(
- new WorkerCacheWrapper(master_env()->worker_cache)),
- worker_env()->device_mgr, {});
-
- auto* r = worker_env()->rendezvous_mgr->Find(0);
- return r->Initialize(worker_session_.get());
- }
-
- std::unique_ptr<GrpcEagerServiceImpl> eager_service_;
- std::shared_ptr<WorkerSession> worker_session_;
- const string worker_name_;
-}; // namespace eager
-
-} // namespace eager
-} // namespace tensorflow
-
-#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
index b36c6dce86..52e06c263d 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
@@ -18,10 +18,8 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -36,7 +34,7 @@ GrpcEagerServiceImpl::GrpcEagerServiceImpl(
cq_ = server_builder->AddCompletionQueue();
}
-void GrpcEagerServiceImpl::DriveCQ() {
+void GrpcEagerServiceImpl::HandleRPCsLoop() {
#define ENQUEUE_REQUEST(method) \
do { \
Call<GrpcEagerServiceImpl, \
@@ -74,12 +72,7 @@ void GrpcEagerServiceImpl::DriveCQ() {
}
}
-void GrpcEagerServiceImpl::Start() {
- // TODO(nareshmodi) separate thread for driving CQ
- request_handler_threadpool_->Schedule([this]() { DriveCQ(); });
-}
-
-void GrpcEagerServiceImpl::Stop() {
+void GrpcEagerServiceImpl::Shutdown() {
// This enqueues a special event (with a null tag)
// that causes the completion queue to be shut down on the
// polling thread.
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
index e94aedf535..9a94026342 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
@@ -20,16 +20,16 @@ limitations under the License.
#include "grpcpp/completion_queue.h"
#include "grpcpp/server_builder.h"
#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
+#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
namespace tensorflow {
namespace eager {
// This class is a wrapper that handles communication for gRPC.
-class GrpcEagerServiceImpl {
+class GrpcEagerServiceImpl : public AsyncServiceInterface {
public:
template <class RequestMessage, class ResponseMessage>
using EagerCall = Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,
@@ -39,8 +39,8 @@ class GrpcEagerServiceImpl {
::grpc::ServerBuilder* server_builder);
virtual ~GrpcEagerServiceImpl() {}
- void Start();
- void Stop();
+ void HandleRPCsLoop() override;
+ void Shutdown() override;
private:
#define HANDLER(method) \
@@ -66,8 +66,6 @@ class GrpcEagerServiceImpl {
EagerServiceImpl local_impl_;
- void DriveCQ();
-
std::unique_ptr<::grpc::Alarm> shutdown_alarm_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index e7914740ae..aa334f9424 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/master_session.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
@@ -42,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
+#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -81,6 +83,7 @@ GrpcServer::~GrpcServer() {
delete master_service_;
delete worker_service_;
+ delete eager_service_;
// TODO(mrry): Refactor the *Env classes so that it is less fiddly
// to destroy them.
@@ -192,6 +195,8 @@ Status GrpcServer::Init(
worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_);
worker_service_ =
NewGrpcWorkerService(worker_impl_.get(), &builder).release();
+ eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
+
// extra service:
if (service_func != nullptr) {
service_func(&worker_env_, &builder);
@@ -264,7 +269,15 @@ Status GrpcServer::Init(
LocalMaster::Register(target(), master_impl_.get(),
config.operation_timeout_in_ms());
- return Status::OK();
+ // Generate a dummy worker session that is used to register the
+ // Rendezvous for eager (we use Step 0 for eager).
+ worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr(
+ "", name_prefix,
+ std::unique_ptr<WorkerCacheInterface>(
+ new WorkerCacheWrapper(master_env_.worker_cache)),
+ worker_env_.device_mgr, {});
+ auto* r = worker_env()->rendezvous_mgr->Find(0);
+ return r->Initialize(worker_session_.get());
}
Status GrpcServer::Init(
@@ -365,6 +378,9 @@ Status GrpcServer::Start() {
worker_thread_.reset(
env_->StartThread(ThreadOptions(), "TF_worker_service",
[this] { worker_service_->HandleRPCsLoop(); }));
+ eager_thread_.reset(
+ env_->StartThread(ThreadOptions(), "TF_eager_service",
+ [this] { eager_service_->HandleRPCsLoop(); }));
state_ = STARTED;
LOG(INFO) << "Started server with target: " << target();
return Status::OK();
@@ -407,6 +423,7 @@ Status GrpcServer::Join() {
case STOPPED:
master_thread_.reset();
worker_thread_.reset();
+ eager_thread_.reset();
return Status::OK();
default:
LOG(FATAL);
@@ -443,6 +460,17 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env,
return Status::OK();
}
+/* static */
+Status GrpcServer::Create(const ServerDef& server_def, Env* env,
+ std::unique_ptr<GrpcServer>* out_server) {
+ std::unique_ptr<GrpcServer> ret(
+ new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
+ ServiceInitFunction service_func = nullptr;
+ TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr, nullptr));
+ *out_server = std::move(ret);
+ return Status::OK();
+}
+
namespace {
class GrpcServerFactory : public ServerFactory {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index 9e53330f85..115148b84e 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -63,6 +63,8 @@ class GrpcServer : public ServerInterface {
public:
static Status Create(const ServerDef& server_def, Env* env,
std::unique_ptr<ServerInterface>* out_server);
+ static Status Create(const ServerDef& server_def, Env* env,
+ std::unique_ptr<GrpcServer>* out_server);
// Destruction is only supported in the factory method. Clean
// shutdown is not currently implemented for this server type.
@@ -74,6 +76,11 @@ class GrpcServer : public ServerInterface {
Status Join() override;
const string target() const override;
+ WorkerEnv* worker_env() { return &worker_env_; }
+ MasterEnv* master_env() { return &master_env_; }
+
+ std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
+
protected:
Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
@@ -115,11 +122,6 @@ class GrpcServer : public ServerInterface {
// This method may only be called after `this->Init()` returns successfully.
int bound_port() const { return bound_port_; }
- WorkerEnv* worker_env() { return &worker_env_; }
- MasterEnv* master_env() { return &master_env_; }
-
- std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
-
const ServerDef& server_def() const { return server_def_; }
private:
@@ -158,6 +160,11 @@ class GrpcServer : public ServerInterface {
AsyncServiceInterface* worker_service_ = nullptr;
std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_);
+ // TensorFlow Eager implementation, and RPC polling thread.
+ AsyncServiceInterface* eager_service_ = nullptr;
+ std::unique_ptr<Thread> eager_thread_ GUARDED_BY(mu_);
+ std::shared_ptr<WorkerSession> worker_session_;
+
std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/core/framework/device_base.cc b/tensorflow/core/framework/device_base.cc
index e30ee84cc3..9108c32942 100644
--- a/tensorflow/core/framework/device_base.cc
+++ b/tensorflow/core/framework/device_base.cc
@@ -13,11 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#define EIGEN_USE_THREADS
+
#include "tensorflow/core/framework/device_base.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/util/work_sharder.h"
+
namespace tensorflow {
-DeviceBase::~DeviceBase() {}
+DeviceBase::~DeviceBase() { gtl::STLDeleteElements(&eigen_cpu_devices_); }
const DeviceAttributes& DeviceBase::attributes() const {
LOG(FATAL) << "Device does not implement attributes()";
@@ -27,4 +33,29 @@ const string& DeviceBase::name() const {
LOG(FATAL) << "Device does not implement name()";
}
+void DeviceBase::set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) {
+ // Eigen::ThreadPoolDevice is a very cheap struct (one pointer and
+ // an int). Therefore, we can afford a pre-allocated array of
+ // Eigen::ThreadPoolDevice. Here, we ensure that
+ // Eigen::ThreadPoolDevices in eigen_cpu_devices_ has increasingly
+ // larger numThreads.
+ for (int i = 1; i <= d->numThreads(); ++i) {
+ eigen_cpu_devices_.push_back(
+ new Eigen::ThreadPoolDevice(d->getPool(), i /* numThreads() */));
+ }
+}
+
+const Eigen::ThreadPoolDevice* DeviceBase::eigen_cpu_device() {
+ // Based on GetPerThreadMaxParallelism(), we return a different
+ // pre-allocated Eigen::ThreadPoolDevice. All these ThreadPoolDevice
+ // use the same underlying threadpool. But they use different
+ // nominal numThreads() hoping that the user of the returned
+ // Eigen::ThreadPoolDevice may not aggressively occupy all the
+ // threads in the underlying threadpool.
+ const int parallelism = std::max<int>(
+ 1,
+ std::min<int>(GetPerThreadMaxParallelism(), eigen_cpu_devices_.size()));
+ return eigen_cpu_devices_[parallelism - 1];
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index ec26d92a61..922d34fac9 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
-#include <unordered_map>
+#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -154,9 +154,7 @@ class DeviceBase {
}
// Does not take ownership.
- void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) {
- eigen_cpu_device_ = d;
- }
+ void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d);
#ifdef TENSORFLOW_USE_SYCL
void set_eigen_sycl_device(Eigen::SyclDevice* d) { eigen_sycl_device_ = d; }
@@ -186,11 +184,12 @@ class DeviceBase {
virtual ScopedAllocatorMgr* GetScopedAllocatorMgr() const { return nullptr; }
- virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() {
- CHECK(eigen_cpu_device_ != nullptr);
- return eigen_cpu_device_;
+ const bool has_eigen_cpu_device() const {
+ return !eigen_cpu_devices_.empty();
}
+ virtual const Eigen::ThreadPoolDevice* eigen_cpu_device();
+
#ifdef TENSORFLOW_USE_SYCL
virtual const Eigen::SyclDevice* eigen_sycl_device() const {
CHECK(eigen_sycl_device_ != nullptr);
@@ -242,7 +241,7 @@ class DeviceBase {
// Set by GPUs as well as by TPU devices.
GpuDeviceInfo* gpu_device_info_ = nullptr;
thread::ThreadPool* device_thread_pool_ = nullptr;
- Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr;
+ std::vector<Eigen::ThreadPoolDevice*> eigen_cpu_devices_;
#ifdef TENSORFLOW_USE_SYCL
Eigen::SyclDevice* eigen_sycl_device_ = nullptr;
#endif
diff --git a/tensorflow/core/framework/device_base_test.cc b/tensorflow/core/framework/device_base_test.cc
new file mode 100644
index 0000000000..6909559ea2
--- /dev/null
+++ b/tensorflow/core/framework/device_base_test.cc
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/device_base.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+TEST(DeviceBaseTest, CpuDevice) {
+ DeviceBase dbase(Env::Default());
+ thread::ThreadPool pool(Env::Default(), "test", 16);
+ EigenThreadPoolWrapper wrapper(&pool);
+ Eigen::ThreadPoolDevice eigen_device(&wrapper, pool.NumThreads());
+ ASSERT_FALSE(dbase.has_eigen_cpu_device());
+ dbase.set_eigen_cpu_device(&eigen_device);
+ ASSERT_TRUE(dbase.has_eigen_cpu_device());
+
+ {
+ auto d = dbase.eigen_cpu_device();
+ EXPECT_EQ(d->numThreads(), 16);
+ }
+
+ {
+ ScopedPerThreadMaxParallelism maxp(4);
+ auto d = dbase.eigen_cpu_device();
+ EXPECT_EQ(d->numThreads(), 4);
+ }
+
+ {
+ ScopedPerThreadMaxParallelism maxp(1);
+ auto d = dbase.eigen_cpu_device();
+ EXPECT_EQ(d->numThreads(), 1);
+ }
+
+ {
+ ScopedPerThreadMaxParallelism maxp(1000);
+ auto d = dbase.eigen_cpu_device();
+ EXPECT_EQ(d->numThreads(), 16);
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index 21fc6c1bd5..0a19861efd 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -60,8 +60,8 @@ namespace internal {
Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) {
if (ctx->device()->attributes().name() != p.device()) {
return errors::InvalidArgument(
- "Trying to access resource located in device ", p.device(),
- " from device ", ctx->device()->attributes().name());
+ "Trying to access resource ", p.name(), " located in device ",
+ p.device(), " from device ", ctx->device()->attributes().name());
}
return Status::OK();
}
diff --git a/tensorflow/core/framework/resource_var.h b/tensorflow/core/framework/resource_var.h
index 872b8f8b30..ff7b3e78a7 100644
--- a/tensorflow/core/framework/resource_var.h
+++ b/tensorflow/core/framework/resource_var.h
@@ -29,6 +29,8 @@ class Var : public ResourceBase {
Var(const Var&) = delete;
Var& operator=(const Var&) = delete;
+ // When locking multiple variables, the locks must be acquired in order of
+ // increasing mu() address.
// TODO(ebrevdo): Use LockSet instead of exposing mu.
mutex* mu() { return &mu_; }
Tensor* tensor() { return &tensor_; }
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index b02bc3adbe..8d597e198d 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -340,6 +340,20 @@ string InferenceContext::DebugString() const {
ProtoDebugString(*node_def_));
}
+string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
+ return strings::StrCat(DebugString(shape_and_type.shape), ":",
+ DataTypeString(shape_and_type.dtype));
+}
+
+string InferenceContext::DebugString(
+ gtl::ArraySlice<ShapeAndType> shape_and_types) {
+ std::vector<string> pieces;
+ for (const ShapeAndType& s : shape_and_types) {
+ pieces.push_back(DebugString(s));
+ }
+ return strings::StrCat("[", str_util::Join(pieces, ","), "]");
+}
+
Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
ShapeHandle* out) {
if (rank > kint32max) {
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 3f3729dcf9..81258b55b3 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -381,6 +381,8 @@ class InferenceContext {
string DebugString(ShapeHandle s);
string DebugString(DimensionHandle d);
+ string DebugString(const ShapeAndType& shape_and_type);
+ string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
// Describes the whole context, for debugging purposes.
string DebugString() const;
diff --git a/tensorflow/core/graph/gradients.cc b/tensorflow/core/graph/gradients.cc
index 6b56613470..c1a8a63784 100644
--- a/tensorflow/core/graph/gradients.cc
+++ b/tensorflow/core/graph/gradients.cc
@@ -106,8 +106,15 @@ static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) {
AddNodeAttr("Tin", in_types, &ndef);
// The gradient node's outputs have the same types as the node 'n's
- // inputs.
- AddNodeAttr("Tout", n->input_types(), &ndef);
+ // inputs, except for resources.
+ DataTypeVector out_types = n->input_types();
+ for (int i = 0; i < out_types.size(); ++i) {
+ if (out_types[i] == DT_RESOURCE) {
+ // TODO(apassos): figure out how to get the right dtype
+ out_types[i] = DT_FLOAT;
+ }
+ }
+ AddNodeAttr("Tout", out_types, &ndef);
NameAttrList func;
func.set_name(n->type_string());
for (const auto& attr : n->attrs()) {
diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc
index 193cf88aed..60337e30aa 100644
--- a/tensorflow/core/graph/subgraph.cc
+++ b/tensorflow/core/graph/subgraph.cc
@@ -81,7 +81,9 @@ Status FeedInputs(
// Update name_index
(*name_index)[feed_node->name()] = feed_node;
- g->AddControlEdge(g->source_node(), feed_node);
+ // Duplicate control edges aren't allowed, but feed_node was *just* created
+ // so there's no need to check for a duplicate.
+ g->AddControlEdge(g->source_node(), feed_node, true);
// Look through edges coming out of "n" for edges whose src_output() index
// matches "output_index". If found, replace the edges with a connection
@@ -107,7 +109,9 @@ Status FeedInputs(
g->AddEdge(feed_node, 0, e->dst(), e->dst_input());
} else {
CHECK_EQ(Graph::kControlSlot, e->src_output());
- g->AddControlEdge(feed_node, e->dst());
+ // Duplicate control edges aren't allowed, but feed_node was *just*
+ // created so there's no need to check for a duplicate.
+ g->AddControlEdge(feed_node, e->dst(), true);
}
g->RemoveEdge(e);
}
@@ -160,7 +164,9 @@ Status FetchOutputs(
// Update the index.
(*name_index)[fetch_node->name()] = fetch_node;
- g->AddControlEdge(fetch_node, g->sink_node());
+ // Duplicate control edges aren't allowed, but fetch_node was *just* created
+ // so there's no need to check for a duplicate.
+ g->AddControlEdge(fetch_node, g->sink_node(), true);
out_fetch_nodes->push_back(fetch_node);
out_fetch_types->push_back(BaseType(n->output_type(id.second)));
}
diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD
index d0b2cf01be..ab8f4bebb3 100644
--- a/tensorflow/core/grappler/clusters/BUILD
+++ b/tensorflow/core/grappler/clusters/BUILD
@@ -77,6 +77,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":cluster",
+ ":utils",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h
index d33aaa7e4c..06db36b3aa 100644
--- a/tensorflow/core/grappler/clusters/cluster.h
+++ b/tensorflow/core/grappler/clusters/cluster.h
@@ -95,7 +95,7 @@ class Cluster {
// The DeviceSet is not always available, but when it is it contains a
// superset of the devices listed in GetDevices/GetDeviceNames().
- const DeviceSet* GetDeviceSet() const { return device_set_; }
+ virtual const DeviceSet* GetDeviceSet() const { return nullptr; }
// Enables collecting the allocator stats. Call with enable=true must be made
// before Provision().
@@ -124,7 +124,6 @@ class Cluster {
protected:
std::unordered_map<string, DeviceProperties> devices_;
- const DeviceSet* device_set_ = nullptr; // Not owned
const int timeout_s_;
SessionOptions options_;
RunOptions run_options_;
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
index 313ef90d81..b97603c890 100644
--- a/tensorflow/core/grappler/clusters/single_machine.cc
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -368,6 +368,15 @@ Status SingleMachine::ResetSession() {
}
coordinator_.reset(new Coordinator());
+ // Build the DeviceSet.
+ device_set_.reset(new DeviceSet);
+ const DeviceMgr* device_mgr;
+ TF_RETURN_IF_ERROR(session_->LocalDeviceManager(&device_mgr));
+ for (auto d : device_mgr->ListDevices()) {
+ device_set_->AddDevice(d);
+ // We currently don't care about the client device.
+ }
+
return Status::OK();
}
diff --git a/tensorflow/core/grappler/clusters/single_machine.h b/tensorflow/core/grappler/clusters/single_machine.h
index 0ae188e0d6..c0421dd4de 100644
--- a/tensorflow/core/grappler/clusters/single_machine.h
+++ b/tensorflow/core/grappler/clusters/single_machine.h
@@ -43,6 +43,8 @@ class SingleMachine : public Cluster {
const std::vector<std::pair<string, Tensor>>& feed,
const std::vector<string>& fetch, RunMetadata* metadata) override;
+ const DeviceSet* GetDeviceSet() const override { return device_set_.get(); }
+
Status EnablePeakMemoryStats(bool enable) override;
// It requires EnableAllocatorStats(true) be called before Provision().
@@ -73,6 +75,7 @@ class SingleMachine : public Cluster {
int64 expected_init_time_s_;
std::unique_ptr<Coordinator> coordinator_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
+ std::unique_ptr<DeviceSet> device_set_;
RunMetadata init_metadata_;
diff --git a/tensorflow/core/grappler/clusters/single_machine_test.cc b/tensorflow/core/grappler/clusters/single_machine_test.cc
index 352f08fede..31b19cfcfd 100644
--- a/tensorflow/core/grappler/clusters/single_machine_test.cc
+++ b/tensorflow/core/grappler/clusters/single_machine_test.cc
@@ -546,7 +546,7 @@ TEST_F(SingleMachineTest, ReleaseMemoryAfterDestruction) {
TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory_before));
EXPECT_EQ(device_peak_memory_before.size(), 1);
// There might be a bit memory used before session's running anything.
- EXPECT_LT(device_peak_memory_before.begin()->second, 200);
+ EXPECT_LT(device_peak_memory_before.begin()->second, 400);
RunMetadata metadata;
TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
@@ -567,8 +567,8 @@ TEST_F(SingleMachineTest, ReleaseMemoryAfterDestruction) {
// Check memory used by resources are released after cluster destruction.
EXPECT_EQ(device_peak_memory_before.size(), 1);
EXPECT_EQ(device_peak_memory_after.size(), 1);
- EXPECT_LT(device_peak_memory_before.begin()->second, 200);
- EXPECT_LT(device_peak_memory_after.begin()->second, 200);
+ EXPECT_LT(device_peak_memory_before.begin()->second, 400);
+ EXPECT_LT(device_peak_memory_after.begin()->second, 400);
}
TEST_F(SingleMachineTest, PeakMemory) {
@@ -597,7 +597,7 @@ TEST_F(SingleMachineTest, PeakMemory) {
device_peak_memory.end());
cpu_memory =
device_peak_memory["/job:localhost/replica:0/task:0/device:CPU:0"];
- EXPECT_LT(cpu_memory, 100);
+ EXPECT_LT(cpu_memory, 200);
}
TEST_F(SingleMachineTest, PeakMemoryStatsNotEnabled) {
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc
index 5c9b2320b5..12e3e46f65 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.cc
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
@@ -38,11 +39,14 @@ VirtualCluster::VirtualCluster(
devices_ = devices;
}
-VirtualCluster::VirtualCluster(
- const std::unordered_map<string, DeviceProperties>& devices,
- const DeviceSet* device_set)
- : VirtualCluster(devices) {
+VirtualCluster::VirtualCluster(const DeviceSet* device_set)
+ : VirtualCluster(std::unordered_map<string, DeviceProperties>()) {
device_set_ = device_set;
+ for (const auto& device : device_set_->devices()) {
+ DeviceProperties props = GetDeviceInfo(device->parsed_name());
+ if (props.type() == "UNKNOWN") continue;
+ devices_[device->name()] = props;
+ }
}
VirtualCluster::~VirtualCluster() {}
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h
index eebac68e1b..6adb0b99bc 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.h
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.h
@@ -36,8 +36,7 @@ class VirtualCluster : public Cluster {
VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices,
OpLevelCostEstimator* node_estimator,
ReadyNodeManager* node_manager);
- VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices,
- const DeviceSet* device_set);
+ VirtualCluster(const DeviceSet* device_set);
~VirtualCluster() override;
@@ -48,10 +47,12 @@ class VirtualCluster : public Cluster {
Status Run(const GraphDef& item,
const std::vector<std::pair<string, Tensor>>& feed,
const std::vector<string>& fetch, RunMetadata* metadata) override;
+ const DeviceSet* GetDeviceSet() const override { return device_set_; }
private:
std::unique_ptr<OpLevelCostEstimator> node_estimator_;
std::unique_ptr<ReadyNodeManager> node_manager_;
+ const DeviceSet* device_set_ = nullptr; // Not owned
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index d9a08d42db..0c02876ac5 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -353,12 +353,12 @@ void VerboseLogUnknownDimensionSources(
class TopoQueue {
public:
explicit TopoQueue(const std::unordered_map<const NodeDef*, int>& topo_order)
- : queue_(CompareNodes(topo_order)) {}
- void push(const NodeDef* n) { queue_.insert(n); }
+ : topo_order_(topo_order) {}
+ void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
const NodeDef* pop() {
CHECK(!empty());
auto it = queue_.begin();
- const NodeDef* n = *it;
+ const NodeDef* n = it->first;
queue_.erase(it);
return n;
}
@@ -367,20 +367,16 @@ class TopoQueue {
std::size_t size() const { return queue_.size(); }
private:
+ using NodeAndId = std::pair<const NodeDef*, int>;
// Graph nodes are created in (roughly) topological order. Therefore we can
// use their id to ensure they're sorted topologically.
- struct CompareNodes {
- explicit CompareNodes(
- const std::unordered_map<const NodeDef*, int>& topo_ordering)
- : topo_order(topo_ordering) {}
- bool operator()(const NodeDef* lhs, const NodeDef* rhs) const {
- return topo_order.at(lhs) < topo_order.at(rhs);
+ struct OrderByIdAscending {
+ bool operator()(const NodeAndId& lhs, const NodeAndId& rhs) const {
+ return lhs.second < rhs.second;
}
-
- private:
- const std::unordered_map<const NodeDef*, int>& topo_order;
};
- std::set<const NodeDef*, CompareNodes> queue_;
+ const std::unordered_map<const NodeDef*, int>& topo_order_;
+ std::set<NodeAndId, OrderByIdAscending> queue_;
};
// Processes symbolic shapes.
@@ -1082,6 +1078,9 @@ Status GraphProperties::UpdateShapes(
// itself.
TF_RETURN_IF_ERROR(
UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
+ } else if (IsQueue(*n)) {
+ // Set shapes and types of Queue ops, if needed.
+ TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
} else {
auto c = shape_refiner->GetNodeContext(n);
if (c && c->op_data && c->op_data->is_function_op) {
@@ -1147,6 +1146,53 @@ Status GraphProperties::PropagateShapes(
return Status::OK();
}
+Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
+ SymbolicShapeRefiner* shape_refiner,
+ bool* new_shapes) {
+ auto ctx = shape_refiner->GetNodeContext(queue_node);
+ if (!ctx) {
+ TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node));
+ ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(queue_node));
+ }
+ auto* ic = ctx->inference_context.get();
+
+ auto* outputs = ic->output_handle_shapes_and_types(0);
+ if (outputs) {
+ // Shapes and types are already set, presumably by Enqueue ops.
+ return shape_refiner->UpdateNode(queue_node, new_shapes);
+ }
+
+ if (queue_node->attr().count("shapes") <= 0 ||
+ queue_node->attr().count("component_types") <= 0 ||
+ queue_node->attr().at("shapes").list().shape_size() !=
+ queue_node->attr().at("component_types").list().type_size()) {
+ // Errors in shapes and component_types attr.
+ return shape_refiner->UpdateNode(queue_node, new_shapes);
+ }
+
+ // Extract types and shapes from Queue attr.
+ const auto& shapes = queue_node->attr().at("shapes").list().shape();
+ const auto& types = queue_node->attr().at("component_types").list().type();
+ std::vector<ShapeAndType> shapes_and_types;
+ for (int i = 0; i < types.size(); i++) {
+ const auto& shape = shapes[i];
+ ShapeHandle shape_handle;
+ TF_RETURN_IF_ERROR(
+ ic->MakeShapeFromPartialTensorShape(shape, &shape_handle));
+ DataType data_type =
+ queue_node->attr().at("component_types").list().type(i);
+ ShapeAndType shape_and_type(shape_handle, data_type);
+ shapes_and_types.push_back(shape_and_type);
+ }
+ ic->set_output_handle_shapes_and_types(0, shapes_and_types);
+
+ // Queue node is updated with output_handle_shapes_and_types, so set
+ // new_shapes and ignore it from UpdateNoe().
+ *new_shapes = true;
+ bool dummy_new_shapes = false;
+ return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes);
+}
+
Status GraphProperties::UpdateEnqueue(
const NodeDef* enqueue_node,
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index 8703613a12..f716cd72c9 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -91,6 +91,11 @@ class GraphProperties {
resource_handles,
SymbolicShapeRefiner* shape_refiner, bool* new_shapes);
+ // Update the shapes and types of the Queue node, if not set by Enqueue node.
+ static Status UpdateQueue(const NodeDef* queue_node,
+ SymbolicShapeRefiner* shape_refiner,
+ bool* new_shapes);
+
// Update the output shapes of a Merge node, and enqueue its fanout in
// new_shapes if needed.
Status UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 3e44b222fd..aa787ae620 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -262,6 +262,59 @@ TEST_F(GraphPropertiesTest, VarHandles) {
EXPECT_EQ(7, prop.shape().dim(1).size());
}
+TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_NoShapeAttr) {
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
+ auto dequeue1 =
+ ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
+
+ GrapplerItem item;
+ TF_CHECK_OK(root.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto props1 = properties.GetOutputProperties("Dequeue1");
+ ASSERT_EQ(1, props1.size());
+ EXPECT_EQ("float: ?", PropToString(props1[0]));
+}
+
+TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_ShapeAttr) {
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
+ ops::FIFOQueue::Attrs().Shapes({{3, 7, 1}}));
+ auto dequeue1 =
+ ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
+
+ GrapplerItem item;
+ TF_CHECK_OK(root.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto props1 = properties.GetOutputProperties("Dequeue1");
+ ASSERT_EQ(1, props1.size());
+ EXPECT_EQ("float: [3,7,1]", PropToString(props1[0]));
+}
+
+TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_PartialShapeAttr) {
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
+ ops::FIFOQueue::Attrs().Shapes({{3, 7, -1}}));
+ auto dequeue1 =
+ ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
+
+ GrapplerItem item;
+ TF_CHECK_OK(root.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto props1 = properties.GetOutputProperties("Dequeue1");
+ ASSERT_EQ(1, props1.size());
+ EXPECT_EQ("float: [3,7,-1]", PropToString(props1[0]));
+}
+
TEST_F(GraphPropertiesTest, Queues) {
// Create a graph with known input shapes, and propagate the shapes through a
// couple of queues.
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 2227904dbf..bdeb5c66fc 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -135,6 +135,18 @@ bool IsDequeueOp(const NodeDef& node) {
bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
+bool IsElementWiseMonotonic(const NodeDef& node) {
+ static const std::unordered_set<string>* element_wise_monotonic_ops =
+ CHECK_NOTNULL((new std::unordered_set<string>{
+ "Relu",
+ "Relu6",
+ "Sigmoid",
+ "Sqrt",
+ "Tanh",
+ }));
+ return element_wise_monotonic_ops->count(node.op()) > 0;
+}
+
bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
bool IsEnter(const NodeDef& node) {
@@ -617,7 +629,8 @@ bool HasOpDef(const NodeDef& node) {
}
bool IsIdempotent(const NodeDef& node) {
- return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node);
+ return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) &&
+ !ModifiesFrameInfo(node);
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 7110a9c63d..2de7d8cc9a 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -55,6 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node);
bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
bool IsDiv(const NodeDef& node);
+bool IsElementWiseMonotonic(const NodeDef& node);
bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 33c2a0d420..8ca726df0b 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -679,6 +679,7 @@ cc_library(
deps = [
":constant_folding",
":graph_optimizer",
+ "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 9d500f8f54..90be051764 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1722,19 +1722,15 @@ class RemoveIdempotentStage : public ArithmeticOptimizerStage {
~RemoveIdempotentStage() override = default;
bool IsSupported(const NodeDef* node) const override {
- return IsIdempotent(*node) && !IsInPreserveSet(*node);
+ return node->input_size() == 1 && IsIdempotent(*node) &&
+ !IsInPreserveSet(*node);
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
- auto root_scope_and_name = ParseNodeScopeAndName(node->name());
- const string new_name = OptimizedNodeName(root_scope_and_name);
- if (input->op() == node->op() && input->device() == node->device() &&
- IsIdempotent(*input) && !ctx().node_map->NodeExists(new_name)) {
- NodeDef* new_input_node = AddCopyNode(new_name, input);
- ForwardControlDependencies(new_input_node, {node});
- *simplified_node_name = new_input_node->name();
+ if (input->op() == node->op() && input->device() == node->device()) {
+ *simplified_node_name = node->input(0);
}
return Status::OK();
}
@@ -2600,6 +2596,58 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
};
+// Performs conversions like:
+// Max(Sqrt(x)) => Sqrt(Max(x))
+// Checks for a max/min reduction over element-wise monotonic functions, such
+// as Sqrt, Sigmoid, Tanh, etc.
+class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
+ public:
+ explicit OptimizeMaxOrMinOfMonotonicStage(
+ const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
+ ctx_ext) {}
+ ~OptimizeMaxOrMinOfMonotonicStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsMax(*node) || IsMin(*node);
+ }
+
+ Status TrySimplify(NodeDef* reduction_node,
+ string* simplified_node_name) override {
+ NodeDef* inner_function;
+ TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
+ // Optimize only if:
+ // 1. inner_function's Op is element-wise monotonic
+ // 2. inner_function's output is not being consumed elsewhere.
+ if (IsElementWiseMonotonic(*inner_function) &&
+ (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) {
+ // Swap the first inputs of the inner function Op & the reduction Op.
+ NodeDef* inner_input;
+ TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
+ inner_function->set_input(0, reduction_node->name());
+ UpdateConsumersAvoidingLoop(inner_function, reduction_node->name());
+ reduction_node->set_input(0, inner_input->name());
+ UpdateConsumersAvoidingLoop(reduction_node, inner_function->name());
+ }
+ return Status::OK();
+ }
+
+ void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) {
+ const string& node_name = node->name();
+ const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
+ for (NodeDef* consumer : consumers) {
+ for (int i = 0; i < consumer->input_size(); ++i) {
+ if (consumer->input(i) == node_name && consumer->name() != new_input) {
+ consumer->set_input(i, new_input);
+ ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
+ }
+ }
+ AddToOptimizationQueue(consumer);
+ }
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -2878,6 +2926,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
if (options_.convert_log1p)
pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
+ if (options_.optimize_max_or_min_of_monotonic)
+ pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 9a6081dcd8..824ef35ef6 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -63,6 +63,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool hoist_common_factor_out_of_aggregation = true;
bool hoist_cwise_unary_chains = false;
bool minimize_broadcasts = true;
+ bool optimize_max_or_min_of_monotonic = true;
bool remove_idempotent = true;
bool remove_identity_transpose = true;
bool remove_involution = true;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 177c237fe7..d0e6b04679 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -269,6 +269,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.convert_log1p = true;
}
+
+ void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.optimize_max_or_min_of_monotonic = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -2971,12 +2976,8 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
- Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
- Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
- Output sn1 =
- ops::Snapshot(s.WithOpName("sn1").WithControlDependencies(ctrl1), a);
- Output sn2 =
- ops::Snapshot(s.WithOpName("sn2").WithControlDependencies(ctrl2), sn1);
+ Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a);
+ Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1);
Output out1 = ops::Identity(s.WithOpName("out1"), sn2);
Output id1 = ops::Identity(s.WithOpName("id1"), a);
Output id2 = ops::Identity(s.WithOpName("id2"), id1);
@@ -2992,32 +2993,24 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
EnableOnlyRemoveIdempotent(&optimizer);
OptimizeTwice(&optimizer, &item, &output);
- EXPECT_EQ(11, output.node_size());
+ EXPECT_EQ(7, output.node_size());
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "out1") {
EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_sn2", node.input(0));
- found++;
- } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_sn2") {
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ("Snapshot", node.op());
- EXPECT_EQ("a", node.input(0));
- EXPECT_EQ("^ctrl1", node.input(1));
- EXPECT_EQ("^ctrl2", node.input(2));
+ EXPECT_EQ("sn1", node.input(0));
found++;
} else if (node.name() == "out2") {
EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_id2", node.input(0));
+ EXPECT_EQ("id1", node.input(0));
found++;
- } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_id2") {
- EXPECT_EQ("Identity", node.op());
+ } else if (node.name() == "sn1") {
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("a", node.input(0));
found++;
}
}
- EXPECT_EQ(4, found);
+ EXPECT_EQ(3, found);
auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(tensors.size(), tensors_expected.size());
@@ -3125,5 +3118,46 @@ TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
}
}
+TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
+ Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
+ Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+ GrapplerItem item;
+ item.fetch = {"final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ // Check if the inputs are switched
+ int required_node_count = 0;
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "sqrt") {
+ EXPECT_EQ("Sqrt", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("reduce_max", node.input(0));
+ ++required_node_count;
+ } else if (node.name() == "reduce_max") {
+ EXPECT_EQ("Max", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ ++required_node_count;
+ }
+ }
+ EXPECT_EQ(2, required_node_count);
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index f4b384ec1e..76c928f995 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -354,12 +354,14 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
}
if (op == "TensorArraySizeV3") {
- const NodeDef* array = node_map_->GetNode(node->input(0));
- if (array->attr().count("dynamic_size") != 0 &&
- array->attr().at("dynamic_size").b()) {
+ const NodeDef* array = CHECK_NOTNULL(node_map_->GetNode(node->input(0)));
+ if (array->input_size() == 0 ||
+ (array->attr().count("dynamic_size") != 0 &&
+ array->attr().at("dynamic_size").b())) {
continue;
}
- const NodeDef* array_size = node_map_->GetNode(array->input(0));
+ const NodeDef* array_size =
+ CHECK_NOTNULL(node_map_->GetNode(array->input(0)));
if (IsReallyConstant(*array_size)) {
// Don't materialize 0 sizes to avoid triggering incorrect static
// checks. A 0 sized array that can't grow isn't useful anyway.
@@ -374,6 +376,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
if (value.flat<int32>()(0) == 0) {
continue;
}
+
node->set_op("Const");
*node->mutable_attr() = array_size->attr();
node->set_input(0, AsControlDependency(NodeName(node->input(0))));
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 9f051ca248..b9765b9292 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -3000,6 +3000,10 @@ TEST_F(ConstantFoldingTest, Enter) {
TEST_F(ConstantFoldingTest, TensorArraySize) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({}));
+ Output placeholder =
+ ops::Placeholder(scope.WithOpName("placeholder"), DT_RESOURCE,
+ ops::Placeholder::Shape(TensorShape({2})));
+ Output foo = ops::Const(scope.WithOpName("foo"), 5.0f, TensorShape({}));
auto dynamic_array =
ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT,
ops::TensorArray::DynamicSize(true));
@@ -3010,6 +3014,8 @@ TEST_F(ConstantFoldingTest, TensorArraySize) {
scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow);
auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"),
static_array.handle, static_array.flow);
+ auto placeholder_sz = ops::TensorArraySize(scope.WithOpName("placeholder_sz"),
+ placeholder, foo);
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
@@ -3026,11 +3032,13 @@ TEST_F(ConstantFoldingTest, TensorArraySize) {
status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(5, output.node_size());
- EXPECT_EQ("dynamic_sz", output.node(3).name());
- EXPECT_EQ("TensorArraySizeV3", output.node(3).op());
- EXPECT_EQ("static_sz", output.node(4).name());
- EXPECT_EQ("Const", output.node(4).op());
+ EXPECT_EQ(8, output.node_size());
+ EXPECT_EQ("dynamic_sz", output.node(5).name());
+ EXPECT_EQ("TensorArraySizeV3", output.node(5).op());
+ EXPECT_EQ("static_sz", output.node(6).name());
+ EXPECT_EQ("Const", output.node(6).op());
+ EXPECT_EQ("placeholder_sz", output.node(7).name());
+ EXPECT_EQ("TensorArraySizeV3", output.node(7).op());
auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"});
EXPECT_EQ(2, tensors_expected.size());
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
index 3f5bab9d3b..fdd82b9603 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
@@ -260,14 +260,14 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
}
continue;
}
+ // Replace a normal input with a control input.
const string ctrl_input = ConstantFolding::AddControlDependency(
old_input, optimized_graph_, node_map_.get());
- if (ctrl_inputs.insert(ctrl_input).second) {
- node->set_input(pos, ctrl_input);
- node_map_->UpdateInput(node_name, old_input, ctrl_input);
- const NodeDef* old_input_node = node_map_->GetNode(old_input);
- nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
- }
+ ctrl_inputs.insert(ctrl_input);
+ node->set_input(pos, ctrl_input);
+ node_map_->UpdateInput(node_name, old_input, ctrl_input);
+ const NodeDef* old_input_node = node_map_->GetNode(old_input);
+ nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
++pos;
}
node->set_op("NoOp");
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
index 0ae3b4ec34..c0f07562af 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
@@ -124,25 +124,62 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop) {
TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size(), output.node_size());
+ int found = 0;
for (int i = 0; i < item.graph.node_size(); ++i) {
const NodeDef& node = item.graph.node(i);
- if (node.name() == "add") {
- EXPECT_EQ("NoOp", node.op());
- EXPECT_EQ(2, node.input_size());
- EXPECT_EQ("^x", node.input(0));
- EXPECT_EQ("^y", node.input(1));
- } else if (node.name() == "id1") {
+ // "add" should get turned into a NoOp and removed.
+ EXPECT_NE("add", node.name());
+ if (node.name() == "id1") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^y", node.input(1));
+ ++found;
} else if (node.name() == "id2") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("^x", node.input(1));
+ ++found;
+ }
+ }
+ EXPECT_EQ(2, found);
+}
+
+TEST_F(DependencyOptimizerTest, ChangeToNoop_RepeatedInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
+ Output add = ops::Add(s.WithOpName("add"), x, x);
+ Output id1 =
+ ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"id1"};
+
+ DependencyOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ LOG(INFO) << output.DebugString();
+
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ int found = 0;
+ for (int i = 0; i < item.graph.node_size(); ++i) {
+ const NodeDef& node = item.graph.node(i);
+ // "add" should get turned into a NoOp and removed.
+ EXPECT_NE("add", node.name());
+ if (node.name() == "id1") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ ++found;
}
}
+ EXPECT_EQ(1, found);
}
TEST_F(DependencyOptimizerTest, ChangeToNoop_SwitchIdentity) {
@@ -400,6 +437,7 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity) {
TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size() - 3, output.node_size());
+ int found = 0;
for (const NodeDef& node : output.node()) {
EXPECT_NE("id_a", node.name());
EXPECT_NE("id_b", node.name());
@@ -407,30 +445,36 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity) {
if (node.name() == "a_a" || node.name() == "a_b") {
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("x", node.input(0));
+ ++found;
}
if (node.name() == "a_c" || node.name() == "a_d") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("z", node.input(0));
EXPECT_EQ("^x", node.input(1));
+ ++found;
}
if (node.name() == "b_a") {
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^y", node.input(1));
EXPECT_EQ("^z", node.input(2));
+ ++found;
}
if (node.name() == "c_a") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^y", node.input(1));
+ ++found;
}
if (node.name() == "c_b") {
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("z", node.input(0));
EXPECT_EQ("^x", node.input(1));
EXPECT_EQ("^y", node.input(2));
+ ++found;
}
}
+ EXPECT_EQ(found, 7);
}
TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
@@ -460,17 +504,20 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
+ int found = 0;
for (const NodeDef& node : output.node()) {
EXPECT_NE("id0", node.name());
if (node.name() == "or0") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("switch:1", node.input(0));
EXPECT_EQ("switch:1", node.input(1));
+ ++found;
}
if (node.name() == "or1") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("switch:1", node.input(0));
EXPECT_EQ("y", node.input(1));
+ ++found;
}
if (node.name() == "or2") {
// or1 should be unchanged.
@@ -478,8 +525,10 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("y", node.input(1));
EXPECT_EQ("^id1", node.input(2));
+ ++found;
}
}
+ EXPECT_EQ(found, 3);
}
TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) {
@@ -535,6 +584,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) {
TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size() - 2, output.node_size());
+ bool found = false;
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
// "id0" and "id1" but neither "ConstantFoldingCtrl/switch_1",
@@ -545,8 +595,10 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
+ found = true;
}
}
+ EXPECT_TRUE(found);
}
TEST_F(DependencyOptimizerTest, IdentityInputs) {
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 143d9dc1c6..b1f31ad0d0 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -42,6 +42,7 @@ namespace grappler {
namespace {
constexpr int kDefaultNumberOfIterations = 2;
+constexpr int kDefaultMinGraphNodes = 4;
int64 NumEdges(const GraphDef& graph) {
int64 num_edges = 0;
@@ -194,6 +195,15 @@ Status MetaOptimizer::InitializeOptimizersByName(
Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
+ int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
+ : cfg_.min_graph_nodes();
+ if (item.graph.node_size() < min_graph_nodes) {
+ VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes
+ << " nodes.";
+ *optimized_graph = item.graph;
+ return Status::OK();
+ }
+
std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
if (cfg_.optimizers().empty() && cfg_.custom_optimizers().empty()) {
TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers));
@@ -202,10 +212,11 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
}
VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id
- << " num_optimizers=" << optimizers.size();
+ << " num_optimizers=" << optimizers.size()
+ << ", num nodes = " << item.graph.node_size();
if (optimizers.empty()) {
- VLOG(3) << "Skip graph optimization, no optimizers registered";
+ VLOG(3) << "Skipping graph optimization, no optimizers registered";
*optimized_graph = item.graph;
return Status::OK();
}
@@ -221,8 +232,15 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphOptimizer* sa_optimizer = nullptr;
for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
- VLOG(4) << "Starting optimization iteration " << iteration + 1;
+ // Don't bother optimizing further if the graph is already tiny.
+ if (optimized_graph->node_size() < min_graph_nodes) {
+ VLOG(3) << "Stopping after iteration " << iteration
+ << ", graph is tiny (#nodes = " << optimized_graph->node_size()
+ << " < " << min_graph_nodes << ")";
+ break;
+ }
+ VLOG(4) << "Starting optimization iteration " << iteration;
for (const auto& optimizer : optimizers) {
// Some optimizers can run only once.
if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
@@ -235,7 +253,6 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
if (fusion_optimizer == nullptr) fusion_optimizer = optimizer.get();
continue;
}
-
Status status = RunOptimizer(optimizer.get(), cluster, &optimized_item,
optimized_graph, &optimization_result);
if (status.ok()) is_optimized = true;
@@ -297,7 +314,7 @@ Status MetaOptimizer::RunOptimizer(
PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph),
", time = ", duration_ms, "ms.");
}
- VLOG(4) << optimizer->name() << ": " << result;
+ VLOG(1) << optimizer->name() << ": " << result;
OptimizerResult optimizer_result{optimizer->name(), result};
optimization_result->results.push_back(optimizer_result);
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 8247cce339..9a03c7dfef 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -74,6 +74,7 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
TestOptimizer::SetOptimized(false);
RewriterConfig rewriter_config;
rewriter_config.add_optimizers("TestOptimizer");
+ rewriter_config.set_min_graph_nodes(-1);
MetaOptimizer optimizer(nullptr, rewriter_config);
GraphDef output;
@@ -89,6 +90,7 @@ TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
RewriterConfig rewriter_config;
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+ rewriter_config.set_min_graph_nodes(-1);
MetaOptimizer optimizer(nullptr, rewriter_config);
GraphDef output;
@@ -104,6 +106,7 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
rewriter_config.set_function_optimization(RewriterConfig::ON);
rewriter_config.add_optimizers("function");
+ rewriter_config.set_min_graph_nodes(-1);
MetaOptimizer optimizer(nullptr, rewriter_config);
diff --git a/tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc
index 5a529bd8ca..508a47deda 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc
@@ -16,10 +16,12 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h"
namespace tensorflow {
namespace functor {
DEFINE_BINARY2(igamma, float, double);
+DEFINE_BINARY2(igamma_grad_a, float, double);
DEFINE_BINARY2(igammac, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_random_grad.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_random_grad.cu.cc
new file mode 100644
index 0000000000..fd0a95ecc5
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_random_grad.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY2(random_gamma_grad, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_igammas.cc b/tensorflow/core/kernels/cwise_op_igammas.cc
index 4b5f888bc1..cadda3b723 100644
--- a/tensorflow/core/kernels/cwise_op_igammas.cc
+++ b/tensorflow/core/kernels/cwise_op_igammas.cc
@@ -14,12 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
namespace tensorflow {
REGISTER2(BinaryOp, CPU, "Igamma", functor::igamma, float, double);
+REGISTER2(BinaryOp, CPU, "IgammaGradA", functor::igamma_grad_a, float, double);
REGISTER2(BinaryOp, CPU, "Igammac", functor::igammac, float, double);
#if GOOGLE_CUDA
REGISTER2(BinaryOp, GPU, "Igamma", functor::igamma, float, double);
+REGISTER2(BinaryOp, GPU, "IgammaGradA", functor::igamma_grad_a, float, double);
REGISTER2(BinaryOp, GPU, "Igammac", functor::igammac, float, double);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_random_grad.cc b/tensorflow/core/kernels/cwise_op_random_grad.cc
new file mode 100644
index 0000000000..8e388ead9e
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_random_grad.cc
@@ -0,0 +1,25 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(BinaryOp, CPU, "RandomGammaGrad", functor::random_gamma_grad, float,
+ double);
+#if GOOGLE_CUDA
+REGISTER2(BinaryOp, GPU, "RandomGammaGrad", functor::random_gamma_grad, float,
+ double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 8b015df4e1..1b1a704d42 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -771,6 +771,10 @@ template <typename T>
struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {};
template <typename T>
+struct random_gamma_grad
+ : base<T, Eigen::internal::scalar_gamma_sample_der_alpha_op<T>> {};
+
+template <typename T>
struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {};
template <typename T>
diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h
index 82cdae9a34..7a6f14babc 100644
--- a/tensorflow/core/kernels/cwise_ops_gradients.h
+++ b/tensorflow/core/kernels/cwise_ops_gradients.h
@@ -202,6 +202,9 @@ struct sqrt_grad : base<T, Eigen::internal::scalar_sqrt_gradient_op<T>> {};
template <typename T>
struct rsqrt_grad : base<T, Eigen::internal::scalar_rsqrt_gradient_op<T>> {};
+template <typename T>
+struct igamma_grad_a : base<T, Eigen::internal::scalar_igamma_der_a_op<T>> {};
+
} // end namespace functor
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index f33e9cec29..b476a452a5 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -779,7 +779,7 @@ class OneShotIteratorOp : public AsyncOpKernel {
}
private:
- void Init(OpKernelContext* ctx, DoneCallback done) {
+ void Init(OpKernelContext* ctx, const DoneCallback& done) {
IteratorResource* iterator = nullptr;
ContainerInfo cinfo;
Status s = TryInit(ctx, &iterator, &cinfo);
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 586677a2d6..aa40f95cde 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -219,8 +219,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
- cond_var_.notify_all();
}
+ cond_var_.notify_all();
return ProcessBatch(ctx, result, out_tensors, end_of_sequence);
}
@@ -286,7 +286,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void Callback(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<BatchResult>& result,
const std::shared_ptr<std::vector<Tensor>>& return_values,
- int64 offset, const Status& status) {
+ int64 offset, const Status& status) LOCKS_EXCLUDED(mu_) {
result->UpdateStatus(status);
if (status.ok()) {
EnsureOutputAllocated(ctx, result, return_values);
@@ -318,36 +318,37 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(result->mu);
result->num_elements++;
}
- {
- mutex_lock l(mu_);
- CallCompleted(result);
- }
+ CallCompleted(result);
}
void CallCompleted(const std::shared_ptr<BatchResult>& result)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- num_calls_--;
+ LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ num_calls_--;
+ result->num_calls--;
+ }
cond_var_.notify_all();
- result->num_calls--;
}
void CallFunction(std::shared_ptr<IteratorContext> ctx,
const std::shared_ptr<BatchResult>& result,
- int64 offset) {
+ int64 offset) LOCKS_EXCLUDED(mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
bool end_of_input;
Status status =
input_impl_->GetNext(ctx.get(), &input_element, &end_of_input);
+ bool return_early;
{
- mutex_lock l(mu_);
- mutex_lock l2(result->mu);
+ mutex_lock l(result->mu);
result->end_of_input = result->end_of_input || end_of_input;
result->status.Update(status);
- if (result->end_of_input || !result->status.ok()) {
- CallCompleted(result);
- return;
- }
+ return_early = result->end_of_input || !result->status.ok();
+ }
+ if (return_early) {
+ CallCompleted(result);
+ return;
}
// Call `captured_func_(input_element)`, using `Callback` to store the
@@ -468,36 +469,43 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return result->status;
}
- void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
- mutex_lock l(mu_);
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
+ LOCKS_EXCLUDED(mu_) {
+ std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
+ new_calls.reserve(dataset()->num_parallel_calls_);
while (true) {
- while (!cancelled_ &&
- (num_calls_ >= dataset()->num_parallel_calls_ ||
- batch_results_.size() > MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ == 0))) {
- cond_var_.wait(l);
- }
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ (num_calls_ >= dataset()->num_parallel_calls_ ||
+ batch_results_.size() > MaxBatchResults() ||
+ (batch_results_.size() == MaxBatchResults() &&
+ call_counter_ % dataset()->batch_size_ == 0))) {
+ cond_var_.wait(l);
+ }
- if (cancelled_) {
- return;
- }
+ if (cancelled_) {
+ return;
+ }
- while (num_calls_ < dataset()->num_parallel_calls_ &&
- (batch_results_.size() < MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ != 0))) {
- if (call_counter_ % dataset()->batch_size_ == 0) {
- batch_results_.emplace_back(
- new BatchResult(dataset()->batch_size_));
+ while (num_calls_ < dataset()->num_parallel_calls_ &&
+ (batch_results_.size() < MaxBatchResults() ||
+ (batch_results_.size() == MaxBatchResults() &&
+ call_counter_ % dataset()->batch_size_ != 0))) {
+ if (call_counter_ % dataset()->batch_size_ == 0) {
+ batch_results_.emplace_back(
+ new BatchResult(dataset()->batch_size_));
+ }
+ int64 offset = call_counter_++ % dataset()->batch_size_;
+ new_calls.emplace_back(batch_results_.back(), offset);
+ num_calls_++;
}
- std::shared_ptr<BatchResult> result = batch_results_.back();
- int64 offset = call_counter_++ % dataset()->batch_size_;
- num_calls_++;
- mu_.unlock();
- CallFunction(ctx, result, offset);
- mu_.lock();
}
+
+ for (const auto& call : new_calls) {
+ CallFunction(ctx, call.first, call.second);
+ }
+ new_calls.clear();
}
}
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 3fa6b0d3a9..15f3dc3b1d 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -151,8 +151,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- invocation_results_(params.dataset->num_parallel_calls_) {}
+ : DatasetIterator<Dataset>(params) {}
~Iterator() override {
// TODO(mrry): Replace this cancellation logic with a
@@ -160,13 +159,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
// but it would be possible to thread a cancellation manager
// through the IteratorContext to upstream,
// potentially-blocking iterators, when we add these.
- {
- mutex_lock l(mu_);
- for (size_t i = 0; i < dataset()->num_parallel_calls_; ++i) {
- if (invocation_results_[i].notification) {
- invocation_results_[i].notification->WaitForNotification();
- }
- }
+ mutex_lock l(mu_);
+ // Cancel the runner thread.
+ cancelled_ = true;
+ cond_var_.notify_all();
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
}
}
@@ -177,173 +176,191 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- mutex_lock l(mu_);
-
- // Ensure that there are `dataset()->num_parallel_calls_`
- // invocations of `func_` outstanding at once.
- while (input_impl_ && (num_inputs_consumed_ - num_outputs_consumed_ <
- dataset()->num_parallel_calls_)) {
- InvokeFunctionLocked(ctx);
- }
-
- if (!input_impl_ && num_inputs_consumed_ == num_outputs_consumed_) {
- *end_of_sequence = true;
- return Status::OK();
- }
-
- // Read the next result out of `invocation_results_`, which
- // acts as a circular buffer.
- const size_t result_index =
- num_outputs_consumed_ % dataset()->num_parallel_calls_;
- InvocationResult* result = &invocation_results_[result_index];
- *end_of_sequence = false;
- if (result->notification) {
- result->notification->WaitForNotification();
- if (result->status.ok()) {
- std::swap(*out_tensors, result->return_values);
+ std::shared_ptr<InvocationResult> result;
+ {
+ mutex_lock l(mu_);
+ EnsureRunnerThreadStarted(ctx);
+ while (invocation_results_.empty()) {
+ cond_var_.wait(l);
}
+ std::swap(result, invocation_results_.front());
+ invocation_results_.pop_front();
}
- ++num_outputs_consumed_;
- if (errors::IsOutOfRange(result->status)) {
- // `f` may deliberately raise `errors::OutOfRange` to indicate
- // that we should terminate the iteration early.
- *end_of_sequence = true;
- return Status::OK();
- } else {
- return result->status;
- }
+ cond_var_.notify_all();
+ result->notification.WaitForNotification();
+ return ProcessResult(result, out_tensors, end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
- if (input_impl_) {
- TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
- } else {
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("end_of_input"), ""));
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
}
- TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_inputs_consumed"),
- num_inputs_consumed_));
+ CHECK_EQ(num_calls_, 0);
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name("num_outputs_consumed"), num_outputs_consumed_));
-
- for (size_t i = 0; i < dataset()->num_parallel_calls_; i++) {
- if (invocation_results_[i].notification) {
- invocation_results_[i].notification->WaitForNotification();
- TF_RETURN_IF_ERROR(
- WriteStatusLocked(writer, i, invocation_results_[i].status));
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("invocation_results[", i, "].size")),
- invocation_results_[i].return_values.size()));
- for (size_t j = 0; j < invocation_results_[i].return_values.size();
- j++) {
- TF_RETURN_IF_ERROR(writer->WriteTensor(
- full_name(
- strings::StrCat("invocation_results[", i, "][", j, "]")),
- invocation_results_[i].return_values[j]));
- }
- } else {
+ full_name("invocation_results.size"), invocation_results_.size()));
+ for (size_t i = 0; i < invocation_results_.size(); i++) {
+ std::shared_ptr<InvocationResult> result = invocation_results_[i];
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ result->return_values.size()));
+ for (size_t j = 0; j < result->return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(
+ strings::StrCat("invocation_results[", i, "][", j, "]")),
+ result->return_values[j]));
+ }
+ if (result->end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("invocation_results[", i, "]_empty")),
+ full_name(strings::StrCat("invocation_results[", i,
+ "].end_of_input")),
""));
}
}
-
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
- if (reader->Contains(full_name("end_of_input"))) {
- input_impl_.reset();
- } else {
- TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
- }
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_inputs_consumed"),
- &num_inputs_consumed_));
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_outputs_consumed"),
- &num_outputs_consumed_));
- for (size_t i = 0; i < dataset()->num_parallel_calls_; i++) {
- InvocationResult* result = &invocation_results_[i];
- *result = InvocationResult();
- if (!reader->Contains(full_name(
- strings::StrCat("invocation_results[", i, "]_empty")))) {
- result->notification.reset(new Notification);
- result->notification->Notify();
- TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
- size_t num_return_values;
- {
- int64 size;
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name(strings::StrCat(
- "invocation_results[", i, "].size")),
- &size));
- num_return_values = static_cast<size_t>(size);
- if (num_return_values != size) {
- return errors::InvalidArgument(strings::StrCat(
- full_name(
- strings::StrCat("invocation_results[", i, "].size")),
- ": ", size, " is not a valid value of type size_t."));
- }
- }
- result->return_values.reserve(num_return_values);
- for (size_t j = 0; j < num_return_values; j++) {
- result->return_values.emplace_back();
- TF_RETURN_IF_ERROR(reader->ReadTensor(
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ int64 invocation_results_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name("invocation_results.size"), &invocation_results_size));
+ for (size_t i = 0; i < invocation_results_size; i++) {
+ std::shared_ptr<InvocationResult> result(new InvocationResult());
+ invocation_results_.push_back(result);
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ size_t num_return_values;
+ {
+ int64 size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ &size));
+ num_return_values = static_cast<size_t>(size);
+ if (num_return_values != size) {
+ return errors::InvalidArgument(strings::StrCat(
full_name(
- strings::StrCat("invocation_results[", i, "][", j, "]")),
- &result->return_values.back()));
+ strings::StrCat("invocation_results[", i, "].size")),
+ ": ", size, " is not a valid value of type size_t."));
}
}
+ result->return_values.reserve(num_return_values);
+ for (size_t j = 0; j < num_return_values; j++) {
+ result->return_values.emplace_back();
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(full_name(strings::StrCat(
+ "invocation_results[", i, "][", j, "]")),
+ &result->return_values.back()));
+ }
+ result->end_of_input = reader->Contains(full_name(
+ strings::StrCat("invocation_results[", i, "].end_of_input")));
+ result->notification.Notify();
}
return Status::OK();
}
private:
struct InvocationResult {
+ Notification notification;
Status status;
- std::unique_ptr<Notification> notification;
std::vector<Tensor> return_values;
+ bool end_of_input;
};
- void InvokeFunctionLocked(IteratorContext* ctx)
+ void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- DCHECK(input_impl_);
- DCHECK(num_inputs_consumed_ - num_outputs_consumed_ <
- dataset()->num_parallel_calls_);
+ if (!runner_thread_) {
+ std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
+ runner_thread_.reset(ctx->env()->StartThread(
+ {}, "runner_thread",
+ std::bind(&Iterator::RunnerThread, this, ctx_copy)));
+ }
+ }
- // The result of invoking the function will be written into the next
- // slot in `invocation_results_`, which acts as a circular buffer.
- const size_t result_index =
- num_inputs_consumed_ % dataset()->num_parallel_calls_;
- InvocationResult* result = &invocation_results_[result_index];
- *result = InvocationResult();
+ void CallCompleted(const std::shared_ptr<InvocationResult>& result)
+ LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ num_calls_--;
+ }
+ result->notification.Notify();
+ cond_var_.notify_all();
+ }
+ void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<InvocationResult>& result)
+ LOCKS_EXCLUDED(mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
- bool end_of_input = false;
- result->status =
- input_impl_->GetNext(ctx, &input_element, &end_of_input);
- if (end_of_input) {
- input_impl_.reset();
- result->status = errors::OutOfRange("");
- } else {
- ++num_inputs_consumed_;
+ result->status = input_impl_->GetNext(ctx.get(), &input_element,
+ &result->end_of_input);
+ if (result->end_of_input || !result->status.ok()) {
+ CallCompleted(result);
+ return;
}
- if (result->status.ok()) {
- // Call `func_(input_element)`, store the result in
- // `result->return_values`, and notify `result->notification`
- // to unblock a consumer.
- result->notification.reset(new Notification);
- dataset()->captured_func_->RunAsync(
- ctx, std::move(input_element), &result->return_values,
- [result, result_index](Status ret_status) {
- result->status.Update(ret_status);
- result->notification->Notify();
- });
+ // Call `func_(input_element)`, store the result in
+ // `result->return_values`, and notify `result->notification` to unblock
+ // a consumer.
+ auto done = [this, result](Status status) {
+ result->status.Update(status);
+ CallCompleted(result);
+ };
+ dataset()->captured_func_->RunAsync(ctx.get(), std::move(input_element),
+ &result->return_values, done);
+ }
+
+ int64 MaxInvocationResults() { return dataset()->num_parallel_calls_; }
+
+ Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) {
+ if (!result->end_of_input && result->status.ok()) {
+ *out_tensors = std::move(result->return_values);
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ if (errors::IsOutOfRange(result->status)) {
+ // `f` may deliberately raise `errors::OutOfRange` to indicate that we
+ // should terminate the iteration early.
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ *end_of_sequence = result->end_of_input;
+ return result->status;
+ }
+
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ std::vector<std::shared_ptr<InvocationResult>> new_calls;
+ new_calls.reserve(dataset()->num_parallel_calls_);
+ while (true) {
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ (num_calls_ >= dataset()->num_parallel_calls_ ||
+ invocation_results_.size() >= MaxInvocationResults())) {
+ cond_var_.wait(l);
+ }
+ if (cancelled_) {
+ return;
+ }
+ while (num_calls_ < dataset()->num_parallel_calls_ &&
+ invocation_results_.size() < MaxInvocationResults()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ new_calls.push_back(invocation_results_.back());
+ num_calls_++;
+ }
+ }
+ cond_var_.notify_all();
+ for (const auto& call : new_calls) {
+ CallFunction(ctx, call);
+ }
+ new_calls.clear();
}
}
@@ -386,11 +403,22 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
strings::StrCat("invocation_results[", index, "].error_message"));
}
+ // Used for coordination between the main thread and the runner thread.
mutex mu_;
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- std::vector<InvocationResult> invocation_results_ GUARDED_BY(mu_);
- int64 num_inputs_consumed_ GUARDED_BY(mu_) = 0;
- int64 num_outputs_consumed_ GUARDED_BY(mu_) = 0;
+ // Used for coordination between the main thread and the runner thread. In
+ // particular, the runner thread should only schedule new calls when the
+ // number of in-flight calls is less than the user specified level of
+ // parallelism and there are slots available in the `invocation_results_`
+ // buffer.
+ condition_variable cond_var_;
+ // Counts the number of outstanding calls.
+ int64 num_calls_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<IteratorBase> input_impl_;
+ // Buffer for storing the invocation results.
+ std::deque<std::shared_ptr<InvocationResult>> invocation_results_
+ GUARDED_BY(mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_) = false;
};
const DatasetBase* const input_;
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index d0703d7576..15004ae4df 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -24,6 +24,14 @@ limitations under the License.
// non-GPU targets. This only breaks in clang, because it's more strict for
// template code and CudaAtomicMax is used in template context.
+
+// This file requires the following include because it uses CudaAtomicMax:
+// #include "tensorflow/core/util/cuda_kernel_helper.h"
+
+// Unfortunately we can't add the #include, since it breaks compilation for
+// non-GPU targets. This only breaks in clang, because it's more strict for
+// template code and CudaAtomicMax is used in template context.
+
// This file requires the following include because it uses CudaAtomicMax:
// #include "tensorflow/core/util/cuda_kernel_helper.h"
@@ -138,4 +146,4 @@ struct Highest {
} // namespace functor
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc
index 3996ff0027..26ab72f12e 100644
--- a/tensorflow/core/kernels/string_split_op.cc
+++ b/tensorflow/core/kernels/string_split_op.cc
@@ -184,7 +184,7 @@ class StringSplitV2Op : public OpKernel {
public:
explicit StringSplitV2Op(OpKernelConstruction* context)
: OpKernel(context), maxsplit_(-1) {
- context->GetAttr("maxsplit", &maxsplit_);
+ OP_REQUIRES_OK(context, context->GetAttr("maxsplit", &maxsplit_));
}
void Compute(OpKernelContext* ctx) override {
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 726bfd63b7..11ed50d30e 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -6426,6 +6426,68 @@ op {
}
}
op {
+ name: "AsString"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_BOOL
+ }
+ }
+ }
+ attr {
+ name: "precision"
+ type: "int"
+ default_value {
+ i: -1
+ }
+ }
+ attr {
+ name: "scientific"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "shortest"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "width"
+ type: "int"
+ default_value {
+ i: -1
+ }
+ }
+ attr {
+ name: "fill"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+}
+op {
name: "Asin"
input_arg {
name: "x"
@@ -25578,6 +25640,31 @@ op {
}
}
op {
+ name: "IgammaGradA"
+ input_arg {
+ name: "a"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Igammac"
input_arg {
name: "a"
@@ -41254,6 +41341,31 @@ op {
is_stateful: true
}
op {
+ name: "RandomGammaGrad"
+ input_arg {
+ name: "alpha"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "sample"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "RandomPoisson"
input_arg {
name: "shape"
@@ -68140,6 +68252,36 @@ op {
}
}
op {
+ name: "StringSplitV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "sep"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "indices"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "values"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "shape"
+ type: DT_INT64
+ }
+ attr {
+ name: "maxsplit"
+ type: "int"
+ default_value {
+ i: -1
+ }
+ }
+}
+op {
name: "StringStrip"
input_arg {
name: "input"
@@ -72671,6 +72813,73 @@ op {
}
}
op {
+ name: "UnsortedSegmentProd"
+ input_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "segment_ids"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "num_segments"
+ type_attr: "Tnumsegments"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tnumsegments"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "UnsortedSegmentSum"
input_arg {
name: "data"
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 0f2d438a03..fd59622b27 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -489,6 +489,13 @@ REGISTER_OP("Igamma")
.Attr("T: {float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+REGISTER_OP("IgammaGradA")
+ .Input("a: T")
+ .Input("x: T")
+ .Output("z: T")
+ .Attr("T: {float, double}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
REGISTER_OP("Zeta")
.Input("x: T")
.Input("q: T")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index c609703bcb..c7f74c205a 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -1977,13 +1977,14 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_INT8
+ type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_FLOAT
type: DT_DOUBLE
type: DT_BOOL
- type: DT_INT8
}
}
}
@@ -12446,6 +12447,31 @@ op {
}
}
op {
+ name: "IgammaGradA"
+ input_arg {
+ name: "a"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Igammac"
input_arg {
name: "a"
@@ -20707,6 +20733,31 @@ op {
is_stateful: true
}
op {
+ name: "RandomGammaGrad"
+ input_arg {
+ name: "alpha"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "sample"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "RandomPoisson"
input_arg {
name: "shape"
@@ -31613,6 +31664,36 @@ op {
}
}
op {
+ name: "StringSplitV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "sep"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "indices"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "values"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "shape"
+ type: DT_INT64
+ }
+ attr {
+ name: "maxsplit"
+ type: "int"
+ default_value {
+ i: -1
+ }
+ }
+}
+op {
name: "StringStrip"
input_arg {
name: "input"
@@ -34534,9 +34615,14 @@ op {
type: DT_UINT8
type: DT_INT16
type: DT_INT8
+ type: DT_COMPLEX64
type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
type: DT_BFLOAT16
type: DT_UINT16
+ type: DT_COMPLEX128
type: DT_HALF
type: DT_UINT32
type: DT_UINT64
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
index 80ffae5796..a76248e05f 100644
--- a/tensorflow/core/ops/random_ops.cc
+++ b/tensorflow/core/ops/random_ops.cc
@@ -138,6 +138,13 @@ REGISTER_OP("RandomGamma")
return Status::OK();
});
+REGISTER_OP("RandomGammaGrad")
+ .Input("alpha: T")
+ .Input("sample: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
REGISTER_OP("RandomPoisson")
.SetIsStateful()
.Input("shape: S")
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index 3d0a6c2157..26499540f1 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -14,6 +14,7 @@
// ============================================================================
#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -84,6 +85,22 @@ REGISTER_OP("ReadVariableOp")
.Attr("dtype: type")
.SetShapeFn(ReadVariableShapeFn);
+Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FunctionDefHelper::Define(
+ // Arg defs
+ {"x: resource", "dy: float"},
+ // Ret val defs
+ {"dy: float"},
+ // Attr defs
+ {},
+ // Nodes
+ {});
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("ReadVariableOp", ReadGrad);
+
REGISTER_OP("DestroyResourceOp")
.Input("resource: resource")
.Attr("ignore_lookup_error: bool = true")
diff --git a/tensorflow/core/platform/fingerprint.h b/tensorflow/core/platform/fingerprint.h
index b47dcdedd7..720dc4c3d6 100644
--- a/tensorflow/core/platform/fingerprint.h
+++ b/tensorflow/core/platform/fingerprint.h
@@ -74,7 +74,7 @@ inline uint64 FingerprintCat64(const uint64 fp1, const uint64 fp2) {
} // namespace tensorflow
-#if defined(PLATFORM_GOOGLE)
+#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
#include "tensorflow/core/platform/google/fingerprint.h"
#else
#include "tensorflow/core/platform/default/fingerprint.h"
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index bbb25d6f3f..07f984ceea 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -80,6 +80,12 @@ message RewriterConfig {
// is once).
NumIterationsType meta_optimizer_iterations = 12;
+ // The minimum number of nodes in a graph to optimizer. For smaller graphs,
+ // optimization is skipped.
+ // 0 means the system picks an appropriate number.
+ // < 0 means do not skip optimization.
+ int32 min_graph_nodes = 17;
+
enum MemOptType {
// The default setting (SCHEDULING and SWAPPING HEURISTICS only)
DEFAULT_MEM_OPT = 0;
diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc
index 337af07b50..f4bd2950e9 100644
--- a/tensorflow/core/util/work_sharder.cc
+++ b/tensorflow/core/util/work_sharder.cc
@@ -20,12 +20,22 @@ limitations under the License.
namespace tensorflow {
+/* ABSL_CONST_INIT */ thread_local int per_thread_max_parallism = 1000000;
+
+void SetPerThreadMaxParallelism(int max_parallelism) {
+ CHECK_LE(0, max_parallelism);
+ per_thread_max_parallism = max_parallelism;
+}
+
+int GetPerThreadMaxParallelism() { return per_thread_max_parallism; }
+
void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
int64 cost_per_unit, std::function<void(int64, int64)> work) {
CHECK_GE(total, 0);
if (total == 0) {
return;
}
+ max_parallelism = std::min(max_parallelism, GetPerThreadMaxParallelism());
if (max_parallelism <= 1) {
// Just inline the whole work since we only have 1 thread (core).
work(0, total);
@@ -35,6 +45,13 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
workers->ParallelFor(total, cost_per_unit, work);
return;
}
+ Sharder::Do(total, cost_per_unit, work,
+ [&workers](Sharder::Closure c) { workers->Schedule(c); },
+ max_parallelism);
+}
+
+void Sharder::Do(int64 total, int64 cost_per_unit, const Work& work,
+ const Runner& runner, int max_parallelism) {
cost_per_unit = std::max(int64{1}, cost_per_unit);
// We shard [0, total) into "num_shards" shards.
// 1 <= num_shards <= num worker threads
@@ -63,7 +80,7 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
BlockingCounter counter(num_shards_used - 1);
for (int64 start = block_size; start < total; start += block_size) {
auto limit = std::min(start + block_size, total);
- workers->Schedule([&work, &counter, start, limit]() {
+ runner([&work, &counter, start, limit]() {
work(start, limit); // Compute the shard.
counter.DecrementCount(); // The shard is done.
});
diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h
index 451da98b6b..72ce493c1b 100644
--- a/tensorflow/core/util/work_sharder.h
+++ b/tensorflow/core/util/work_sharder.h
@@ -41,6 +41,12 @@ namespace tensorflow {
// work(start, limit) computes the work units from [start,
// limit), i.e., [start, limit) is a shard.
//
+// Too much parallelism can also cause excessive thread switches,
+// therefore, Shard() often limits the maximum parallelism. Each
+// caller can provide the 1st argument max_parallelism. A thread can
+// call SetMaxParallelism() so that all Shard() calls later limits the
+// thread parallelism.
+//
// REQUIRES: max_parallelism >= 0
// REQUIRES: workers != nullptr
// REQUIRES: total >= 0
@@ -48,6 +54,45 @@ namespace tensorflow {
void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
int64 cost_per_unit, std::function<void(int64, int64)> work);
+// Each thread has an associated option to express the desired maximum
+// parallelism. Its default is a very large quantity.
+//
+// Within TF runtime, per-thread max parallelism affects Shard() and
+// intra-op parallelism. E.g., if SetPerThreadMaxParallelism(1) is
+// arranged to be called by a tf_compute thread, Shard() calls and
+// eigen device assignment happens in that thread afterwards becomes
+// single-threaded.
+void SetPerThreadMaxParallelism(int max_parallelism);
+int GetPerThreadMaxParallelism();
+
+// Helper to set and unset per-thread max parallelism.
+class ScopedPerThreadMaxParallelism {
+ public:
+ ScopedPerThreadMaxParallelism(int max_parallelism)
+ : previous_(GetPerThreadMaxParallelism()) {
+ SetPerThreadMaxParallelism(max_parallelism);
+ }
+
+ ~ScopedPerThreadMaxParallelism() { SetPerThreadMaxParallelism(previous_); }
+
+ private:
+ int previous_ = -1;
+};
+
+// Implementation details for Shard().
+class Sharder {
+ public:
+ typedef std::function<void()> Closure;
+ typedef std::function<void(Closure)> Runner;
+ typedef std::function<void(int64, int64)> Work;
+
+ // Refers to Shard()'s comment for the meaning of total,
+ // cost_per_unit, work, max_parallelism. runner is an interface to
+ // schedule a closure. Shard() uses thread::ThreadPool instead.
+ static void Do(int64 total, int64 cost_per_unit, const Work& work,
+ const Runner& runner, int max_parallelism);
+};
+
} // end namespace tensorflow
#endif // TENSORFLOW_UTIL_WORK_SHARDER_H_
diff --git a/tensorflow/core/util/work_sharder_test.cc b/tensorflow/core/util/work_sharder_test.cc
index 0694566ad9..bc5a1d221f 100644
--- a/tensorflow/core/util/work_sharder_test.cc
+++ b/tensorflow/core/util/work_sharder_test.cc
@@ -28,6 +28,7 @@ namespace tensorflow {
namespace {
void RunSharding(int64 num_workers, int64 total, int64 cost_per_unit,
+ int64 per_thread_max_parallelism,
thread::ThreadPool* threads) {
mutex mu;
int64 num_shards = 0;
@@ -46,9 +47,18 @@ void RunSharding(int64 num_workers, int64 total, int64 cost_per_unit,
work[start] = true;
}
});
- EXPECT_EQ(num_done_work, total);
LOG(INFO) << num_workers << " " << total << " " << cost_per_unit << " "
<< num_shards;
+ EXPECT_EQ(num_done_work, total);
+ if (std::min(num_workers, per_thread_max_parallelism) <
+ threads->NumThreads()) {
+ // If the intention is to limit the parallelism explicitly, we'd
+ // better honor it. Ideally, even if per_thread_max_parallelism >
+ // num_workers, we should expect that Shard() implementation do
+ // not over-shard. Unfortunately, ThreadPoolDevice::parallelFor
+ // tends to over-shard.
+ EXPECT_LE(num_shards, 1 + per_thread_max_parallelism);
+ }
}
TEST(Shard, Basic) {
@@ -56,7 +66,10 @@ TEST(Shard, Basic) {
for (auto workers : {0, 1, 2, 3, 5, 7, 10, 11, 15, 100, 1000}) {
for (auto total : {0, 1, 7, 10, 64, 100, 256, 1000, 9999}) {
for (auto cost_per_unit : {0, 1, 11, 102, 1003, 10005, 1000007}) {
- RunSharding(workers, total, cost_per_unit, &threads);
+ for (auto maxp : {1, 2, 4, 8, 100}) {
+ ScopedPerThreadMaxParallelism s(maxp);
+ RunSharding(workers, total, cost_per_unit, maxp, &threads);
+ }
}
}
}