aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-17 16:41:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 16:50:56 -0700
commit6805a8b27759a530f0ebab0670593a05455a64a0 (patch)
tree1ea29c728b29b5b5641ee00997debb7737bcb13c /tensorflow/core/framework
parent0cdf60ff8239a68326af9610e715f42c773be731 (diff)
Changing `OpInputList` so that it is a forward iterator and taking advantage of the fact in the tf.data kernels.
PiperOrigin-RevId: 213361953
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/op_kernel.h31
1 files changed, 25 insertions, 6 deletions
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index e752599de1..4bbd6c3d7d 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -372,18 +372,37 @@ class OpKernelConstruction {
template <typename ListType, typename ElementType>
class OpArgIterator {
public:
- typedef OpArgIterator<ListType, ElementType> ME;
+ using iterator_category = std::forward_iterator_tag;
+ using value_type = ElementType;
+ using pointer = ElementType*;
+ using reference = ElementType&;
+ using difference_type = ptrdiff_t;
+
OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
- bool operator==(const ME& rhs) {
+
+ bool operator==(const OpArgIterator& rhs) {
DCHECK(list_ == rhs.list_);
return i_ == rhs.i_;
}
- bool operator!=(const ME& rhs) {
+
+ bool operator!=(const OpArgIterator& rhs) {
DCHECK(list_ == rhs.list_);
return i_ != rhs.i_;
}
- void operator++() { ++i_; }
- ElementType& operator*() { return (*list_)[i_]; }
+
+ OpArgIterator operator++() { // prefix ++it
+ ++i_;
+ return *this;
+ }
+
+ OpArgIterator operator++(int) { // postfix it++
+ OpArgIterator old_value = *this;
+ ++i_;
+ return old_value;
+ }
+
+ reference operator*() { return (*list_)[i_]; }
+ pointer operator->() { return &(*list_)[i_]; }
private:
const ListType* const list_;
@@ -394,7 +413,7 @@ class OpArgIterator {
// that are passed to the op as a single named argument.
class OpInputList {
public:
- typedef OpArgIterator<OpInputList, const Tensor&> Iterator;
+ typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext* ctx, int start, int stop)
: ctx_(ctx), start_(start), stop_(stop) {}