diff options
author | Piotr Padlewski <prazek@google.com> | 2018-09-14 11:28:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-14 11:32:52 -0700 |
commit | c20a7b81d79d30db9e990309ddb419bcb48120cc (patch) | |
tree | 9ea682cf79bac18653e7690785e0f5e7117b6b8b /tensorflow/core/framework | |
parent | 89f9080ed0d1a43cb2fa253997b2553c6916f364 (diff) |
[tf.data] Introducing an optimization that parallelizes map transformations.
Stateless MapDatasets can be paralellized by switching to ParallelMapDataset. We set `num_parallel_calls` to 2 for now, but in the future a special value will be used that result in the optimal value to be selected dynamically at runtime.
This patch also exposed a memory leak which was fixed.
PiperOrigin-RevId: 213015223
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/function.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/framework/function_testlib.cc | 34 | ||||
-rw-r--r-- | tensorflow/core/framework/function_testlib.h | 3 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/framework/op_segment.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/framework/op_segment.h | 4 |
6 files changed, 66 insertions, 7 deletions
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index d979353d2f..a17959a448 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1294,6 +1294,18 @@ FunctionDef FunctionDefHelper::Create( for (const auto& r : ret_def) { fdef.mutable_ret()->insert({r.first, r.second}); } + + auto* op_def_registry = OpRegistry::Global(); + // Check if any op is stateful. + for (const auto& n : node_def) { + const OpDef* op_def = nullptr; + auto status = op_def_registry->LookUpOpDef(n.op, &op_def); + // Lookup can fail if e.g. we are calling a function that was not yet + // defined. If it happens, conservatively assume the op is stateful. + if (!status.ok() || op_def->is_stateful()) { + fdef.mutable_signature()->set_is_stateful(true); + } + } return fdef; } @@ -1355,6 +1367,7 @@ FunctionDef FunctionDefHelper::Define(const string& name, strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first); } } + if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true); } // Returns diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index c5a4f661d2..d5c203d276 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -91,6 +91,40 @@ FunctionDef IsZero() { }); } +FunctionDef RandomUniform() { + const Tensor kZero = test::AsScalar<int64>(0); + const Tensor kTen = test::AsScalar<int64>(10); + + return FDH::Define( + // Name + "RandomUniform", + // Args + {"x: T"}, + // Return values + {"random_uniform: int64"}, + // Attr def + {"T:{float, double, int32, int64, string}"}, + {{{"random_uniform/shape"}, + "Const", + {}, + {{"value", kZero}, {"dtype", DT_INT64}}}, + {{"random_uniform/min"}, + "Const", + {}, + {{"value", kZero}, {"dtype", DT_INT64}}}, + {{"random_uniform/max"}, + "Const", + {}, + {{"value", kTen}, {"dtype", DT_INT64}}}, + {{"random_uniform"}, + "RandomUniformInt", + {}, + {{"T", DT_INT64}, + {"Tout", DT_INT64}, + {"seed", 87654321}, + {"seed2", 42}}}}); +} + FunctionDef XTimesTwo() { const Tensor kTwo = test::AsScalar<int64>(2); return FDH::Define( diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h index ad61a76f16..a01743423b 100644 --- a/tensorflow/core/framework/function_testlib.h +++ b/tensorflow/core/framework/function_testlib.h @@ -84,6 +84,9 @@ FunctionDef NonZero(); // x: T -> bool. FunctionDef IsZero(); +// x: T -> int64 +FunctionDef RandomUniform(); + // x:T, y:T -> y:T, x:T FunctionDef Swap(); diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index c694e10193..80f2b12987 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -80,10 +81,8 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs, // OpKernel ------------------------------------------------------------------ -// TODO(mrry): Convert to std::make_unique when available. OpKernel::OpKernel(OpKernelConstruction* context) - : OpKernel(context, - std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {} + : OpKernel(context, MakeUnique<const NodeDef>(context->def())) {} OpKernel::OpKernel(OpKernelConstruction* context, std::unique_ptr<const NodeDef> node_def) @@ -525,10 +524,8 @@ std::unique_ptr<Tensor> OpKernelContext::forward_input( return nullptr; } } - // TODO(rmlarsen): Use MakeUnique here. There is already a copy in - // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of - // general cleanup of ownership in this code. - std::unique_ptr<Tensor> output_tensor(new Tensor()); + + auto output_tensor = MakeUnique<Tensor>(); CHECK(output_tensor->CopyFrom(*input.tensor, output_shape)); return output_tensor; } diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc index dfc5aa7747..75ed4a4eaf 100644 --- a/tensorflow/core/framework/op_segment.cc +++ b/tensorflow/core/framework/op_segment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/op_segment.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -99,4 +100,11 @@ void OpSegment::RemoveHold(const string& session_handle) { delete item; } +bool OpSegment::ShouldOwnKernel(FunctionLibraryRuntime* lib, + const string& node_op) { + // OpSegment should not own kernel if the node is stateless, or a function. + return lib->IsStateful(node_op) && + lib->GetFunctionLibraryDefinition()->Find(node_op) == nullptr; +} + } // end namespace tensorflow diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h index 4433a2554f..37d939ea2b 100644 --- a/tensorflow/core/framework/op_segment.h +++ b/tensorflow/core/framework/op_segment.h @@ -60,6 +60,10 @@ class OpSegment { Status FindOrCreate(const string& session_handle, const string& node_name, OpKernel** kernel, CreateKernelFn create_fn); + // Returns true if OpSegment should own the kernel. + static bool ShouldOwnKernel(FunctionLibraryRuntime* lib, + const string& node_op); + private: // op name -> OpKernel typedef std::unordered_map<string, OpKernel*> KernelMap; |