aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-09-20 13:54:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-20 14:01:40 -0700
commit1ad7cb6f05c221ff0df5532e4101e99250dec33f (patch)
tree3ab50ebe20ef959aeee45503d821d33cc818a87a
parent73a8755cdfd7bd74d626eee35cc58a3ddb2198e6 (diff)
function_ops: Do not hold on to references to the NodeDef in the kernel.
NodeDef's might not outlive the OpKernel (happens often when eager execution is enabled). Holding on to references to the NodeDef was thus unsafe. PiperOrigin-RevId: 169445422
-rw-r--r--tensorflow/contrib/eager/python/datasets_test.py8
-rw-r--r--tensorflow/core/framework/node_def_util.cc22
-rw-r--r--tensorflow/core/kernels/function_ops.cc6
-rw-r--r--tensorflow/core/kernels/iterator_ops.cc5
4 files changed, 14 insertions, 27 deletions
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py
index cbd195d909..a2da6b28c6 100644
--- a/tensorflow/contrib/eager/python/datasets_test.py
+++ b/tensorflow/contrib/eager/python/datasets_test.py
@@ -61,6 +61,14 @@ class IteratorTest(test.TestCase):
got = [x.numpy() for x in it]
self.assertAllEqual([0, 4, 16, 36], got)
+ def testMultipleIteratorsOnADatasetThatUsesFunctions(self):
+ ds = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(math_ops.square)
+
+ got1 = [x.numpy() for x in datasets.Iterator(ds)]
+ self.assertAllEqual([1, 4, 9, 16, 25, 36], got1)
+ got2 = [x.numpy() for x in datasets.Iterator(ds)]
+ self.assertAllEqual(got1, got2)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index b76ae23cb9..f039497f13 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -240,7 +240,7 @@ DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t;
ProtoShortDebugString(v),
" that can't be converted to a Tensor");
})
-
+DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
#undef DEFINE_GET_ATTR
static const string& kEmptyString = *new string();
@@ -286,26 +286,6 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
return Status::OK();
}
-Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
- NameAttrList* value) {
- const AttrValue* attr_value;
- TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
- TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
- *value = attr_value->func();
- return Status::OK();
-}
-
-Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
- std::vector<NameAttrList>* value) {
- const AttrValue* attr_value;
- TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
- TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(func)"));
- for (const auto& v : attr_value->list().func()) {
- value->emplace_back(v);
- }
- return Status::OK();
-}
-
namespace { // Helper for InOutTypesForNode().
Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index 629e29958f..a7206f6258 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -290,7 +290,7 @@ class RemoteCallOp : public AsyncOpKernel {
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
const Tensor* target;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
- AttrValueMap attr_values = func_->attr();
+ AttrValueMap attr_values = func_.attr();
AttrValue v;
const string& target_device = target->scalar<string>()();
v.set_s(target_device);
@@ -302,7 +302,7 @@ class RemoteCallOp : public AsyncOpKernel {
done);
FunctionLibraryRuntime::Handle handle;
OP_REQUIRES_OK_ASYNC(
- ctx, lib->Instantiate(func_->name(), AttrSlice(&attr_values), &handle),
+ ctx, lib->Instantiate(func_.name(), AttrSlice(&attr_values), &handle),
done);
OpInputList arguments;
@@ -336,7 +336,7 @@ class RemoteCallOp : public AsyncOpKernel {
private:
string target_;
- const NameAttrList* func_;
+ NameAttrList func_;
TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp);
};
diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc
index 66f08675c1..089f3f7bb4 100644
--- a/tensorflow/core/kernels/iterator_ops.cc
+++ b/tensorflow/core/kernels/iterator_ops.cc
@@ -227,9 +227,8 @@ class OneShotIteratorOp : public AsyncOpKernel {
OP_REQUIRES(ctx, shared_name.empty(),
errors::InvalidArgument("OneShotIteratorOp does not currently "
"support the 'shared_name' attr."));
- const NameAttrList* dataset_factory_func;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dataset_factory", &dataset_factory_func));
- dataset_factory_func_ = *dataset_factory_func;
+ OP_REQUIRES_OK(ctx,
+ ctx->GetAttr("dataset_factory", &dataset_factory_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}