aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-08-10 15:57:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 16:01:41 -0700
commit8d532ac40f4db7f5293610fd3c6e92a3f7409b76 (patch)
treef0a57897cac3baa4259ff8a9293befec3dcf1d35 /tensorflow/core
parent84af5e7061f82240828f72c7b484a1a66b8c4f7f (diff)
[tf.data] Optimization checkpointing improvements.
This CL: - changes the `OptimizeDataset` checkpointing logic to checkpoint the optimized dataset (as opposed to the original dataset + the optimizations, re-running optimization every time a checkpoint is restored) - replaces `OpKernelContext` with newly introduced `SerializationContext` in the signature of `AsGraphDefInternal` to reduce the scope of the context and also simplify the logic for overriding the `FunctionLibraryDefinition` when optimizations take place PiperOrigin-RevId: 208282562
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/framework/dataset.cc26
-rw-r--r--tensorflow/core/framework/dataset.h83
-rw-r--r--tensorflow/core/kernels/data/batch_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/cache_dataset_ops.cc6
-rw-r--r--tensorflow/core/kernels/data/concatenate_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/dataset_ops.cc6
-rw-r--r--tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/filter_by_component_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc5
-rw-r--r--tensorflow/core/kernels/data/flat_map_dataset_op.cc5
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc12
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc10
-rw-r--r--tensorflow/core/kernels/data/interleave_dataset_op.cc5
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc7
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc5
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc5
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc25
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc5
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/random_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/range_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/reader_dataset_ops.cc9
-rw-r--r--tensorflow/core/kernels/data/repeat_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc5
-rw-r--r--tensorflow/core/kernels/data/shuffle_dataset_op.cc9
-rw-r--r--tensorflow/core/kernels/data/skip_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/slide_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/sql_dataset_ops.cc3
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/stats_dataset_ops.cc9
-rw-r--r--tensorflow/core/kernels/data/take_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/tensor_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/tensor_queue_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/tensor_slice_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/unbatch_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/window_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/zip_dataset_op.cc3
40 files changed, 187 insertions, 124 deletions
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index 6510f81ab7..e886ef7b8e 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -134,24 +134,22 @@ Status GraphDefBuilderWrapper::AddDataset(
return Status::OK();
}
-Status GraphDefBuilderWrapper::AddFunction(OpKernelContext* ctx,
- const string& function_name) {
+Status GraphDefBuilderWrapper::AddFunction(
+ const FunctionLibraryDefinition& flib_def, const string& function_name) {
if (b_->HasFunction(function_name)) {
- LOG(INFO) << "Function with name " << function_name << "already exists in"
- << " the graph. It will not be added again.";
+ VLOG(1) << "Function with name " << function_name << "already exists in"
+ << " the graph. It will not be added again.";
return Status::OK();
}
- TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name));
- const FunctionLibraryDefinition* flib_def =
- ctx->function_library()->GetFunctionLibraryDefinition();
- const FunctionDef* f_def = flib_def->Find(function_name);
+ TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(flib_def, function_name));
+ const FunctionDef* f_def = flib_def.Find(function_name);
if (f_def == nullptr) {
return errors::InvalidArgument("Unable to find FunctionDef for ",
function_name, " in the registry.");
}
FunctionDefLibrary def;
*def.add_function() = *f_def;
- const string gradient_func = flib_def->FindGradient(function_name);
+ const string gradient_func = flib_def.FindGradient(function_name);
if (!gradient_func.empty()) {
GradientDef* g_def = def.add_gradient();
g_def->set_function_name(function_name);
@@ -162,19 +160,19 @@ Status GraphDefBuilderWrapper::AddFunction(OpKernelContext* ctx,
// Recursively add functions in inputs of function_name.
for (const NodeDef& node_def : f_def->node_def()) {
const OpRegistrationData* op_reg_data = nullptr;
- TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data));
+ TF_RETURN_IF_ERROR(flib_def.LookUp(node_def.op(), &op_reg_data));
if (op_reg_data->is_function_op) {
- TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
+ TF_RETURN_IF_ERROR(AddFunction(flib_def, op_reg_data->op_def.name()));
}
// Recursively add functions in attrs of this NodeDef.
for (const auto& pair : node_def.attr()) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, ctx));
+ TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, flib_def));
}
}
// Recursively add functions in attrs of function_name.
for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, ctx));
+ TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, flib_def));
}
return Status::OK();
}
@@ -196,7 +194,7 @@ bool GraphDefBuilderWrapper::HasAttr(const string& op_type_name,
return HasAttr(op_def, attr_name);
}
-Status GraphDatasetBase::Serialize(OpKernelContext* ctx,
+Status GraphDatasetBase::Serialize(SerializationContext* ctx,
string* serialized_graph_def,
string* output_node) const {
GraphDefBuilder b;
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 0e3d5adecd..66e836f9a6 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -157,7 +157,8 @@ class GraphDefBuilderWrapper {
// name `function_name` is not found in the FunctionLibraryDefinition, returns
// an InvalidArgumentError. If the function with name `function_name` or any
// of its dependent functions are stateful, returns an InvalidArgument error.
- Status AddFunction(OpKernelContext* ctx, const string& function_name);
+ Status AddFunction(const FunctionLibraryDefinition& flib_def,
+ const string& function_name);
template <typename T>
void BuildAttrValue(const T& value, AttrValue* attr) {
@@ -167,18 +168,16 @@ class GraphDefBuilderWrapper {
private:
void AddTensorInternal(const Tensor& val, Node** output);
- Status EnsureFunctionIsStateless(OpKernelContext* ctx,
+ Status EnsureFunctionIsStateless(const FunctionLibraryDefinition& flib_def,
const string& function_name) const {
- const FunctionLibraryDefinition* lib_def =
- ctx->function_library()->GetFunctionLibraryDefinition();
- const FunctionDef* function_def = lib_def->Find(function_name);
+ const FunctionDef* function_def = flib_def.Find(function_name);
if (!function_def) {
return errors::InvalidArgument("Unable to find FunctionDef for ",
function_name, " in registry.");
}
for (const NodeDef& node_def : function_def->node_def()) {
const OpDef* op_def;
- TF_RETURN_IF_ERROR(lib_def->LookUpOpDef(node_def.op(), &op_def));
+ TF_RETURN_IF_ERROR(flib_def.LookUpOpDef(node_def.op(), &op_def));
// TODO(b/65524810): Hack to allow functions to capture Dataset op
// nodes needed for FlatMap. Currently, source datasets nodes have been
// marked stateful to avoid constant folding since we do not have a
@@ -220,12 +219,13 @@ class GraphDefBuilderWrapper {
return false;
}
- Status AddAttrFunctions(const AttrValue& attr_value, OpKernelContext* ctx) {
+ Status AddAttrFunctions(const AttrValue& attr_value,
+ const FunctionLibraryDefinition& flib_def) {
if (attr_value.has_func()) {
- TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name()));
+ TF_RETURN_IF_ERROR(AddFunction(flib_def, attr_value.func().name()));
} else if (attr_value.has_list()) {
for (const NameAttrList& name_attr_list : attr_value.list().func()) {
- TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name()));
+ TF_RETURN_IF_ERROR(AddFunction(flib_def, name_attr_list.name()));
}
}
return Status::OK();
@@ -236,21 +236,17 @@ class GraphDefBuilderWrapper {
class StatsAggregator;
-// A cut-down version of OpKernelContext for running computations in
-// iterators. Note that we cannot simply use OpKernelContext here
-// because we might run computation in an iterator whose lifetime is
-// not nested within the lifetime of a single OpKernelContext
-// (e.g. asynchronous prefetching).
+// A cut-down version of `OpKernelContext` for running computations in
+// iterators. Note that we cannot simply use `OpKernelContext` here because we
+// might run computation in an iterator whose lifetime is not nested within the
+// lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching).
//
-// TODO(mrry): We will probably need to support more of
-// OpKernelContext here. For example, should allocation be handled by
-// the IteratorContext?
-// TODO(mrry): We're making some daring assumptions about the lifetime
-// of the runner passed in here. A runner will be deleted when the original
-// step ends, but all existing runners only close over session-lifetime (or
-// longer-lived) state, so we can make a copy of the function. There's nothing
-// in the definition of the API from which we took the runner to guarantee that
-// what we are doing is safe. We should formalize the properties here.
+// TODO(mrry): We're making some daring assumptions about the lifetime of the
+// runner passed in here. A runner will be deleted when the original step ends,
+// but all existing runners only close over session-lifetime (or longer-lived)
+// state, so we can make a copy of the function. There's nothing in the
+// definition of the API from which we took the runner to guarantee that what we
+// are doing is safe. We should formalize the properties here.
class IteratorContext {
public:
struct Params {
@@ -318,6 +314,23 @@ class IteratorContext {
Params params_;
};
+// Aggregates runtime support needed for dataset and iterator serialization.
+class SerializationContext {
+ public:
+ struct Params {
+ const FunctionLibraryDefinition* flib_def; // Not owned.
+ };
+
+ explicit SerializationContext(Params params) : params_(std::move(params)) {}
+
+ const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; }
+
+ private:
+ Params params_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext);
+};
+
// Represents the current position in a range of outputs, where the
// range of outputs is typically represented by an `DatasetBase`,
// defined below.
@@ -357,7 +370,7 @@ class IteratorBase {
virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
// Saves the state of this iterator.
- virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) {
+ virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
return SaveInternal(writer);
}
@@ -427,8 +440,10 @@ class DatasetBase : public core::RefCounted {
virtual string DebugString() const = 0;
// Serializes the dataset and writes it to the `writer`.
- virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const {
- return errors::Unimplemented("DatasetBase::Save");
+ virtual Status Save(SerializationContext* ctx,
+ IteratorStateWriter* writer) const {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
}
protected:
@@ -439,13 +454,14 @@ class DatasetBase : public core::RefCounted {
class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
public:
DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
- Status AddInputDataset(OpKernelContext* ctx, const DatasetBase* dataset,
- Node** output) {
+ Status AddInputDataset(SerializationContext* ctx,
+ const DatasetBase* dataset, Node** output) {
return dataset->AsGraphDefInternal(ctx, this, output);
}
};
- virtual Status AsGraphDefInternal(OpKernelContext* ctx,
+ // TODO(jsimsa): Consolidate overloading into a single method.
+ virtual Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** node) const {
return AsGraphDefInternal(b, node);
@@ -453,7 +469,8 @@ class DatasetBase : public core::RefCounted {
virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
Node** node) const {
- return errors::Unimplemented("AsGraphDefInternal");
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
}
virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
@@ -470,7 +487,7 @@ class GraphDatasetBase : public DatasetBase {
const string op_name() const { return op_name_; }
- Status Save(OpKernelContext* ctx,
+ Status Save(SerializationContext* ctx,
IteratorStateWriter* writer) const override {
string serialized_graph_def;
string output_node;
@@ -490,7 +507,7 @@ class GraphDatasetBase : public DatasetBase {
TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
private:
- Status Serialize(OpKernelContext* ctx, string* serialized_graph_def,
+ Status Serialize(SerializationContext* ctx, string* serialized_graph_def,
string* output_node) const;
const string op_name_;
@@ -539,7 +556,7 @@ class DatasetBaseIterator : public IteratorBase {
return s;
}
- Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) final {
+ Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
TF_RETURN_IF_ERROR(params_.dataset->Save(ctx, writer));
return IteratorBase::Save(ctx, writer);
}
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index 2ee4548621..5295c9d2a6 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -96,7 +96,8 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 4f23d07aae..3762e403a9 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -85,7 +85,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
@@ -566,7 +567,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index 98282d74a9..6393005cdc 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -80,7 +80,8 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc
index d15cdf215e..c71d027f23 100644
--- a/tensorflow/core/kernels/data/dataset_ops.cc
+++ b/tensorflow/core/kernels/data/dataset_ops.cc
@@ -32,7 +32,11 @@ class DatasetToGraphOp : public OpKernel {
GraphDefBuilder b;
DatasetBase::DatasetGraphDefBuilder db(&b);
Node* input_node = nullptr;
- OP_REQUIRES_OK(ctx, db.AddInputDataset(ctx, dataset, &input_node));
+ SerializationContext::Params params;
+ params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+ SerializationContext serialization_ctx(params);
+ OP_REQUIRES_OK(
+ ctx, db.AddInputDataset(&serialization_ctx, dataset, &input_node));
GraphDef graph_def;
OP_REQUIRES_OK(ctx, b.ToGraphDef(&graph_def));
Tensor* result;
diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
index 77a04ef3f1..9105587cf4 100644
--- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
@@ -115,7 +115,8 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
index 2f6479de16..4b6d808af0 100644
--- a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
@@ -80,7 +80,8 @@ class FilterByLastComponentDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index aebc2f065f..b11d7cf2ef 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -109,9 +109,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "FilterDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
Node* input_graph_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index aae3f19c0d..3419eed6c6 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -91,9 +91,10 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "FlatMapDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index f245fc402d..bcf0adacc7 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -106,12 +106,14 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func().name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, init_func().name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func().name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, finalize_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), key_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), init_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), reduce_func().name()));
+ TF_RETURN_IF_ERROR(
+ b->AddFunction(ctx->flib_def(), finalize_func().name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index f6664fc5fb..683a50e71c 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -136,11 +136,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func_.name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func_.name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, window_size_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), key_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), reduce_func_.name()));
+ TF_RETURN_IF_ERROR(
+ b->AddFunction(ctx->flib_def(), window_size_func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index 20096a6590..8fee29d4d0 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -114,9 +114,10 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node;
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index e2df14337c..da9d29dd76 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -116,7 +116,7 @@ class IteratorResource : public ResourceBase {
}
}
- Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) {
+ Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
if (captured_iterator) {
return captured_iterator->Save(ctx, writer);
@@ -386,10 +386,13 @@ class IteratorStateVariant {
// that it can be written on the next call to Encode().
Status InitializeFromIterator(OpKernelContext* ctx,
IteratorResource* iterator_resource) {
+ SerializationContext::Params params;
+ params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+ SerializationContext serialization_ctx(params);
data_.reset(new VariantTensorData());
data_->set_type_name(TypeName());
VariantTensorDataWriter writer(data_.get());
- TF_RETURN_IF_ERROR(iterator_resource->Save(ctx, &writer));
+ TF_RETURN_IF_ERROR(iterator_resource->Save(&serialization_ctx, &writer));
TF_RETURN_IF_ERROR(writer.Flush());
return Status::OK();
}
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index e0f164e784..51a7fd23a8 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -144,9 +144,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), map_fn_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* batch_size_node;
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 852c942e99..ec9e12453b 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -89,9 +89,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "MapDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 14a0aa7844..8add049123 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -80,15 +80,22 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Optimize")}));
+ // We do not add a token for the optimization dataset to the prefix. The
+ // prefix is used to identify checkpoint elements and since the
+ // optimization dataset is excluded from the checkpoint, adding a token
+ // here would result in invalid checkpoint identifiers.
+ return std::unique_ptr<IteratorBase>(new Iterator({this, prefix}));
}
Status Optimize(OpKernelContext* ctx) {
GraphDefBuilder b;
DatasetGraphDefBuilder db(&b);
Node* input_node = nullptr;
- TF_RETURN_IF_ERROR(db.AddInputDataset(ctx, input_, &input_node));
+ SerializationContext::Params params;
+ params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+ SerializationContext serialization_ctx(params);
+ TF_RETURN_IF_ERROR(
+ db.AddInputDataset(&serialization_ctx, input_, &input_node));
string output_node = input_node->name();
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
@@ -119,14 +126,12 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "OptimizeDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
- Node* input_graph_node = nullptr;
- TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
- Node* optimizations_node = nullptr;
- TF_RETURN_IF_ERROR(b->AddVector(optimizations_, &optimizations_node));
- TF_RETURN_IF_ERROR(
- b->AddDataset(this, {input_graph_node, optimizations_node}, output));
+ // We only serialize the optimized dataset to avoid re-running
+ // optimizations when the input pipeline is restored from a checkpoint.
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, optimized_input_, output));
return Status::OK();
}
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index dd4b0f9f4c..755d46dac2 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -153,7 +153,8 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 1995708732..d2b83f9eab 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -134,9 +134,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
+ TF_RETURN_IF_ERROR(
+ b->AddFunction(ctx->flib_def(), interleave_func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node;
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 2cbdde01ec..c56a7ea808 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -113,7 +113,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
// Input: input_dataset
Node* input_graph_node = nullptr;
@@ -137,7 +138,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
b->AddScalar(num_parallel_calls_, &num_parallel_calls));
// Attr: f
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
AttrValue f;
b->BuildAttrValue(func_, &f);
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index ccdb7c1479..20148a4378 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -51,7 +51,8 @@ class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
string DebugString() const override { return "PrefetchDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc
index ff166c3be7..7e48428b3f 100644
--- a/tensorflow/core/kernels/data/random_dataset_op.cc
+++ b/tensorflow/core/kernels/data/random_dataset_op.cc
@@ -77,7 +77,8 @@ class RandomDatasetOp : public DatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* seed = nullptr;
Node* seed2 = nullptr;
diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc
index 0b5c814767..50bd3dac4e 100644
--- a/tensorflow/core/kernels/data/range_dataset_op.cc
+++ b/tensorflow/core/kernels/data/range_dataset_op.cc
@@ -71,7 +71,8 @@ class RangeDatasetOp : public DatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* start = nullptr;
Node* stop = nullptr;
diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc
index 29654b9bca..6a71a7af1d 100644
--- a/tensorflow/core/kernels/data/reader_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc
@@ -109,7 +109,8 @@ class TextLineDatasetOp : public DatasetOpKernel {
string DebugString() const override { return "TextLineDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* filenames = nullptr;
Node* compression_type = nullptr;
@@ -345,7 +346,8 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* filenames = nullptr;
Node* header_bytes = nullptr;
@@ -563,7 +565,8 @@ class TFRecordDatasetOp : public DatasetOpKernel {
string DebugString() const override { return "TFRecordDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* filenames = nullptr;
TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc
index 002b0ee596..093ea563b4 100644
--- a/tensorflow/core/kernels/data/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc
@@ -72,7 +72,8 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "RepeatDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index b0874a5509..7c59874d96 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -106,9 +106,10 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "ScanDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
std::vector<Node*> initial_state_nodes;
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index caf1e2af3f..603c3feb79 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -429,7 +429,8 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
}
};
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
mutex_lock l(mu_);
Node* input_graph_node = nullptr;
@@ -499,7 +500,8 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
@@ -584,7 +586,8 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc
index 2da496c423..61db6a0a54 100644
--- a/tensorflow/core/kernels/data/skip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/skip_dataset_op.cc
@@ -68,7 +68,8 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "SkipDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc
index 6c0384a537..fd8c5ccd92 100644
--- a/tensorflow/core/kernels/data/slide_dataset_op.cc
+++ b/tensorflow/core/kernels/data/slide_dataset_op.cc
@@ -104,7 +104,8 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
index b5dff48d2d..9bb86e76a2 100644
--- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
@@ -55,7 +55,8 @@ class Dataset : public GraphDatasetBase {
}
protected:
- Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* indices_node;
TF_RETURN_IF_ERROR(b->AddTensor(sparse_tensor_.indices(), &indices_node));
diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc
index 16652e792c..9b0190e3fc 100644
--- a/tensorflow/core/kernels/data/sql_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc
@@ -105,7 +105,8 @@ class SqlDatasetOp : public DatasetOpKernel {
string DebugString() const override { return "SqlDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* driver_name_node;
TF_RETURN_IF_ERROR(b->AddScalar(driver_name_, &driver_name_node));
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index 61a1937e4c..8465a1d2c0 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -70,14 +70,6 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
return "SetStatsAggregatorDatasetOp::Dataset";
}
- protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
- Node** output) const override {
- return errors::Unimplemented(
- "Cannot currently serialize the `stats_aggregator` for a "
- "SetStatsAggregatorDataset.");
- }
-
private:
class Iterator : public DatasetIterator<Dataset> {
public:
diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc
index ba5ed39d83..85fed31773 100644
--- a/tensorflow/core/kernels/data/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc
@@ -76,7 +76,8 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
@@ -175,7 +176,8 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
@@ -280,7 +282,8 @@ class FeatureStatsDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc
index 0114629ef6..d4a3c7a978 100644
--- a/tensorflow/core/kernels/data/take_dataset_op.cc
+++ b/tensorflow/core/kernels/data/take_dataset_op.cc
@@ -69,7 +69,8 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "TakeDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index 36fc434d8f..ac2015c865 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -67,7 +67,8 @@ class TensorDatasetOp : public DatasetOpKernel {
string DebugString() const override { return "TensorDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
std::vector<Node*> components;
components.reserve(tensors_.size());
diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
index d728f1ab14..ea472e2b79 100644
--- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
@@ -99,7 +99,8 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
index 68ce324081..8f18d38f83 100644
--- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
@@ -86,7 +86,8 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
std::vector<Node*> components;
components.reserve(tensors_.size());
diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
index 692b5d8819..02c3c5315a 100644
--- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
@@ -65,7 +65,8 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "UnbatchDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc
index c87214b3ef..f9fd5b5a83 100644
--- a/tensorflow/core/kernels/data/window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op.cc
@@ -74,7 +74,8 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc
index 8cc21cd2bc..63e9b99d4b 100644
--- a/tensorflow/core/kernels/data/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/zip_dataset_op.cc
@@ -77,7 +77,8 @@ class ZipDatasetOp : public DatasetOpKernel {
string DebugString() const override { return "ZipDatasetOp::Dataset"; }
protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
std::vector<Node*> input_graph_nodes;
input_graph_nodes.reserve(inputs_.size());