diff options
author | Shivani Agrawal <shivaniagrawal@google.com> | 2018-04-18 16:01:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-18 16:04:37 -0700 |
commit | e9d47fbff0d644a75c6f3dcdcb852685ef515b64 (patch) | |
tree | 87b63378d33241f3fff820b0f70cdd89d2957406 /tensorflow/core/kernels/data/iterator_ops.cc | |
parent | 695da2d928b5927c0a4f73e352a597a19886f2cb (diff) |
Adds dataset transformation function `set_stats_aggregator(..)`, which sets the given `stats_aggregator` for aggregating the input dataset stats.
PiperOrigin-RevId: 193432590
Diffstat (limited to 'tensorflow/core/kernels/data/iterator_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/data/iterator_ops.cc | 32 |
1 files changed, 1 insertions, 31 deletions
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 780f927a4f..4e4997d7b3 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -19,11 +19,11 @@ limitations under the License. #include "tensorflow/core/framework/iterator.pb.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/kernels/data/dataset.h" -#include "tensorflow/core/kernels/data/stats_aggregator.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -203,10 +203,6 @@ class IteratorResource : public ResourceBase { return Status::OK(); } - void set_stats_aggregator(std::shared_ptr<StatsAggregator> stats_aggregator) { - mutex_lock l(mu_); - stats_aggregator_ = std::move(stats_aggregator); - } std::shared_ptr<StatsAggregator> stats_aggregator() { tf_shared_lock l(mu_); @@ -1075,30 +1071,6 @@ class DeserializeIteratorOp : public OpKernel { } }; -class IteratorSetStatsAggregatorOp : public OpKernel { - public: - explicit IteratorSetStatsAggregatorOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - IteratorResource* iterator_resource; - OP_REQUIRES_OK( - ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); - core::ScopedUnref unref_iterator(iterator_resource); - - StatsAggregatorResource* stats_aggregator_resource; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), - &stats_aggregator_resource)); - core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource); - // TODO(mrry): Consider allowing multiple StatsAggregator ops to - // subscribe to updates, and/or unsubscribing. - OP_REQUIRES(ctx, !iterator_resource->stats_aggregator(), - errors::FailedPrecondition( - "Iterator already associated with a StatsAggregator")); - iterator_resource->set_stats_aggregator( - stats_aggregator_resource->stats_aggregator()); - } -}; REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp); REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU), @@ -1119,8 +1091,6 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU), SerializeIteratorOp); REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU), DeserializeIteratorOp); -REGISTER_KERNEL_BUILDER(Name("IteratorSetStatsAggregator").Device(DEVICE_CPU), - IteratorSetStatsAggregatorOp); } // namespace |