aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/eager/python/datasets_test.py8
-rw-r--r--tensorflow/core/framework/node_def_util.cc22
-rw-r--r--tensorflow/core/kernels/function_ops.cc6
-rw-r--r--tensorflow/core/kernels/iterator_ops.cc5
4 files changed, 14 insertions, 27 deletions
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py
index cbd195d909..a2da6b28c6 100644
--- a/tensorflow/contrib/eager/python/datasets_test.py
+++ b/tensorflow/contrib/eager/python/datasets_test.py
@@ -61,6 +61,14 @@ class IteratorTest(test.TestCase):
got = [x.numpy() for x in it]
self.assertAllEqual([0, 4, 16, 36], got)
+ def testMultipleIteratorsOnADatasetThatUsesFunctions(self):
+ ds = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(math_ops.square)
+
+ got1 = [x.numpy() for x in datasets.Iterator(ds)]
+ self.assertAllEqual([1, 4, 9, 16, 25, 36], got1)
+ got2 = [x.numpy() for x in datasets.Iterator(ds)]
+ self.assertAllEqual(got1, got2)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index b76ae23cb9..f039497f13 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -240,7 +240,7 @@ DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t;
ProtoShortDebugString(v),
" that can't be converted to a Tensor");
})
-
+DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
#undef DEFINE_GET_ATTR
static const string& kEmptyString = *new string();
@@ -286,26 +286,6 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
return Status::OK();
}
-Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
- NameAttrList* value) {
- const AttrValue* attr_value;
- TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
- TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
- *value = attr_value->func();
- return Status::OK();
-}
-
-Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
- std::vector<NameAttrList>* value) {
- const AttrValue* attr_value;
- TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
- TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(func)"));
- for (const auto& v : attr_value->list().func()) {
- value->emplace_back(v);
- }
- return Status::OK();
-}
-
namespace { // Helper for InOutTypesForNode().
Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index 629e29958f..a7206f6258 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -290,7 +290,7 @@ class RemoteCallOp : public AsyncOpKernel {
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
const Tensor* target;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
- AttrValueMap attr_values = func_->attr();
+ AttrValueMap attr_values = func_.attr();
AttrValue v;
const string& target_device = target->scalar<string>()();
v.set_s(target_device);
@@ -302,7 +302,7 @@ class RemoteCallOp : public AsyncOpKernel {
done);
FunctionLibraryRuntime::Handle handle;
OP_REQUIRES_OK_ASYNC(
- ctx, lib->Instantiate(func_->name(), AttrSlice(&attr_values), &handle),
+ ctx, lib->Instantiate(func_.name(), AttrSlice(&attr_values), &handle),
done);
OpInputList arguments;
@@ -336,7 +336,7 @@ class RemoteCallOp : public AsyncOpKernel {
private:
string target_;
- const NameAttrList* func_;
+ NameAttrList func_;
TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp);
};
diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc
index 66f08675c1..089f3f7bb4 100644
--- a/tensorflow/core/kernels/iterator_ops.cc
+++ b/tensorflow/core/kernels/iterator_ops.cc
@@ -227,9 +227,8 @@ class OneShotIteratorOp : public AsyncOpKernel {
OP_REQUIRES(ctx, shared_name.empty(),
errors::InvalidArgument("OneShotIteratorOp does not currently "
"support the 'shared_name' attr."));
- const NameAttrList* dataset_factory_func;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dataset_factory", &dataset_factory_func));
- dataset_factory_func_ = *dataset_factory_func;
+ OP_REQUIRES_OK(ctx,
+ ctx->GetAttr("dataset_factory", &dataset_factory_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}