diff options
-rw-r--r-- | tensorflow/contrib/eager/python/datasets_test.py | 8 | ||||
-rw-r--r-- | tensorflow/core/framework/node_def_util.cc | 22 | ||||
-rw-r--r-- | tensorflow/core/kernels/function_ops.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/kernels/iterator_ops.cc | 5 |
4 files changed, 14 insertions, 27 deletions
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index cbd195d909..a2da6b28c6 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -61,6 +61,14 @@ class IteratorTest(test.TestCase): got = [x.numpy() for x in it] self.assertAllEqual([0, 4, 16, 36], got) + def testMultipleIteratorsOnADatasetThatUsesFunctions(self): + ds = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(math_ops.square) + + got1 = [x.numpy() for x in datasets.Iterator(ds)] + self.assertAllEqual([1, 4, 9, 16, 25, 36], got1) + got2 = [x.numpy() for x in datasets.Iterator(ds)] + self.assertAllEqual(got1, got2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index b76ae23cb9..f039497f13 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -240,7 +240,7 @@ DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t; ProtoShortDebugString(v), " that can't be converted to a Tensor"); }) - +DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;); #undef DEFINE_GET_ATTR static const string& kEmptyString = *new string(); @@ -286,26 +286,6 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, return Status::OK(); } -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - NameAttrList* value) { - const AttrValue* attr_value; - TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); - TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func")); - *value = attr_value->func(); - return Status::OK(); -} - -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - std::vector<NameAttrList>* value) { - const AttrValue* attr_value; - TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); - TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(func)")); - for (const auto& v : attr_value->list().func()) { - value->emplace_back(v); - } - return Status::OK(); -} - namespace { // Helper for InOutTypesForNode(). Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def, diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 629e29958f..a7206f6258 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -290,7 +290,7 @@ class RemoteCallOp : public AsyncOpKernel { void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { const Tensor* target; OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done); - AttrValueMap attr_values = func_->attr(); + AttrValueMap attr_values = func_.attr(); AttrValue v; const string& target_device = target->scalar<string>()(); v.set_s(target_device); @@ -302,7 +302,7 @@ class RemoteCallOp : public AsyncOpKernel { done); FunctionLibraryRuntime::Handle handle; OP_REQUIRES_OK_ASYNC( - ctx, lib->Instantiate(func_->name(), AttrSlice(&attr_values), &handle), + ctx, lib->Instantiate(func_.name(), AttrSlice(&attr_values), &handle), done); OpInputList arguments; @@ -336,7 +336,7 @@ class RemoteCallOp : public AsyncOpKernel { private: string target_; - const NameAttrList* func_; + NameAttrList func_; TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp); }; diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc index 66f08675c1..089f3f7bb4 100644 --- a/tensorflow/core/kernels/iterator_ops.cc +++ b/tensorflow/core/kernels/iterator_ops.cc @@ -227,9 +227,8 @@ class OneShotIteratorOp : public AsyncOpKernel { OP_REQUIRES(ctx, shared_name.empty(), errors::InvalidArgument("OneShotIteratorOp does not currently " "support the 'shared_name' attr.")); - const NameAttrList* dataset_factory_func; - OP_REQUIRES_OK(ctx, ctx->GetAttr("dataset_factory", &dataset_factory_func)); - dataset_factory_func_ = *dataset_factory_func; + OP_REQUIRES_OK(ctx, + ctx->GetAttr("dataset_factory", &dataset_factory_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } |