diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-09-17 16:41:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 16:50:56 -0700 |
commit | 6805a8b27759a530f0ebab0670593a05455a64a0 (patch) | |
tree | 1ea29c728b29b5b5641ee00997debb7737bcb13c /tensorflow/core/framework | |
parent | 0cdf60ff8239a68326af9610e715f42c773be731 (diff) |
Changing `OpInputList` so that it is a forward iterator and taking advantage of the fact in the tf.data kernels.
PiperOrigin-RevId: 213361953
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/op_kernel.h | 31 |
1 files changed, 25 insertions, 6 deletions
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index e752599de1..4bbd6c3d7d 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -372,18 +372,37 @@ class OpKernelConstruction { template <typename ListType, typename ElementType> class OpArgIterator { public: - typedef OpArgIterator<ListType, ElementType> ME; + using iterator_category = std::forward_iterator_tag; + using value_type = ElementType; + using pointer = ElementType*; + using reference = ElementType&; + using difference_type = ptrdiff_t; + OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {} - bool operator==(const ME& rhs) { + + bool operator==(const OpArgIterator& rhs) { DCHECK(list_ == rhs.list_); return i_ == rhs.i_; } - bool operator!=(const ME& rhs) { + + bool operator!=(const OpArgIterator& rhs) { DCHECK(list_ == rhs.list_); return i_ != rhs.i_; } - void operator++() { ++i_; } - ElementType& operator*() { return (*list_)[i_]; } + + OpArgIterator operator++() { // prefix ++it + ++i_; + return *this; + } + + OpArgIterator operator++(int) { // postfix it++ + OpArgIterator old_value = *this; + ++i_; + return old_value; + } + + reference operator*() { return (*list_)[i_]; } + pointer operator->() { return &(*list_)[i_]; } private: const ListType* const list_; @@ -394,7 +413,7 @@ class OpArgIterator { // that are passed to the op as a single named argument. class OpInputList { public: - typedef OpArgIterator<OpInputList, const Tensor&> Iterator; + typedef OpArgIterator<OpInputList, const Tensor> Iterator; OpInputList() : ctx_(nullptr), start_(0), stop_(0) {} OpInputList(OpKernelContext* ctx, int start, int stop) : ctx_(ctx), start_(start), stop_(stop) {} |