aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Piotr Padlewski <prazek@google.com>2018-09-14 11:28:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-14 11:32:52 -0700
commitc20a7b81d79d30db9e990309ddb419bcb48120cc (patch)
tree9ea682cf79bac18653e7690785e0f5e7117b6b8b /tensorflow/core/framework
parent89f9080ed0d1a43cb2fa253997b2553c6916f364 (diff)
[tf.data] Introducing an optimization that parallelizes map transformations.
Stateless MapDatasets can be paralellized by switching to ParallelMapDataset. We set `num_parallel_calls` to 2 for now, but in the future a special value will be used that result in the optimal value to be selected dynamically at runtime. This patch also exposed a memory leak which was fixed. PiperOrigin-RevId: 213015223
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/function.cc13
-rw-r--r--tensorflow/core/framework/function_testlib.cc34
-rw-r--r--tensorflow/core/framework/function_testlib.h3
-rw-r--r--tensorflow/core/framework/op_kernel.cc11
-rw-r--r--tensorflow/core/framework/op_segment.cc8
-rw-r--r--tensorflow/core/framework/op_segment.h4
6 files changed, 66 insertions, 7 deletions
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index d979353d2f..a17959a448 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1294,6 +1294,18 @@ FunctionDef FunctionDefHelper::Create(
for (const auto& r : ret_def) {
fdef.mutable_ret()->insert({r.first, r.second});
}
+
+ auto* op_def_registry = OpRegistry::Global();
+ // Check if any op is stateful.
+ for (const auto& n : node_def) {
+ const OpDef* op_def = nullptr;
+ auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
+ // Lookup can fail if e.g. we are calling a function that was not yet
+ // defined. If it happens, conservatively assume the op is stateful.
+ if (!status.ok() || op_def->is_stateful()) {
+ fdef.mutable_signature()->set_is_stateful(true);
+ }
+ }
return fdef;
}
@@ -1355,6 +1367,7 @@ FunctionDef FunctionDefHelper::Define(const string& name,
strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
}
}
+ if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
}
// Returns
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index c5a4f661d2..d5c203d276 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -91,6 +91,40 @@ FunctionDef IsZero() {
});
}
+FunctionDef RandomUniform() {
+ const Tensor kZero = test::AsScalar<int64>(0);
+ const Tensor kTen = test::AsScalar<int64>(10);
+
+ return FDH::Define(
+ // Name
+ "RandomUniform",
+ // Args
+ {"x: T"},
+ // Return values
+ {"random_uniform: int64"},
+ // Attr def
+ {"T:{float, double, int32, int64, string}"},
+ {{{"random_uniform/shape"},
+ "Const",
+ {},
+ {{"value", kZero}, {"dtype", DT_INT64}}},
+ {{"random_uniform/min"},
+ "Const",
+ {},
+ {{"value", kZero}, {"dtype", DT_INT64}}},
+ {{"random_uniform/max"},
+ "Const",
+ {},
+ {{"value", kTen}, {"dtype", DT_INT64}}},
+ {{"random_uniform"},
+ "RandomUniformInt",
+ {},
+ {{"T", DT_INT64},
+ {"Tout", DT_INT64},
+ {"seed", 87654321},
+ {"seed2", 42}}}});
+}
+
FunctionDef XTimesTwo() {
const Tensor kTwo = test::AsScalar<int64>(2);
return FDH::Define(
diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
index ad61a76f16..a01743423b 100644
--- a/tensorflow/core/framework/function_testlib.h
+++ b/tensorflow/core/framework/function_testlib.h
@@ -84,6 +84,9 @@ FunctionDef NonZero();
// x: T -> bool.
FunctionDef IsZero();
+// x: T -> int64
+FunctionDef RandomUniform();
+
// x:T, y:T -> y:T, x:T
FunctionDef Swap();
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index c694e10193..80f2b12987 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -80,10 +81,8 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
// OpKernel ------------------------------------------------------------------
-// TODO(mrry): Convert to std::make_unique when available.
OpKernel::OpKernel(OpKernelConstruction* context)
- : OpKernel(context,
- std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {}
+ : OpKernel(context, MakeUnique<const NodeDef>(context->def())) {}
OpKernel::OpKernel(OpKernelConstruction* context,
std::unique_ptr<const NodeDef> node_def)
@@ -525,10 +524,8 @@ std::unique_ptr<Tensor> OpKernelContext::forward_input(
return nullptr;
}
}
- // TODO(rmlarsen): Use MakeUnique here. There is already a copy in
- // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of
- // general cleanup of ownership in this code.
- std::unique_ptr<Tensor> output_tensor(new Tensor());
+
+ auto output_tensor = MakeUnique<Tensor>();
CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
return output_tensor;
}
diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc
index dfc5aa7747..75ed4a4eaf 100644
--- a/tensorflow/core/framework/op_segment.cc
+++ b/tensorflow/core/framework/op_segment.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -99,4 +100,11 @@ void OpSegment::RemoveHold(const string& session_handle) {
delete item;
}
+bool OpSegment::ShouldOwnKernel(FunctionLibraryRuntime* lib,
+ const string& node_op) {
+ // OpSegment should not own kernel if the node is stateless, or a function.
+ return lib->IsStateful(node_op) &&
+ lib->GetFunctionLibraryDefinition()->Find(node_op) == nullptr;
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h
index 4433a2554f..37d939ea2b 100644
--- a/tensorflow/core/framework/op_segment.h
+++ b/tensorflow/core/framework/op_segment.h
@@ -60,6 +60,10 @@ class OpSegment {
Status FindOrCreate(const string& session_handle, const string& node_name,
OpKernel** kernel, CreateKernelFn create_fn);
+ // Returns true if OpSegment should own the kernel.
+ static bool ShouldOwnKernel(FunctionLibraryRuntime* lib,
+ const string& node_op);
+
private:
// op name -> OpKernel
typedef std::unordered_map<string, OpKernel*> KernelMap;