aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc16
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index 75af73df54..7e528a71be 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
namespace {
class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
@@ -33,16 +34,18 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
&stats_aggregator_resource));
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
- *output = new Dataset(ctx, input, stats_aggregator_resource);
+ *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource);
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const Tensor& resource_handle,
StatsAggregatorResource* stats_aggregator_resource)
: DatasetBase(DatasetContext(ctx)),
input_(input),
+ resource_handle_(resource_handle),
stats_aggregator_resource_(stats_aggregator_resource) {
input_->Ref();
stats_aggregator_resource_->Ref();
@@ -74,8 +77,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- return errors::Unimplemented("%s does not support serialization",
- DebugString());
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* resource_handle_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, resource_handle_node}, output));
+ return Status::OK();
}
private:
@@ -128,6 +136,7 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
};
const DatasetBase* const input_;
+ const Tensor resource_handle_;
StatsAggregatorResource* stats_aggregator_resource_;
};
};
@@ -135,4 +144,5 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU),
SetStatsAggregatorDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow