aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/generator_dataset_op.cc
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-08-07 16:42:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 16:51:57 -0700
commit75ee9f5e9bc04c312363d9c0836dd9c3851ead64 (patch)
treef17ca719eb10753eeb45206c7d7ae4303a7379cd /tensorflow/core/kernels/data/generator_dataset_op.cc
parent6f07c3f37425ca0f893cd7f2f457830662ffed02 (diff)
Making PrefetchToDevice work on XLA Compile on Demand mode. Also adds a bunch of dataset / iterator kernel registrations for XLA.
PiperOrigin-RevId: 207802858
Diffstat (limited to 'tensorflow/core/kernels/data/generator_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc292
1 files changed, 137 insertions, 155 deletions
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index 0981e42ba1..c4dd849b8b 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -15,192 +15,174 @@ limitations under the License.
#include <iterator>
#include <vector>
-#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/kernels/data/generator_dataset_op.h"
+
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-namespace {
-
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class GeneratorDatasetOp : public DatasetOpKernel {
+class GeneratorDatasetOp::Dataset : public GraphDatasetBase {
public:
- explicit GeneratorDatasetOp(OpKernelConstruction* ctx)
- : DatasetOpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("next_func", &next_func_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func,
+ std::unique_ptr<CapturedFunction> next_func,
+ std::unique_ptr<CapturedFunction> finalize_func,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : GraphDatasetBase(ctx),
+ init_func_(std::move(init_func)),
+ next_func_(std::move(next_func)),
+ finalize_func_(std::move(finalize_func)),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Generator")}));
}
- void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- OpInputList init_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args",
- &init_func_other_args_input));
- std::vector<Tensor> init_func_other_args;
- init_func_other_args.reserve(init_func_other_args_input.size());
- for (const Tensor& t : init_func_other_args_input) {
- init_func_other_args.push_back(t);
- }
- std::unique_ptr<CapturedFunction> init_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- init_func_, std::move(init_func_other_args), &init_func));
-
- OpInputList next_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args",
- &next_func_other_args_input));
- std::vector<Tensor> next_func_other_args;
- next_func_other_args.reserve(next_func_other_args_input.size());
- for (const Tensor& t : next_func_other_args_input) {
- next_func_other_args.push_back(t);
- }
- std::unique_ptr<CapturedFunction> next_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- next_func_, std::move(next_func_other_args), &next_func));
-
- OpInputList finalize_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args",
- &finalize_func_other_args_input));
- std::vector<Tensor> finalize_func_other_args;
- finalize_func_other_args.reserve(finalize_func_other_args_input.size());
- for (const Tensor& t : finalize_func_other_args_input) {
- finalize_func_other_args.push_back(t);
- }
- std::unique_ptr<CapturedFunction> finalize_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- finalize_func_, std::move(finalize_func_other_args),
- &finalize_func));
-
- *output =
- new Dataset(ctx, std::move(init_func), std::move(next_func),
- std::move(finalize_func), output_types_, output_shapes_);
+ const DataTypeVector& output_dtypes() const override { return output_types_; }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
}
+ string DebugString() const override { return "GeneratorDatasetOp::Dataset"; }
+
private:
- class Dataset : public GraphDatasetBase {
+ class Iterator : public DatasetIterator<Dataset> {
public:
- Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func,
- std::unique_ptr<CapturedFunction> next_func,
- std::unique_ptr<CapturedFunction> finalize_func,
- const DataTypeVector& output_types,
- const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
- init_func_(std::move(init_func)),
- next_func_(std::move(next_func)),
- finalize_func_(std::move(finalize_func)),
- output_types_(output_types),
- output_shapes_(output_shapes) {}
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Generator")}));
- }
-
- const DataTypeVector& output_dtypes() const override {
- return output_types_;
- }
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return output_shapes_;
- }
-
- string DebugString() const override {
- return "GeneratorDatasetOp::Dataset";
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- ~Iterator() override {
- if (!finalized_) {
- std::vector<Tensor> ignored;
- Status s =
- dataset()->finalize_func_->RunInstantiated(state_, &ignored);
- if (!s.ok()) {
- LOG(WARNING)
- << "Error occurred when finalizing GeneratorDataset iterator: "
- << s;
- }
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ ~Iterator() override {
+ if (!finalized_) {
+ std::vector<Tensor> ignored;
+ Status s = dataset()->finalize_func_->RunInstantiated(state_, &ignored);
+ if (!s.ok()) {
+ LOG(WARNING)
+ << "Error occurred when finalizing GeneratorDataset iterator: "
+ << s;
}
}
+ }
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- mutex_lock l(mu_);
-
- if (!initialized_) {
- TF_RETURN_IF_ERROR(
- dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
- // Explicitly instantiate the finalize function here so that
- // we can invoke it in the destructor.
- TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
- initialized_ = true;
- }
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+
+ if (!initialized_) {
+ TF_RETURN_IF_ERROR(
+ dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
+ // Explicitly instantiate the finalize function here so that
+ // we can invoke it in the destructor.
+ TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
+ initialized_ = true;
+ }
- if (finalized_) {
- *end_of_sequence = true;
- return Status::OK();
- }
+ if (finalized_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
- Status s = dataset()->next_func_->RunWithBorrowedArgs(ctx, state_,
- out_tensors);
- if (s.ok()) {
- *end_of_sequence = false;
- } else if (errors::IsOutOfRange(s)) {
- // `next_func` may deliberately raise `errors::OutOfRange`
- // to indicate that we should terminate the iteration.
- s = Status::OK();
- *end_of_sequence = true;
-
- // NOTE(mrry): We ignore any tensors returned by the
- // finalize function.
- std::vector<Tensor> ignored;
- TF_RETURN_IF_ERROR(
- dataset()->finalize_func_->RunInstantiated(state_, &ignored));
- finalized_ = true;
- }
- return s;
+ Status s =
+ dataset()->next_func_->RunWithBorrowedArgs(ctx, state_, out_tensors);
+ if (s.ok()) {
+ *end_of_sequence = false;
+ } else if (errors::IsOutOfRange(s)) {
+ // `next_func` may deliberately raise `errors::OutOfRange`
+ // to indicate that we should terminate the iteration.
+ s = Status::OK();
+ *end_of_sequence = true;
+
+ // NOTE(mrry): We ignore any tensors returned by the
+ // finalize function.
+ std::vector<Tensor> ignored;
+ TF_RETURN_IF_ERROR(
+ dataset()->finalize_func_->RunInstantiated(state_, &ignored));
+ finalized_ = true;
}
+ return s;
+ }
- private:
- mutex mu_;
- bool initialized_ GUARDED_BY(mu_) = false;
- bool finalized_ GUARDED_BY(mu_) = false;
- std::vector<Tensor> state_ GUARDED_BY(mu_);
- };
-
- const std::unique_ptr<CapturedFunction> init_func_;
- const std::unique_ptr<CapturedFunction> next_func_;
- const std::unique_ptr<CapturedFunction> finalize_func_;
- const DataTypeVector output_types_;
- const std::vector<PartialTensorShape> output_shapes_;
+ private:
+ mutex mu_;
+ bool initialized_ GUARDED_BY(mu_) = false;
+ bool finalized_ GUARDED_BY(mu_) = false;
+ std::vector<Tensor> state_ GUARDED_BY(mu_);
};
- DataTypeVector output_types_;
- std::vector<PartialTensorShape> output_shapes_;
- NameAttrList init_func_;
- NameAttrList next_func_;
- NameAttrList finalize_func_;
+ const std::unique_ptr<CapturedFunction> init_func_;
+ const std::unique_ptr<CapturedFunction> next_func_;
+ const std::unique_ptr<CapturedFunction> finalize_func_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
};
+GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
+ : DatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("next_func", &next_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+}
+
+void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
+ DatasetBase** output) {
+ OpInputList init_func_other_args_input;
+ OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args",
+ &init_func_other_args_input));
+ std::vector<Tensor> init_func_other_args;
+ init_func_other_args.reserve(init_func_other_args_input.size());
+ for (const Tensor& t : init_func_other_args_input) {
+ init_func_other_args.push_back(t);
+ }
+ std::unique_ptr<CapturedFunction> init_func;
+ OP_REQUIRES_OK(
+ ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args),
+ &init_func));
+
+ OpInputList next_func_other_args_input;
+ OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args",
+ &next_func_other_args_input));
+ std::vector<Tensor> next_func_other_args;
+ next_func_other_args.reserve(next_func_other_args_input.size());
+ for (const Tensor& t : next_func_other_args_input) {
+ next_func_other_args.push_back(t);
+ }
+ std::unique_ptr<CapturedFunction> next_func;
+ OP_REQUIRES_OK(
+ ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args),
+ &next_func));
+
+ OpInputList finalize_func_other_args_input;
+ OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args",
+ &finalize_func_other_args_input));
+ std::vector<Tensor> finalize_func_other_args;
+ finalize_func_other_args.reserve(finalize_func_other_args_input.size());
+ for (const Tensor& t : finalize_func_other_args_input) {
+ finalize_func_other_args.push_back(t);
+ }
+ std::unique_ptr<CapturedFunction> finalize_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(
+ finalize_func_, std::move(finalize_func_other_args),
+ &finalize_func));
+
+ *output =
+ new Dataset(ctx, std::move(init_func), std::move(next_func),
+ std::move(finalize_func), output_types_, output_shapes_);
+}
+
REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU),
GeneratorDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"),
GeneratorDatasetOp);
-} // namespace
-
} // namespace tensorflow