aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/op_kernel.cc29
-rw-r--r--tensorflow/core/framework/op_kernel.h26
-rw-r--r--tensorflow/core/framework/tensor_reference.cc10
-rw-r--r--tensorflow/core/framework/tensor_reference.h5
-rw-r--r--tensorflow/core/framework/tensor_shape.cc12
-rw-r--r--tensorflow/core/framework/tensor_shape.h15
-rw-r--r--tensorflow/core/framework/unique_tensor_references.cc76
-rw-r--r--tensorflow/core/framework/unique_tensor_references.h70
-rw-r--r--tensorflow/core/graph/node_builder.cc7
-rw-r--r--tensorflow/core/graph/node_builder.h7
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.cc127
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.h125
13 files changed, 279 insertions, 231 deletions
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index fb75443413..4392dbe5ff 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -90,6 +90,8 @@ OpKernel::OpKernel(OpKernelConstruction* context)
&output_name_map_));
}
+OpKernel::~OpKernel() {}
+
Status OpKernel::InputRange(const string& input_name, int* start,
int* stop) const {
const auto result = input_name_map_.find(input_name);
@@ -172,6 +174,10 @@ Status OpKernelConstruction::allocate_persistent(
return s;
}
+void OpKernelConstruction::SetStatus(const Status& status) {
+ status_->Update(status);
+}
+
// OpKernelContext -----------------------------------------------------------
OpKernelContext::OpKernelContext(Params* params)
@@ -194,6 +200,29 @@ OpKernelContext::~OpKernelContext() {
}
}
+Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
+ Allocator* allocator =
+ params_->device->GetStepAllocator(attr, step_resource_manager());
+ if (params_->track_allocations) {
+ mutex_lock lock(mu_);
+ for (const auto& wrapped : wrapped_allocators_) {
+ if (wrapped.first == allocator) {
+ return wrapped.second;
+ }
+ }
+ TrackingAllocator* wrapped_allocator =
+ new TrackingAllocator(allocator, attr.track_sizes());
+ wrapped_allocators_.push_back(std::make_pair(allocator, wrapped_allocator));
+ return wrapped_allocator;
+ } else {
+ return allocator;
+ }
+}
+
+void OpKernelContext::SetStatus(const Status& status) {
+ status_.Update(status);
+}
+
Status OpKernelContext::input(const string& name, const Tensor** tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 618c931eb3..e4ffcd7352 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -69,7 +69,7 @@ class OpKernel {
// OpKernel won't be instantiated by the scheduler, so you may perform
// expensive initialization in the descendant's constructor.
explicit OpKernel(OpKernelConstruction* context);
- virtual ~OpKernel() {}
+ virtual ~OpKernel();
// An OpKernel's computation can be either synchronous or
// asynchronous.
@@ -287,7 +287,7 @@ class OpKernelConstruction {
const DataTypeSlice expected_outputs);
// For recording configuration errors during construction.
- void SetStatus(const Status& status) { status_->Update(status); }
+ void SetStatus(const Status& status);
const Status& status() const { return *status_; }
// Look up the attr with name attr_name and set *value to its value. If no
@@ -874,7 +874,7 @@ class OpKernelContext {
// An OpKernel should call SetStatus() if Compute() encounters an
// error.
- void SetStatus(const Status& status) { status_.Update(status); }
+ void SetStatus(const Status& status);
const Status& status() const { return status_; }
// Cancellation.
@@ -907,25 +907,7 @@ class OpKernelContext {
}
private:
- Allocator* get_allocator(AllocatorAttributes attr) {
- Allocator* allocator =
- params_->device->GetStepAllocator(attr, step_resource_manager());
- if (params_->track_allocations) {
- mutex_lock lock(mu_);
- for (const auto& wrapped : wrapped_allocators_) {
- if (wrapped.first == allocator) {
- return wrapped.second;
- }
- }
- TrackingAllocator* wrapped_allocator =
- new TrackingAllocator(allocator, attr.track_sizes());
- wrapped_allocators_.push_back(
- std::make_pair(allocator, wrapped_allocator));
- return wrapped_allocator;
- } else {
- return allocator;
- }
- }
+ Allocator* get_allocator(AllocatorAttributes attr);
// Internal method to add a tensor's buffer to the list of buffers
// referenced during the execution of the Op, so that GPUs may
diff --git a/tensorflow/core/framework/tensor_reference.cc b/tensorflow/core/framework/tensor_reference.cc
new file mode 100644
index 0000000000..52e5592030
--- /dev/null
+++ b/tensorflow/core/framework/tensor_reference.cc
@@ -0,0 +1,10 @@
+#include "tensorflow/core/framework/tensor_reference.h"
+
+namespace tensorflow {
+
+TensorReference::TensorReference(const Tensor& tensor)
+ : buf_(tensor.buf_ ? tensor.buf_->root_buffer() : nullptr) {
+ if (buf_) buf_->Ref();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_reference.h b/tensorflow/core/framework/tensor_reference.h
index 56b516e694..abffccadd6 100644
--- a/tensorflow/core/framework/tensor_reference.h
+++ b/tensorflow/core/framework/tensor_reference.h
@@ -31,10 +31,7 @@ namespace tensorflow {
class TensorReference {
public:
// Take the reference of the root buffer so the size will be more accurate
- explicit TensorReference(const Tensor& tensor)
- : buf_(tensor.buf_ ? tensor.buf_->root_buffer() : nullptr) {
- if (buf_) buf_->Ref();
- }
+ explicit TensorReference(const Tensor& tensor);
~TensorReference() {}
diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc
index f8daf49d89..f608cc815d 100644
--- a/tensorflow/core/framework/tensor_shape.cc
+++ b/tensorflow/core/framework/tensor_shape.cc
@@ -109,6 +109,18 @@ void TensorShape::SlowCopyFrom(const TensorShape& b) {
}
}
+int64 TensorShape::dim_size(int d) const {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ if (tag() == REP16) {
+ return as16()->dims_[d];
+ } else if (tag() == REP32) {
+ return as32()->dims_[d];
+ } else {
+ return (*as64()->dims_)[d];
+ }
+}
+
void TensorShape::Clear() {
ClearAllButDataType();
set_data_type(DT_INVALID);
diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h
index 8d4e728535..ab4979c902 100644
--- a/tensorflow/core/framework/tensor_shape.h
+++ b/tensorflow/core/framework/tensor_shape.h
@@ -87,24 +87,15 @@ class TensorShape {
/// Return the number of dimensions in the tensor.
int dims() const {
- return (tag() == REP_OUT_OF_LINE) ? (*as64()->dims_).size() : ndims_byte();
+ DCHECK(tag() != REP_OUT_OF_LINE || (*as64()->dims_).size() == ndims_byte());
+ return ndims_byte();
}
/// \brief Returns the number of elements in dimension `d`.
/// REQUIRES: `0 <= d < dims()`
// TODO(touts): Rename to `dimension()` to match
// `Eigen::Tensor::dimension()`?
- int64 dim_size(int d) const {
- DCHECK_GE(d, 0);
- DCHECK_LT(d, dims());
- if (tag() == REP16) {
- return as16()->dims_[d];
- } else if (tag() == REP32) {
- return as32()->dims_[d];
- } else {
- return (*as64()->dims_)[d];
- }
- }
+ int64 dim_size(int d) const;
/// Returns sizes of all dimensions.
gtl::InlinedVector<int64, 4> dim_sizes() const;
diff --git a/tensorflow/core/framework/unique_tensor_references.cc b/tensorflow/core/framework/unique_tensor_references.cc
new file mode 100644
index 0000000000..2901feb531
--- /dev/null
+++ b/tensorflow/core/framework/unique_tensor_references.cc
@@ -0,0 +1,76 @@
+#include "tensorflow/core/framework/unique_tensor_references.h"
+
+namespace tensorflow {
+
+UniqueTensorReferences::~UniqueTensorReferences() {
+ if (!frozen_) {
+ // The references were not retrieved so discard them to avoid
+ // leaking memory.
+ TensorReferenceVector refs;
+ FreezeAndReturnReferences(&refs);
+ for (auto& tensor : refs) {
+ tensor.Unref();
+ }
+ }
+ delete referenced_tensors_set_;
+}
+
+void UniqueTensorReferences::Add(const Tensor& tensor) {
+ DCHECK(!frozen_);
+ // Do nothing if the tensor has a null buffer.
+ if (tensor.IsInitialized()) {
+ if (referenced_tensors_set_ != nullptr) {
+ // There are enough tensors that we are using a hash set to
+ // de-duplicate.
+ const TensorReference tensor_ref(tensor);
+ if (!referenced_tensors_set_->insert(tensor_ref).second) {
+ // The tensor was a duplicate, so discard the reference.
+ tensor_ref.Unref();
+ }
+ } else {
+ for (size_t i = 0; i < referenced_tensors_vector_.size(); ++i) {
+ if (referenced_tensors_vector_[i].SharesBufferWith(tensor)) {
+ // tensor is a duplicate, so nothing to do.
+ return;
+ }
+ }
+ referenced_tensors_vector_.push_back(TensorReference(tensor));
+ if (kInVector == referenced_tensors_vector_.size()) {
+ // There are too many tensors to keep using the N^2 algorithm
+ // so start de-duplicating using a set.
+ // Transfer the refs from the vector to the set.
+ DCHECK(referenced_tensors_set_ == nullptr);
+ referenced_tensors_set_ = new ReferencedTensorsSet;
+ referenced_tensors_set_->reserve(kInVector);
+ referenced_tensors_set_->insert(referenced_tensors_vector_.begin(),
+ referenced_tensors_vector_.end());
+ DCHECK_EQ(kInVector, referenced_tensors_set_->size());
+ referenced_tensors_vector_.clear();
+ }
+ }
+ }
+}
+
+void UniqueTensorReferences::FreezeAndReturnReferences(
+ TensorReferenceVector* out_vector) {
+ // Prevent any further additions.
+ frozen_ = true;
+ if (referenced_tensors_set_ != nullptr) {
+ DCHECK(referenced_tensors_vector_.empty());
+ out_vector->reserve(referenced_tensors_set_->size());
+ for (const auto& ref : *referenced_tensors_set_) {
+ out_vector->push_back(ref);
+ }
+ referenced_tensors_set_->clear();
+ delete referenced_tensors_set_;
+ referenced_tensors_set_ = nullptr;
+ } else {
+ out_vector->reserve(referenced_tensors_vector_.size());
+ for (const auto& ref : referenced_tensors_vector_) {
+ out_vector->push_back(ref);
+ }
+ referenced_tensors_vector_.clear();
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/unique_tensor_references.h b/tensorflow/core/framework/unique_tensor_references.h
index 2e08a9825f..f587348d9b 100644
--- a/tensorflow/core/framework/unique_tensor_references.h
+++ b/tensorflow/core/framework/unique_tensor_references.h
@@ -35,78 +35,14 @@ class UniqueTensorReferences {
public:
UniqueTensorReferences() : frozen_(false), referenced_tensors_set_(nullptr) {}
- ~UniqueTensorReferences() {
- if (!frozen_) {
- // The references were not retrieved so discard them to avoid
- // leaking memory.
- TensorReferenceVector refs;
- FreezeAndReturnReferences(&refs);
- for (auto& tensor : refs) {
- tensor.Unref();
- }
- }
- delete referenced_tensors_set_;
- }
+ ~UniqueTensorReferences();
// Adds a reference to tensor if its buffer is not already referenced.
- void Add(const Tensor& tensor) {
- DCHECK(!frozen_);
- // Do nothing if the tensor has a null buffer.
- if (tensor.IsInitialized()) {
- if (referenced_tensors_set_ != nullptr) {
- // There are enough tensors that we are using a hash set to
- // de-duplicate.
- const TensorReference tensor_ref(tensor);
- if (!referenced_tensors_set_->insert(tensor_ref).second) {
- // The tensor was a duplicate, so discard the reference.
- tensor_ref.Unref();
- }
- } else {
- for (size_t i = 0; i < referenced_tensors_vector_.size(); ++i) {
- if (referenced_tensors_vector_[i].SharesBufferWith(tensor)) {
- // tensor is a duplicate, so nothing to do.
- return;
- }
- }
- referenced_tensors_vector_.push_back(TensorReference(tensor));
- if (kInVector == referenced_tensors_vector_.size()) {
- // There are too many tensors to keep using the N^2 algorithm
- // so start de-duplicating using a set.
- // Transfer the refs from the vector to the set.
- DCHECK(referenced_tensors_set_ == nullptr);
- referenced_tensors_set_ = new ReferencedTensorsSet;
- referenced_tensors_set_->reserve(kInVector);
- referenced_tensors_set_->insert(referenced_tensors_vector_.begin(),
- referenced_tensors_vector_.end());
- DCHECK_EQ(kInVector, referenced_tensors_set_->size());
- referenced_tensors_vector_.clear();
- }
- }
- }
- }
+ void Add(const Tensor& tensor);
// No more references may be added after this is called. The unique
// references are returning in out_vector.
- void FreezeAndReturnReferences(TensorReferenceVector* out_vector) {
- // Prevent any further additions.
- frozen_ = true;
- if (referenced_tensors_set_ != nullptr) {
- DCHECK(referenced_tensors_vector_.empty());
- out_vector->reserve(referenced_tensors_set_->size());
- for (const auto& ref : *referenced_tensors_set_) {
- out_vector->push_back(ref);
- }
- referenced_tensors_set_->clear();
- delete referenced_tensors_set_;
- referenced_tensors_set_ = nullptr;
- } else {
- out_vector->reserve(referenced_tensors_vector_.size());
- for (const auto& ref : referenced_tensors_vector_) {
- out_vector->push_back(ref);
- }
- referenced_tensors_vector_.clear();
- }
- }
+ void FreezeAndReturnReferences(TensorReferenceVector* out_vector);
private:
// Up to kInVector elements are stored in reference_tensors_vector_
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index fbea8d03cd..e0a71a0856 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -21,6 +21,13 @@ limitations under the License.
namespace tensorflow {
+NodeBuilder::NodeOut::NodeOut(Node* n, int i) // NOLINT(runtime/explicit)
+ : node(n),
+ error(false),
+ name(node != nullptr ? node->name() : (error = true, "")),
+ index(i),
+ dt(SafeGetOutput(node, i, &error)) {}
+
NodeBuilder::NodeBuilder(const string& name, const string& op_name,
const OpRegistryInterface* op_registry)
: def_builder_(name, op_name, op_registry) {}
diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h
index c8b4056974..51f00f6449 100644
--- a/tensorflow/core/graph/node_builder.h
+++ b/tensorflow/core/graph/node_builder.h
@@ -48,12 +48,7 @@ class NodeBuilder {
// ArraySlice.
struct NodeOut {
// For referencing an existing Node.
- NodeOut(Node* n, int i = 0) // NOLINT(runtime/explicit)
- : node(n),
- error(false),
- name(node != nullptr ? node->name() : (error = true, "")),
- index(i),
- dt(SafeGetOutput(node, i, &error)) {}
+ NodeOut(Node* n, int i = 0);
// For referencing Nodes not in the graph being built. It is
// useful when preparing a graph for ExtendSession or creating a
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 4e5073b8be..a590558492 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1074,6 +1074,7 @@ filegroup(
"io.cc",
"lrn_op.cc",
"maxpooling_op.cc",
+ "reduction_ops_common.cc",
"reduction_ops_max.cc",
"reduction_ops_mean.cc",
"reduction_ops_min.cc",
diff --git a/tensorflow/core/kernels/reduction_ops_common.cc b/tensorflow/core/kernels/reduction_ops_common.cc
new file mode 100644
index 0000000000..50e0791872
--- /dev/null
+++ b/tensorflow/core/kernels/reduction_ops_common.cc
@@ -0,0 +1,127 @@
+#include "tensorflow/core/kernels/reduction_ops_common.h"
+
+namespace tensorflow {
+
+TensorShape ReductionHelper::out_reshape() const {
+ TensorShape shape;
+ for (auto size : out_reshape_) shape.AddDim(size);
+ return shape;
+}
+
+// The final output shape must be allocated with this shape.
+TensorShape ReductionHelper::out_shape() const {
+ TensorShape shape;
+ for (auto size : out_shape_) shape.AddDim(size);
+ return shape;
+}
+
+TensorShape ReductionHelper::shuffled_shape() {
+ const int dims = data_reshape_.size();
+ TensorShape shape;
+ for (int i = reduce_first_axis_; i < dims; i += 2) {
+ shape.AddDim(data_reshape_[i]);
+ }
+ for (int i = !reduce_first_axis_; i < dims; i += 2) {
+ shape.AddDim(data_reshape_[i]);
+ }
+ return shape;
+}
+
+gtl::InlinedVector<int32, 8> ReductionHelper::permutation() {
+ const int dims = data_reshape_.size();
+ const int unreduced_dims = (dims + !reduce_first_axis_) / 2;
+ gtl::InlinedVector<int32, 8> perm(dims);
+ for (int i = 0; i < unreduced_dims; i++) {
+ perm[i] = 2 * i + reduce_first_axis_;
+ }
+ for (int i = unreduced_dims; i < dims; i++) {
+ perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_;
+ }
+ return perm;
+}
+
+Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
+ const bool keep_dims) {
+ // bitmap[i] indicates whether to reduce data along i-th axis.
+ gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
+ auto axis_vec = axis.flat<int32>();
+ for (int64 i = 0; i < axis.NumElements(); ++i) {
+ const int32 index = axis_vec(i);
+ if (index < 0 || index >= data.dims()) {
+ return errors::OutOfRange("Invalid reduction dimension (", index,
+ " for input with ", data.dims(),
+ " dimension(s)");
+ }
+ bitmap[index] = true;
+ }
+
+ // Output tensor's dim sizes.
+ out_shape_.clear();
+ for (int i = 0; i < data.dims(); ++i) {
+ if (!bitmap[i]) {
+ // If we are not reducing along dimension i.
+ out_shape_.push_back(data.dim_size(i));
+ } else if (keep_dims) {
+ // We are reducing along dimension i, but we want to keep the
+ // same number of dimensions, so we set the dimension of i to
+ // '1'.
+ out_shape_.push_back(1);
+ }
+ }
+
+ // Depending on bitmap[i] and bitmap[i-1], we can collapse axis of
+ // the input data before doing the reduction on the resulting
+ // tensor. The shape of the reduction is a reshape of the final
+ // output.
+
+ // We'll skip the leading 1s.
+ int dim_index = 0;
+ for (; dim_index < data.dims(); ++dim_index) {
+ if (data.dim_size(dim_index) != 1) break;
+ }
+ if (dim_index >= data.dims()) {
+ // Special case. The input is essentially a scalar.
+ reduce_first_axis_ = true;
+ } else {
+ // Starting from the (dim_index)-th dimension, dimensions
+ // alternates between runs that need to be reduced and runs that
+ // don't.
+ //
+ // NOTE: If a dimension has size 1, we group it as the current
+ // run so that we can minimize the number of runs.
+ //
+ // E.g., when we want to reduce a tensor of shape [2, 1, 3, 1,
+ // 5] by axes = [1, 4], we should treat the tensor as a [6, 5]
+ // and reduce by axes = [1] (i.e., the output is shape [6]).
+ reduce_first_axis_ = bitmap[dim_index];
+ data_reshape_.push_back(data.dim_size(dim_index));
+ ++dim_index;
+ for (; dim_index < data.dims(); ++dim_index) {
+ const auto size = data.dim_size(dim_index);
+ if (size == 1) {
+ bitmap[dim_index] = bitmap[dim_index - 1];
+ }
+ if (bitmap[dim_index - 1] != bitmap[dim_index]) {
+ // Starts a new run of reduce or !reduce.
+ data_reshape_.push_back(size);
+ } else {
+ // Continue a run of reduce or !reduce.
+ data_reshape_.back() *= size;
+ }
+ }
+ // If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc
+ // are reduced), data_reshape_[1, 3, 5, ...] is out_reshape_,
+ // otherwise, data_reshape_[0, 2, 4, ...] is.
+ for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size();
+ i += 2) {
+ out_reshape_.push_back(data_reshape_[i]);
+ }
+ }
+
+ VLOG(1) << "data reshape: " << str_util::Join(data_reshape_, ",");
+ VLOG(1) << "out reshape: " << str_util::Join(out_reshape_, ",");
+ VLOG(1) << "out shape: " << str_util::Join(out_shape_, ",");
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h
index 19cedb1332..31cf55c78b 100644
--- a/tensorflow/core/kernels/reduction_ops_common.h
+++ b/tensorflow/core/kernels/reduction_ops_common.h
@@ -68,95 +68,11 @@ struct Constants<CPUDevice> {
};
#endif
-namespace {
-
class ReductionHelper {
public:
ReductionHelper() : reduce_first_axis_(false) {}
- Status Simplify(const Tensor& data, const Tensor& axis,
- const bool keep_dims) {
- // bitmap[i] indicates whether to reduce data along i-th axis.
- gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
- auto axis_vec = axis.flat<int32>();
- for (int64 i = 0; i < axis.NumElements(); ++i) {
- const int32 index = axis_vec(i);
- if (index < 0 || index >= data.dims()) {
- return errors::OutOfRange("Invalid reduction dimension (", index,
- " for input with ", data.dims(),
- " dimension(s)");
- }
- bitmap[index] = true;
- }
-
- // Output tensor's dim sizes.
- out_shape_.clear();
- for (int i = 0; i < data.dims(); ++i) {
- if (!bitmap[i]) {
- // If we are not reducing along dimension i.
- out_shape_.push_back(data.dim_size(i));
- } else if (keep_dims) {
- // We are reducing along dimension i, but we want to keep the
- // same number of dimensions, so we set the dimension of i to
- // '1'.
- out_shape_.push_back(1);
- }
- }
-
- // Depending on bitmap[i] and bitmap[i-1], we can collapse axis of
- // the input data before doing the reduction on the resulting
- // tensor. The shape of the reduction is a reshape of the final
- // output.
-
- // We'll skip the leading 1s.
- int dim_index = 0;
- for (; dim_index < data.dims(); ++dim_index) {
- if (data.dim_size(dim_index) != 1) break;
- }
- if (dim_index >= data.dims()) {
- // Special case. The input is essentially a scalar.
- reduce_first_axis_ = true;
- } else {
- // Starting from the (dim_index)-th dimension, dimensions
- // alternates between runs that need to be reduced and runs that
- // don't.
- //
- // NOTE: If a dimension has size 1, we group it as the current
- // run so that we can minimize the number of runs.
- //
- // E.g., when we want to reduce a tensor of shape [2, 1, 3, 1,
- // 5] by axes = [1, 4], we should treat the tensor as a [6, 5]
- // and reduce by axes = [1] (i.e., the output is shape [6]).
- reduce_first_axis_ = bitmap[dim_index];
- data_reshape_.push_back(data.dim_size(dim_index));
- ++dim_index;
- for (; dim_index < data.dims(); ++dim_index) {
- const auto size = data.dim_size(dim_index);
- if (size == 1) {
- bitmap[dim_index] = bitmap[dim_index - 1];
- }
- if (bitmap[dim_index - 1] != bitmap[dim_index]) {
- // Starts a new run of reduce or !reduce.
- data_reshape_.push_back(size);
- } else {
- // Continue a run of reduce or !reduce.
- data_reshape_.back() *= size;
- }
- }
- // If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc
- // are reduced), data_reshape_[1, 3, 5, ...] is out_reshape_,
- // otherwise, data_reshape_[0, 2, 4, ...] is.
- for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size();
- i += 2) {
- out_reshape_.push_back(data_reshape_[i]);
- }
- }
-
- VLOG(1) << "data reshape: " << str_util::Join(data_reshape_, ",");
- VLOG(1) << "out reshape: " << str_util::Join(out_reshape_, ",");
- VLOG(1) << "out shape: " << str_util::Join(out_shape_, ",");
- return Status::OK();
- }
+ Status Simplify(const Tensor& data, const Tensor& axis, const bool keep_dims);
// We need to do roughly:
// tmp_out = allocate(out_reshape())
@@ -164,18 +80,10 @@ class ReductionHelper {
// out = tmp_out.reshape(out_shape)
// The reduction result must be allocated with this shape.
- TensorShape out_reshape() const {
- TensorShape shape;
- for (auto size : out_reshape_) shape.AddDim(size);
- return shape;
- }
+ TensorShape out_reshape() const;
// The final output shape must be allocated with this shape.
- TensorShape out_shape() const {
- TensorShape shape;
- for (auto size : out_shape_) shape.AddDim(size);
- return shape;
- }
+ TensorShape out_shape() const;
// The reduction is on a reshaped tensor of this rank.
int ndims() const { return data_reshape_.size(); }
@@ -203,31 +111,10 @@ class ReductionHelper {
}
// Shape with all reduction dimensions at the end
- TensorShape shuffled_shape() {
- const int dims = data_reshape_.size();
- TensorShape shape;
- for (int i = reduce_first_axis_; i < dims; i += 2) {
- shape.AddDim(data_reshape_[i]);
- }
- for (int i = !reduce_first_axis_; i < dims; i += 2) {
- shape.AddDim(data_reshape_[i]);
- }
- return shape;
- }
+ TensorShape shuffled_shape();
// Permutation of reduced dims needed to put reduction dimensions at the end
- gtl::InlinedVector<int32, 8> permutation() {
- const int dims = data_reshape_.size();
- const int unreduced_dims = (dims + !reduce_first_axis_) / 2;
- gtl::InlinedVector<int32, 8> perm(dims);
- for (int i = 0; i < unreduced_dims; i++) {
- perm[i] = 2 * i + reduce_first_axis_;
- }
- for (int i = unreduced_dims; i < dims; i++) {
- perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_;
- }
- return perm;
- }
+ gtl::InlinedVector<int32, 8> permutation();
private:
bool reduce_first_axis_; // True if need to reduce the 0-th dimension.
@@ -236,8 +123,6 @@ class ReductionHelper {
gtl::InlinedVector<int64, 4> out_reshape_; // Reshape output for reduction.
};
-} // end namespace
-
// For operations where the output is a reduction function along some
// dimensions of the input.
template <typename Device, class T, typename Reducer>