diff options
author | 2018-08-10 15:57:45 -0700 | |
---|---|---|
committer | 2018-08-10 16:01:41 -0700 | |
commit | 8d532ac40f4db7f5293610fd3c6e92a3f7409b76 (patch) | |
tree | f0a57897cac3baa4259ff8a9293befec3dcf1d35 /tensorflow/core | |
parent | 84af5e7061f82240828f72c7b484a1a66b8c4f7f (diff) |
[tf.data] Optimization checkpointing improvements.
This CL:
- changes the `OptimizeDataset` checkpointing logic to checkpoint the optimized dataset (as opposed to the original dataset + the optimizations, re-running optimization every time a checkpoint is restored)
- replaces `OpKernelContext` with newly introduced `SerializationContext` in the signature of `AsGraphDefInternal` to reduce the scope of the context and also simplify the logic for overriding the `FunctionLibraryDefinition` when optimizations take place
PiperOrigin-RevId: 208282562
Diffstat (limited to 'tensorflow/core')
40 files changed, 187 insertions, 124 deletions
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 6510f81ab7..e886ef7b8e 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -134,24 +134,22 @@ Status GraphDefBuilderWrapper::AddDataset( return Status::OK(); } -Status GraphDefBuilderWrapper::AddFunction(OpKernelContext* ctx, - const string& function_name) { +Status GraphDefBuilderWrapper::AddFunction( + const FunctionLibraryDefinition& flib_def, const string& function_name) { if (b_->HasFunction(function_name)) { - LOG(INFO) << "Function with name " << function_name << "already exists in" - << " the graph. It will not be added again."; + VLOG(1) << "Function with name " << function_name << "already exists in" + << " the graph. It will not be added again."; return Status::OK(); } - TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name)); - const FunctionLibraryDefinition* flib_def = - ctx->function_library()->GetFunctionLibraryDefinition(); - const FunctionDef* f_def = flib_def->Find(function_name); + TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(flib_def, function_name)); + const FunctionDef* f_def = flib_def.Find(function_name); if (f_def == nullptr) { return errors::InvalidArgument("Unable to find FunctionDef for ", function_name, " in the registry."); } FunctionDefLibrary def; *def.add_function() = *f_def; - const string gradient_func = flib_def->FindGradient(function_name); + const string gradient_func = flib_def.FindGradient(function_name); if (!gradient_func.empty()) { GradientDef* g_def = def.add_gradient(); g_def->set_function_name(function_name); @@ -162,19 +160,19 @@ Status GraphDefBuilderWrapper::AddFunction(OpKernelContext* ctx, // Recursively add functions in inputs of function_name. for (const NodeDef& node_def : f_def->node_def()) { const OpRegistrationData* op_reg_data = nullptr; - TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data)); + TF_RETURN_IF_ERROR(flib_def.LookUp(node_def.op(), &op_reg_data)); if (op_reg_data->is_function_op) { - TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name())); + TF_RETURN_IF_ERROR(AddFunction(flib_def, op_reg_data->op_def.name())); } // Recursively add functions in attrs of this NodeDef. for (const auto& pair : node_def.attr()) { - TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, ctx)); + TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, flib_def)); } } // Recursively add functions in attrs of function_name. for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) { - TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, ctx)); + TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, flib_def)); } return Status::OK(); } @@ -196,7 +194,7 @@ bool GraphDefBuilderWrapper::HasAttr(const string& op_type_name, return HasAttr(op_def, attr_name); } -Status GraphDatasetBase::Serialize(OpKernelContext* ctx, +Status GraphDatasetBase::Serialize(SerializationContext* ctx, string* serialized_graph_def, string* output_node) const { GraphDefBuilder b; diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 0e3d5adecd..66e836f9a6 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -157,7 +157,8 @@ class GraphDefBuilderWrapper { // name `function_name` is not found in the FunctionLibraryDefinition, returns // an InvalidArgumentError. If the function with name `function_name` or any // of its dependent functions are stateful, returns an InvalidArgument error. - Status AddFunction(OpKernelContext* ctx, const string& function_name); + Status AddFunction(const FunctionLibraryDefinition& flib_def, + const string& function_name); template <typename T> void BuildAttrValue(const T& value, AttrValue* attr) { @@ -167,18 +168,16 @@ class GraphDefBuilderWrapper { private: void AddTensorInternal(const Tensor& val, Node** output); - Status EnsureFunctionIsStateless(OpKernelContext* ctx, + Status EnsureFunctionIsStateless(const FunctionLibraryDefinition& flib_def, const string& function_name) const { - const FunctionLibraryDefinition* lib_def = - ctx->function_library()->GetFunctionLibraryDefinition(); - const FunctionDef* function_def = lib_def->Find(function_name); + const FunctionDef* function_def = flib_def.Find(function_name); if (!function_def) { return errors::InvalidArgument("Unable to find FunctionDef for ", function_name, " in registry."); } for (const NodeDef& node_def : function_def->node_def()) { const OpDef* op_def; - TF_RETURN_IF_ERROR(lib_def->LookUpOpDef(node_def.op(), &op_def)); + TF_RETURN_IF_ERROR(flib_def.LookUpOpDef(node_def.op(), &op_def)); // TODO(b/65524810): Hack to allow functions to capture Dataset op // nodes needed for FlatMap. Currently, source datasets nodes have been // marked stateful to avoid constant folding since we do not have a @@ -220,12 +219,13 @@ class GraphDefBuilderWrapper { return false; } - Status AddAttrFunctions(const AttrValue& attr_value, OpKernelContext* ctx) { + Status AddAttrFunctions(const AttrValue& attr_value, + const FunctionLibraryDefinition& flib_def) { if (attr_value.has_func()) { - TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name())); + TF_RETURN_IF_ERROR(AddFunction(flib_def, attr_value.func().name())); } else if (attr_value.has_list()) { for (const NameAttrList& name_attr_list : attr_value.list().func()) { - TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name())); + TF_RETURN_IF_ERROR(AddFunction(flib_def, name_attr_list.name())); } } return Status::OK(); @@ -236,21 +236,17 @@ class GraphDefBuilderWrapper { class StatsAggregator; -// A cut-down version of OpKernelContext for running computations in -// iterators. Note that we cannot simply use OpKernelContext here -// because we might run computation in an iterator whose lifetime is -// not nested within the lifetime of a single OpKernelContext -// (e.g. asynchronous prefetching). +// A cut-down version of `OpKernelContext` for running computations in +// iterators. Note that we cannot simply use `OpKernelContext` here because we +// might run computation in an iterator whose lifetime is not nested within the +// lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching). // -// TODO(mrry): We will probably need to support more of -// OpKernelContext here. For example, should allocation be handled by -// the IteratorContext? -// TODO(mrry): We're making some daring assumptions about the lifetime -// of the runner passed in here. A runner will be deleted when the original -// step ends, but all existing runners only close over session-lifetime (or -// longer-lived) state, so we can make a copy of the function. There's nothing -// in the definition of the API from which we took the runner to guarantee that -// what we are doing is safe. We should formalize the properties here. +// TODO(mrry): We're making some daring assumptions about the lifetime of the +// runner passed in here. A runner will be deleted when the original step ends, +// but all existing runners only close over session-lifetime (or longer-lived) +// state, so we can make a copy of the function. There's nothing in the +// definition of the API from which we took the runner to guarantee that what we +// are doing is safe. We should formalize the properties here. class IteratorContext { public: struct Params { @@ -318,6 +314,23 @@ class IteratorContext { Params params_; }; +// Aggregates runtime support needed for dataset and iterator serialization. +class SerializationContext { + public: + struct Params { + const FunctionLibraryDefinition* flib_def; // Not owned. + }; + + explicit SerializationContext(Params params) : params_(std::move(params)) {} + + const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; } + + private: + Params params_; + + TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext); +}; + // Represents the current position in a range of outputs, where the // range of outputs is typically represented by an `DatasetBase`, // defined below. @@ -357,7 +370,7 @@ class IteratorBase { virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); } // Saves the state of this iterator. - virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { + virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) { return SaveInternal(writer); } @@ -427,8 +440,10 @@ class DatasetBase : public core::RefCounted { virtual string DebugString() const = 0; // Serializes the dataset and writes it to the `writer`. - virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const { - return errors::Unimplemented("DatasetBase::Save"); + virtual Status Save(SerializationContext* ctx, + IteratorStateWriter* writer) const { + return errors::Unimplemented("%s does not support serialization", + DebugString()); } protected: @@ -439,13 +454,14 @@ class DatasetBase : public core::RefCounted { class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { public: DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} - Status AddInputDataset(OpKernelContext* ctx, const DatasetBase* dataset, - Node** output) { + Status AddInputDataset(SerializationContext* ctx, + const DatasetBase* dataset, Node** output) { return dataset->AsGraphDefInternal(ctx, this, output); } }; - virtual Status AsGraphDefInternal(OpKernelContext* ctx, + // TODO(jsimsa): Consolidate overloading into a single method. + virtual Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** node) const { return AsGraphDefInternal(b, node); @@ -453,7 +469,8 @@ class DatasetBase : public core::RefCounted { virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b, Node** node) const { - return errors::Unimplemented("AsGraphDefInternal"); + return errors::Unimplemented("%s does not support serialization", + DebugString()); } virtual std::unique_ptr<IteratorBase> MakeIteratorInternal( @@ -470,7 +487,7 @@ class GraphDatasetBase : public DatasetBase { const string op_name() const { return op_name_; } - Status Save(OpKernelContext* ctx, + Status Save(SerializationContext* ctx, IteratorStateWriter* writer) const override { string serialized_graph_def; string output_node; @@ -490,7 +507,7 @@ class GraphDatasetBase : public DatasetBase { TF_EXPORT static const char kDatasetGraphOutputNodeKey[]; private: - Status Serialize(OpKernelContext* ctx, string* serialized_graph_def, + Status Serialize(SerializationContext* ctx, string* serialized_graph_def, string* output_node) const; const string op_name_; @@ -539,7 +556,7 @@ class DatasetBaseIterator : public IteratorBase { return s; } - Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) final { + Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final { TF_RETURN_IF_ERROR(params_.dataset->Save(ctx, writer)); return IteratorBase::Save(ctx, writer); } diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index 2ee4548621..5295c9d2a6 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -96,7 +96,8 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 4f23d07aae..3762e403a9 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -85,7 +85,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph)); @@ -566,7 +567,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index 98282d74a9..6393005cdc 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -80,7 +80,8 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph)); diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc index d15cdf215e..c71d027f23 100644 --- a/tensorflow/core/kernels/data/dataset_ops.cc +++ b/tensorflow/core/kernels/data/dataset_ops.cc @@ -32,7 +32,11 @@ class DatasetToGraphOp : public OpKernel { GraphDefBuilder b; DatasetBase::DatasetGraphDefBuilder db(&b); Node* input_node = nullptr; - OP_REQUIRES_OK(ctx, db.AddInputDataset(ctx, dataset, &input_node)); + SerializationContext::Params params; + params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + SerializationContext serialization_ctx(params); + OP_REQUIRES_OK( + ctx, db.AddInputDataset(&serialization_ctx, dataset, &input_node)); GraphDef graph_def; OP_REQUIRES_OK(ctx, b.ToGraphDef(&graph_def)); Tensor* result; diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc index 77a04ef3f1..9105587cf4 100644 --- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc @@ -115,7 +115,8 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc index 2f6479de16..4b6d808af0 100644 --- a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc @@ -80,7 +80,8 @@ class FilterByLastComponentDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index aebc2f065f..b11d7cf2ef 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -109,9 +109,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { string DebugString() const override { return "FilterDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name())); Node* input_graph_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index aae3f19c0d..3419eed6c6 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -91,9 +91,10 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { string DebugString() const override { return "FlatMapDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc index f245fc402d..bcf0adacc7 100644 --- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc @@ -106,12 +106,14 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func().name())); - TF_RETURN_IF_ERROR(b->AddFunction(ctx, init_func().name())); - TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func().name())); - TF_RETURN_IF_ERROR(b->AddFunction(ctx, finalize_func().name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), key_func().name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), init_func().name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), reduce_func().name())); + TF_RETURN_IF_ERROR( + b->AddFunction(ctx->flib_def(), finalize_func().name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index f6664fc5fb..683a50e71c 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -136,11 +136,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func_.name())); - TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func_.name())); - TF_RETURN_IF_ERROR(b->AddFunction(ctx, window_size_func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), key_func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), reduce_func_.name())); + TF_RETURN_IF_ERROR( + b->AddFunction(ctx->flib_def(), window_size_func_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 20096a6590..8fee29d4d0 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -114,9 +114,10 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name())); Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); Node* cycle_length_node; diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index e2df14337c..da9d29dd76 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -116,7 +116,7 @@ class IteratorResource : public ResourceBase { } } - Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { + Status Save(SerializationContext* ctx, IteratorStateWriter* writer) { std::shared_ptr<IteratorBase> captured_iterator(iterator_); if (captured_iterator) { return captured_iterator->Save(ctx, writer); @@ -386,10 +386,13 @@ class IteratorStateVariant { // that it can be written on the next call to Encode(). Status InitializeFromIterator(OpKernelContext* ctx, IteratorResource* iterator_resource) { + SerializationContext::Params params; + params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + SerializationContext serialization_ctx(params); data_.reset(new VariantTensorData()); data_->set_type_name(TypeName()); VariantTensorDataWriter writer(data_.get()); - TF_RETURN_IF_ERROR(iterator_resource->Save(ctx, &writer)); + TF_RETURN_IF_ERROR(iterator_resource->Save(&serialization_ctx, &writer)); TF_RETURN_IF_ERROR(writer.Flush()); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index e0f164e784..51a7fd23a8 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -144,9 +144,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), map_fn_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_size_node; diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 852c942e99..ec9e12453b 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -89,9 +89,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel { string DebugString() const override { return "MapDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 14a0aa7844..8add049123 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -80,15 +80,22 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::Optimize")})); + // We do not add a token for the optimization dataset to the prefix. The + // prefix is used to identify checkpoint elements and since the + // optimization dataset is excluded from the checkpoint, adding a token + // here would result in invalid checkpoint identifiers. + return std::unique_ptr<IteratorBase>(new Iterator({this, prefix})); } Status Optimize(OpKernelContext* ctx) { GraphDefBuilder b; DatasetGraphDefBuilder db(&b); Node* input_node = nullptr; - TF_RETURN_IF_ERROR(db.AddInputDataset(ctx, input_, &input_node)); + SerializationContext::Params params; + params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + SerializationContext serialization_ctx(params); + TF_RETURN_IF_ERROR( + db.AddInputDataset(&serialization_ctx, input_, &input_node)); string output_node = input_node->name(); GraphDef graph_def; TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); @@ -119,14 +126,12 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { string DebugString() const override { return "OptimizeDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - Node* optimizations_node = nullptr; - TF_RETURN_IF_ERROR(b->AddVector(optimizations_, &optimizations_node)); - TF_RETURN_IF_ERROR( - b->AddDataset(this, {input_graph_node, optimizations_node}, output)); + // We only serialize the optimized dataset to avoid re-running + // optimizations when the input pipeline is restored from a checkpoint. + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, optimized_input_, output)); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index dd4b0f9f4c..755d46dac2 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -153,7 +153,8 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 1995708732..d2b83f9eab 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -134,9 +134,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name())); + TF_RETURN_IF_ERROR( + b->AddFunction(ctx->flib_def(), interleave_func_.name())); Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); Node* cycle_length_node; diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 2cbdde01ec..c56a7ea808 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -113,7 +113,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { // Input: input_dataset Node* input_graph_node = nullptr; @@ -137,7 +138,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { b->AddScalar(num_parallel_calls_, &num_parallel_calls)); // Attr: f - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name())); AttrValue f; b->BuildAttrValue(func_, &f); diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index ccdb7c1479..20148a4378 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -51,7 +51,8 @@ class PrefetchDatasetOp::Dataset : public GraphDatasetBase { string DebugString() const override { return "PrefetchDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc index ff166c3be7..7e48428b3f 100644 --- a/tensorflow/core/kernels/data/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/random_dataset_op.cc @@ -77,7 +77,8 @@ class RandomDatasetOp : public DatasetOpKernel { } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* seed = nullptr; Node* seed2 = nullptr; diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index 0b5c814767..50bd3dac4e 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -71,7 +71,8 @@ class RangeDatasetOp : public DatasetOpKernel { } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* start = nullptr; Node* stop = nullptr; diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc index 29654b9bca..6a71a7af1d 100644 --- a/tensorflow/core/kernels/data/reader_dataset_ops.cc +++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc @@ -109,7 +109,8 @@ class TextLineDatasetOp : public DatasetOpKernel { string DebugString() const override { return "TextLineDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* filenames = nullptr; Node* compression_type = nullptr; @@ -345,7 +346,8 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel { } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* filenames = nullptr; Node* header_bytes = nullptr; @@ -563,7 +565,8 @@ class TFRecordDatasetOp : public DatasetOpKernel { string DebugString() const override { return "TFRecordDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* filenames = nullptr; TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index 002b0ee596..093ea563b4 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -72,7 +72,8 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { string DebugString() const override { return "RepeatDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index b0874a5509..7c59874d96 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -106,9 +106,10 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { string DebugString() const override { return "ScanDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name())); Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); std::vector<Node*> initial_state_nodes; diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index caf1e2af3f..603c3feb79 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -429,7 +429,8 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase { } }; - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { mutex_lock l(mu_); Node* input_graph_node = nullptr; @@ -499,7 +500,8 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); @@ -584,7 +586,8 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index 2da496c423..61db6a0a54 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -68,7 +68,8 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { string DebugString() const override { return "SkipDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc index 6c0384a537..fd8c5ccd92 100644 --- a/tensorflow/core/kernels/data/slide_dataset_op.cc +++ b/tensorflow/core/kernels/data/slide_dataset_op.cc @@ -104,7 +104,8 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc index b5dff48d2d..9bb86e76a2 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc @@ -55,7 +55,8 @@ class Dataset : public GraphDatasetBase { } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* indices_node; TF_RETURN_IF_ERROR(b->AddTensor(sparse_tensor_.indices(), &indices_node)); diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc index 16652e792c..9b0190e3fc 100644 --- a/tensorflow/core/kernels/data/sql_dataset_ops.cc +++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc @@ -105,7 +105,8 @@ class SqlDatasetOp : public DatasetOpKernel { string DebugString() const override { return "SqlDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* driver_name_node; TF_RETURN_IF_ERROR(b->AddScalar(driver_name_, &driver_name_node)); diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc index 61a1937e4c..8465a1d2c0 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc @@ -70,14 +70,6 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { return "SetStatsAggregatorDatasetOp::Dataset"; } - protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, - Node** output) const override { - return errors::Unimplemented( - "Cannot currently serialize the `stats_aggregator` for a " - "SetStatsAggregatorDataset."); - } - private: class Iterator : public DatasetIterator<Dataset> { public: diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc index ba5ed39d83..85fed31773 100644 --- a/tensorflow/core/kernels/data/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc @@ -76,7 +76,8 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); @@ -175,7 +176,8 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); @@ -280,7 +282,8 @@ class FeatureStatsDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc index 0114629ef6..d4a3c7a978 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.cc +++ b/tensorflow/core/kernels/data/take_dataset_op.cc @@ -69,7 +69,8 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { string DebugString() const override { return "TakeDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc index 36fc434d8f..ac2015c865 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc @@ -67,7 +67,8 @@ class TensorDatasetOp : public DatasetOpKernel { string DebugString() const override { return "TensorDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { std::vector<Node*> components; components.reserve(tensors_.size()); diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc index d728f1ab14..ea472e2b79 100644 --- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc @@ -99,7 +99,8 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph)); diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index 68ce324081..8f18d38f83 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -86,7 +86,8 @@ class TensorSliceDatasetOp : public DatasetOpKernel { } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { std::vector<Node*> components; components.reserve(tensors_.size()); diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc index 692b5d8819..02c3c5315a 100644 --- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc @@ -65,7 +65,8 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { string DebugString() const override { return "UnbatchDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc index c87214b3ef..f9fd5b5a83 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.cc +++ b/tensorflow/core/kernels/data/window_dataset_op.cc @@ -74,7 +74,8 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc index 8cc21cd2bc..63e9b99d4b 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op.cc @@ -77,7 +77,8 @@ class ZipDatasetOp : public DatasetOpKernel { string DebugString() const override { return "ZipDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { std::vector<Node*> input_graph_nodes; input_graph_nodes.reserve(inputs_.size()); |