diff options
Diffstat (limited to 'tensorflow/core/framework/dataset.h')
-rw-r--r-- | tensorflow/core/framework/dataset.h | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 23dc903caf..d8618f391e 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -459,6 +459,8 @@ class DatasetBase : public core::RefCounted { virtual std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const = 0; + + friend class DatasetToGraphOp; // For access to graph related members. }; // Base-class for datasets that are built by ops. @@ -584,6 +586,23 @@ class DatasetOpKernel : public OpKernel { *output = argument_t->scalar<T>()(); return Status::OK(); } + + template <typename T> + Status ParseVectorArgument(OpKernelContext* ctx, + const StringPiece& argument_name, + std::vector<T>* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsVector(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a vector"); + } + int size = argument_t->vec<T>().size(); + output->reserve(size); + for (int i = 0; i < size; ++i) { + output->push_back(argument_t->vec<T>()(i)); + } + return Status::OK(); + } }; // Encapsulates the work required to plug unary Datasets into the core |