aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/dataset.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/dataset.h')
-rw-r--r--tensorflow/core/framework/dataset.h19
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 23dc903caf..d8618f391e 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -459,6 +459,8 @@ class DatasetBase : public core::RefCounted {
virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const = 0;
+
+ friend class DatasetToGraphOp; // For access to graph related members.
};
// Base-class for datasets that are built by ops.
@@ -584,6 +586,23 @@ class DatasetOpKernel : public OpKernel {
*output = argument_t->scalar<T>()();
return Status::OK();
}
+
+ template <typename T>
+ Status ParseVectorArgument(OpKernelContext* ctx,
+ const StringPiece& argument_name,
+ std::vector<T>* output) {
+ const Tensor* argument_t;
+ TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
+ if (!TensorShapeUtils::IsVector(argument_t->shape())) {
+ return errors::InvalidArgument(argument_name, " must be a vector");
+ }
+ int size = argument_t->vec<T>().size();
+ output->reserve(size);
+ for (int i = 0; i < size; ++i) {
+ output->push_back(argument_t->vec<T>()(i));
+ }
+ return Status::OK();
+ }
};
// Encapsulates the work required to plug unary Datasets into the core