diff options
Diffstat (limited to 'tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc index 75af73df54..7e528a71be 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { namespace { class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { @@ -33,16 +34,18 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { &stats_aggregator_resource)); core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource); - *output = new Dataset(ctx, input, stats_aggregator_resource); + *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource); } private: class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, + const Tensor& resource_handle, StatsAggregatorResource* stats_aggregator_resource) : DatasetBase(DatasetContext(ctx)), input_(input), + resource_handle_(resource_handle), stats_aggregator_resource_(stats_aggregator_resource) { input_->Ref(); stats_aggregator_resource_->Ref(); @@ -74,8 +77,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - return errors::Unimplemented("%s does not support serialization", - DebugString()); + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* resource_handle_node = nullptr; + TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {input_graph_node, resource_handle_node}, output)); + return Status::OK(); } private: @@ -128,6 +136,7 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { }; const DatasetBase* const input_; + const Tensor resource_handle_; StatsAggregatorResource* stats_aggregator_resource_; }; }; @@ -135,4 +144,5 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU), SetStatsAggregatorDatasetOp); } // namespace +} // namespace data } // namespace tensorflow |