aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/iterator_ops.cc
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-04-18 16:01:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 16:04:37 -0700
commite9d47fbff0d644a75c6f3dcdcb852685ef515b64 (patch)
tree87b63378d33241f3fff820b0f70cdd89d2957406 /tensorflow/core/kernels/data/iterator_ops.cc
parent695da2d928b5927c0a4f73e352a597a19886f2cb (diff)
Adds dataset transformation function `set_stats_aggregator(..)`, which sets the given `stats_aggregator` for aggregating the input dataset stats.
PiperOrigin-RevId: 193432590
Diffstat (limited to 'tensorflow/core/kernels/data/iterator_ops.cc')
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc32
1 files changed, 1 insertions, 31 deletions
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 780f927a4f..4e4997d7b3 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -19,11 +19,11 @@ limitations under the License.
#include "tensorflow/core/framework/iterator.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/data/dataset.h"
-#include "tensorflow/core/kernels/data/stats_aggregator.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -203,10 +203,6 @@ class IteratorResource : public ResourceBase {
return Status::OK();
}
- void set_stats_aggregator(std::shared_ptr<StatsAggregator> stats_aggregator) {
- mutex_lock l(mu_);
- stats_aggregator_ = std::move(stats_aggregator);
- }
std::shared_ptr<StatsAggregator> stats_aggregator() {
tf_shared_lock l(mu_);
@@ -1075,30 +1071,6 @@ class DeserializeIteratorOp : public OpKernel {
}
};
-class IteratorSetStatsAggregatorOp : public OpKernel {
- public:
- explicit IteratorSetStatsAggregatorOp(OpKernelConstruction* ctx)
- : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- IteratorResource* iterator_resource;
- OP_REQUIRES_OK(
- ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
- core::ScopedUnref unref_iterator(iterator_resource);
-
- StatsAggregatorResource* stats_aggregator_resource;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
- &stats_aggregator_resource));
- core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
- // TODO(mrry): Consider allowing multiple StatsAggregator ops to
- // subscribe to updates, and/or unsubscribing.
- OP_REQUIRES(ctx, !iterator_resource->stats_aggregator(),
- errors::FailedPrecondition(
- "Iterator already associated with a StatsAggregator"));
- iterator_resource->set_stats_aggregator(
- stats_aggregator_resource->stats_aggregator());
- }
-};
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
@@ -1119,8 +1091,6 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
SerializeIteratorOp);
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
DeserializeIteratorOp);
-REGISTER_KERNEL_BUILDER(Name("IteratorSetStatsAggregator").Device(DEVICE_CPU),
- IteratorSetStatsAggregatorOp);
} // namespace