diff options
-rw-r--r-- | tensorflow/core/framework/op_kernel.cc | 29 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.h | 26 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_reference.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_reference.h | 5 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_shape.cc | 12 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_shape.h | 15 | ||||
-rw-r--r-- | tensorflow/core/framework/unique_tensor_references.cc | 76 | ||||
-rw-r--r-- | tensorflow/core/framework/unique_tensor_references.h | 70 | ||||
-rw-r--r-- | tensorflow/core/graph/node_builder.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/graph/node_builder.h | 7 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/reduction_ops_common.cc | 127 | ||||
-rw-r--r-- | tensorflow/core/kernels/reduction_ops_common.h | 125 |
13 files changed, 279 insertions, 231 deletions
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index fb75443413..4392dbe5ff 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -90,6 +90,8 @@ OpKernel::OpKernel(OpKernelConstruction* context) &output_name_map_)); } +OpKernel::~OpKernel() {} + Status OpKernel::InputRange(const string& input_name, int* start, int* stop) const { const auto result = input_name_map_.find(input_name); @@ -172,6 +174,10 @@ Status OpKernelConstruction::allocate_persistent( return s; } +void OpKernelConstruction::SetStatus(const Status& status) { + status_->Update(status); +} + // OpKernelContext ----------------------------------------------------------- OpKernelContext::OpKernelContext(Params* params) @@ -194,6 +200,29 @@ OpKernelContext::~OpKernelContext() { } } +Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) { + Allocator* allocator = + params_->device->GetStepAllocator(attr, step_resource_manager()); + if (params_->track_allocations) { + mutex_lock lock(mu_); + for (const auto& wrapped : wrapped_allocators_) { + if (wrapped.first == allocator) { + return wrapped.second; + } + } + TrackingAllocator* wrapped_allocator = + new TrackingAllocator(allocator, attr.track_sizes()); + wrapped_allocators_.push_back(std::make_pair(allocator, wrapped_allocator)); + return wrapped_allocator; + } else { + return allocator; + } +} + +void OpKernelContext::SetStatus(const Status& status) { + status_.Update(status); +} + Status OpKernelContext::input(const string& name, const Tensor** tensor) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 618c931eb3..e4ffcd7352 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -69,7 +69,7 @@ class OpKernel { // OpKernel won't be instantiated by the scheduler, so you may perform // expensive initialization in the descendant's constructor. explicit OpKernel(OpKernelConstruction* context); - virtual ~OpKernel() {} + virtual ~OpKernel(); // An OpKernel's computation can be either synchronous or // asynchronous. @@ -287,7 +287,7 @@ class OpKernelConstruction { const DataTypeSlice expected_outputs); // For recording configuration errors during construction. - void SetStatus(const Status& status) { status_->Update(status); } + void SetStatus(const Status& status); const Status& status() const { return *status_; } // Look up the attr with name attr_name and set *value to its value. If no @@ -874,7 +874,7 @@ class OpKernelContext { // An OpKernel should call SetStatus() if Compute() encounters an // error. - void SetStatus(const Status& status) { status_.Update(status); } + void SetStatus(const Status& status); const Status& status() const { return status_; } // Cancellation. @@ -907,25 +907,7 @@ class OpKernelContext { } private: - Allocator* get_allocator(AllocatorAttributes attr) { - Allocator* allocator = - params_->device->GetStepAllocator(attr, step_resource_manager()); - if (params_->track_allocations) { - mutex_lock lock(mu_); - for (const auto& wrapped : wrapped_allocators_) { - if (wrapped.first == allocator) { - return wrapped.second; - } - } - TrackingAllocator* wrapped_allocator = - new TrackingAllocator(allocator, attr.track_sizes()); - wrapped_allocators_.push_back( - std::make_pair(allocator, wrapped_allocator)); - return wrapped_allocator; - } else { - return allocator; - } - } + Allocator* get_allocator(AllocatorAttributes attr); // Internal method to add a tensor's buffer to the list of buffers // referenced during the execution of the Op, so that GPUs may diff --git a/tensorflow/core/framework/tensor_reference.cc b/tensorflow/core/framework/tensor_reference.cc new file mode 100644 index 0000000000..52e5592030 --- /dev/null +++ b/tensorflow/core/framework/tensor_reference.cc @@ -0,0 +1,10 @@ +#include "tensorflow/core/framework/tensor_reference.h" + +namespace tensorflow { + +TensorReference::TensorReference(const Tensor& tensor) + : buf_(tensor.buf_ ? tensor.buf_->root_buffer() : nullptr) { + if (buf_) buf_->Ref(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_reference.h b/tensorflow/core/framework/tensor_reference.h index 56b516e694..abffccadd6 100644 --- a/tensorflow/core/framework/tensor_reference.h +++ b/tensorflow/core/framework/tensor_reference.h @@ -31,10 +31,7 @@ namespace tensorflow { class TensorReference { public: // Take the reference of the root buffer so the size will be more accurate - explicit TensorReference(const Tensor& tensor) - : buf_(tensor.buf_ ? tensor.buf_->root_buffer() : nullptr) { - if (buf_) buf_->Ref(); - } + explicit TensorReference(const Tensor& tensor); ~TensorReference() {} diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index f8daf49d89..f608cc815d 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -109,6 +109,18 @@ void TensorShape::SlowCopyFrom(const TensorShape& b) { } } +int64 TensorShape::dim_size(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + if (tag() == REP16) { + return as16()->dims_[d]; + } else if (tag() == REP32) { + return as32()->dims_[d]; + } else { + return (*as64()->dims_)[d]; + } +} + void TensorShape::Clear() { ClearAllButDataType(); set_data_type(DT_INVALID); diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index 8d4e728535..ab4979c902 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -87,24 +87,15 @@ class TensorShape { /// Return the number of dimensions in the tensor. int dims() const { - return (tag() == REP_OUT_OF_LINE) ? (*as64()->dims_).size() : ndims_byte(); + DCHECK(tag() != REP_OUT_OF_LINE || (*as64()->dims_).size() == ndims_byte()); + return ndims_byte(); } /// \brief Returns the number of elements in dimension `d`. /// REQUIRES: `0 <= d < dims()` // TODO(touts): Rename to `dimension()` to match // `Eigen::Tensor::dimension()`? - int64 dim_size(int d) const { - DCHECK_GE(d, 0); - DCHECK_LT(d, dims()); - if (tag() == REP16) { - return as16()->dims_[d]; - } else if (tag() == REP32) { - return as32()->dims_[d]; - } else { - return (*as64()->dims_)[d]; - } - } + int64 dim_size(int d) const; /// Returns sizes of all dimensions. gtl::InlinedVector<int64, 4> dim_sizes() const; diff --git a/tensorflow/core/framework/unique_tensor_references.cc b/tensorflow/core/framework/unique_tensor_references.cc new file mode 100644 index 0000000000..2901feb531 --- /dev/null +++ b/tensorflow/core/framework/unique_tensor_references.cc @@ -0,0 +1,76 @@ +#include "tensorflow/core/framework/unique_tensor_references.h" + +namespace tensorflow { + +UniqueTensorReferences::~UniqueTensorReferences() { + if (!frozen_) { + // The references were not retrieved so discard them to avoid + // leaking memory. + TensorReferenceVector refs; + FreezeAndReturnReferences(&refs); + for (auto& tensor : refs) { + tensor.Unref(); + } + } + delete referenced_tensors_set_; +} + +void UniqueTensorReferences::Add(const Tensor& tensor) { + DCHECK(!frozen_); + // Do nothing if the tensor has a null buffer. + if (tensor.IsInitialized()) { + if (referenced_tensors_set_ != nullptr) { + // There are enough tensors that we are using a hash set to + // de-duplicate. + const TensorReference tensor_ref(tensor); + if (!referenced_tensors_set_->insert(tensor_ref).second) { + // The tensor was a duplicate, so discard the reference. + tensor_ref.Unref(); + } + } else { + for (size_t i = 0; i < referenced_tensors_vector_.size(); ++i) { + if (referenced_tensors_vector_[i].SharesBufferWith(tensor)) { + // tensor is a duplicate, so nothing to do. + return; + } + } + referenced_tensors_vector_.push_back(TensorReference(tensor)); + if (kInVector == referenced_tensors_vector_.size()) { + // There are too many tensors to keep using the N^2 algorithm + // so start de-duplicating using a set. + // Transfer the refs from the vector to the set. + DCHECK(referenced_tensors_set_ == nullptr); + referenced_tensors_set_ = new ReferencedTensorsSet; + referenced_tensors_set_->reserve(kInVector); + referenced_tensors_set_->insert(referenced_tensors_vector_.begin(), + referenced_tensors_vector_.end()); + DCHECK_EQ(kInVector, referenced_tensors_set_->size()); + referenced_tensors_vector_.clear(); + } + } + } +} + +void UniqueTensorReferences::FreezeAndReturnReferences( + TensorReferenceVector* out_vector) { + // Prevent any further additions. + frozen_ = true; + if (referenced_tensors_set_ != nullptr) { + DCHECK(referenced_tensors_vector_.empty()); + out_vector->reserve(referenced_tensors_set_->size()); + for (const auto& ref : *referenced_tensors_set_) { + out_vector->push_back(ref); + } + referenced_tensors_set_->clear(); + delete referenced_tensors_set_; + referenced_tensors_set_ = nullptr; + } else { + out_vector->reserve(referenced_tensors_vector_.size()); + for (const auto& ref : referenced_tensors_vector_) { + out_vector->push_back(ref); + } + referenced_tensors_vector_.clear(); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/unique_tensor_references.h b/tensorflow/core/framework/unique_tensor_references.h index 2e08a9825f..f587348d9b 100644 --- a/tensorflow/core/framework/unique_tensor_references.h +++ b/tensorflow/core/framework/unique_tensor_references.h @@ -35,78 +35,14 @@ class UniqueTensorReferences { public: UniqueTensorReferences() : frozen_(false), referenced_tensors_set_(nullptr) {} - ~UniqueTensorReferences() { - if (!frozen_) { - // The references were not retrieved so discard them to avoid - // leaking memory. - TensorReferenceVector refs; - FreezeAndReturnReferences(&refs); - for (auto& tensor : refs) { - tensor.Unref(); - } - } - delete referenced_tensors_set_; - } + ~UniqueTensorReferences(); // Adds a reference to tensor if its buffer is not already referenced. - void Add(const Tensor& tensor) { - DCHECK(!frozen_); - // Do nothing if the tensor has a null buffer. - if (tensor.IsInitialized()) { - if (referenced_tensors_set_ != nullptr) { - // There are enough tensors that we are using a hash set to - // de-duplicate. - const TensorReference tensor_ref(tensor); - if (!referenced_tensors_set_->insert(tensor_ref).second) { - // The tensor was a duplicate, so discard the reference. - tensor_ref.Unref(); - } - } else { - for (size_t i = 0; i < referenced_tensors_vector_.size(); ++i) { - if (referenced_tensors_vector_[i].SharesBufferWith(tensor)) { - // tensor is a duplicate, so nothing to do. - return; - } - } - referenced_tensors_vector_.push_back(TensorReference(tensor)); - if (kInVector == referenced_tensors_vector_.size()) { - // There are too many tensors to keep using the N^2 algorithm - // so start de-duplicating using a set. - // Transfer the refs from the vector to the set. - DCHECK(referenced_tensors_set_ == nullptr); - referenced_tensors_set_ = new ReferencedTensorsSet; - referenced_tensors_set_->reserve(kInVector); - referenced_tensors_set_->insert(referenced_tensors_vector_.begin(), - referenced_tensors_vector_.end()); - DCHECK_EQ(kInVector, referenced_tensors_set_->size()); - referenced_tensors_vector_.clear(); - } - } - } - } + void Add(const Tensor& tensor); // No more references may be added after this is called. The unique // references are returning in out_vector. - void FreezeAndReturnReferences(TensorReferenceVector* out_vector) { - // Prevent any further additions. - frozen_ = true; - if (referenced_tensors_set_ != nullptr) { - DCHECK(referenced_tensors_vector_.empty()); - out_vector->reserve(referenced_tensors_set_->size()); - for (const auto& ref : *referenced_tensors_set_) { - out_vector->push_back(ref); - } - referenced_tensors_set_->clear(); - delete referenced_tensors_set_; - referenced_tensors_set_ = nullptr; - } else { - out_vector->reserve(referenced_tensors_vector_.size()); - for (const auto& ref : referenced_tensors_vector_) { - out_vector->push_back(ref); - } - referenced_tensors_vector_.clear(); - } - } + void FreezeAndReturnReferences(TensorReferenceVector* out_vector); private: // Up to kInVector elements are stored in reference_tensors_vector_ diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index fbea8d03cd..e0a71a0856 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -21,6 +21,13 @@ limitations under the License. namespace tensorflow { +NodeBuilder::NodeOut::NodeOut(Node* n, int i) // NOLINT(runtime/explicit) + : node(n), + error(false), + name(node != nullptr ? node->name() : (error = true, "")), + index(i), + dt(SafeGetOutput(node, i, &error)) {} + NodeBuilder::NodeBuilder(const string& name, const string& op_name, const OpRegistryInterface* op_registry) : def_builder_(name, op_name, op_registry) {} diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h index c8b4056974..51f00f6449 100644 --- a/tensorflow/core/graph/node_builder.h +++ b/tensorflow/core/graph/node_builder.h @@ -48,12 +48,7 @@ class NodeBuilder { // ArraySlice. struct NodeOut { // For referencing an existing Node. - NodeOut(Node* n, int i = 0) // NOLINT(runtime/explicit) - : node(n), - error(false), - name(node != nullptr ? node->name() : (error = true, "")), - index(i), - dt(SafeGetOutput(node, i, &error)) {} + NodeOut(Node* n, int i = 0); // For referencing Nodes not in the graph being built. It is // useful when preparing a graph for ExtendSession or creating a diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 4e5073b8be..a590558492 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1074,6 +1074,7 @@ filegroup( "io.cc", "lrn_op.cc", "maxpooling_op.cc", + "reduction_ops_common.cc", "reduction_ops_max.cc", "reduction_ops_mean.cc", "reduction_ops_min.cc", diff --git a/tensorflow/core/kernels/reduction_ops_common.cc b/tensorflow/core/kernels/reduction_ops_common.cc new file mode 100644 index 0000000000..50e0791872 --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops_common.cc @@ -0,0 +1,127 @@ +#include "tensorflow/core/kernels/reduction_ops_common.h" + +namespace tensorflow { + +TensorShape ReductionHelper::out_reshape() const { + TensorShape shape; + for (auto size : out_reshape_) shape.AddDim(size); + return shape; +} + +// The final output shape must be allocated with this shape. +TensorShape ReductionHelper::out_shape() const { + TensorShape shape; + for (auto size : out_shape_) shape.AddDim(size); + return shape; +} + +TensorShape ReductionHelper::shuffled_shape() { + const int dims = data_reshape_.size(); + TensorShape shape; + for (int i = reduce_first_axis_; i < dims; i += 2) { + shape.AddDim(data_reshape_[i]); + } + for (int i = !reduce_first_axis_; i < dims; i += 2) { + shape.AddDim(data_reshape_[i]); + } + return shape; +} + +gtl::InlinedVector<int32, 8> ReductionHelper::permutation() { + const int dims = data_reshape_.size(); + const int unreduced_dims = (dims + !reduce_first_axis_) / 2; + gtl::InlinedVector<int32, 8> perm(dims); + for (int i = 0; i < unreduced_dims; i++) { + perm[i] = 2 * i + reduce_first_axis_; + } + for (int i = unreduced_dims; i < dims; i++) { + perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_; + } + return perm; +} + +Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis, + const bool keep_dims) { + // bitmap[i] indicates whether to reduce data along i-th axis. + gtl::InlinedVector<bool, 4> bitmap(data.dims(), false); + auto axis_vec = axis.flat<int32>(); + for (int64 i = 0; i < axis.NumElements(); ++i) { + const int32 index = axis_vec(i); + if (index < 0 || index >= data.dims()) { + return errors::OutOfRange("Invalid reduction dimension (", index, + " for input with ", data.dims(), + " dimension(s)"); + } + bitmap[index] = true; + } + + // Output tensor's dim sizes. + out_shape_.clear(); + for (int i = 0; i < data.dims(); ++i) { + if (!bitmap[i]) { + // If we are not reducing along dimension i. + out_shape_.push_back(data.dim_size(i)); + } else if (keep_dims) { + // We are reducing along dimension i, but we want to keep the + // same number of dimensions, so we set the dimension of i to + // '1'. + out_shape_.push_back(1); + } + } + + // Depending on bitmap[i] and bitmap[i-1], we can collapse axis of + // the input data before doing the reduction on the resulting + // tensor. The shape of the reduction is a reshape of the final + // output. + + // We'll skip the leading 1s. + int dim_index = 0; + for (; dim_index < data.dims(); ++dim_index) { + if (data.dim_size(dim_index) != 1) break; + } + if (dim_index >= data.dims()) { + // Special case. The input is essentially a scalar. + reduce_first_axis_ = true; + } else { + // Starting from the (dim_index)-th dimension, dimensions + // alternates between runs that need to be reduced and runs that + // don't. + // + // NOTE: If a dimension has size 1, we group it as the current + // run so that we can minimize the number of runs. + // + // E.g., when we want to reduce a tensor of shape [2, 1, 3, 1, + // 5] by axes = [1, 4], we should treat the tensor as a [6, 5] + // and reduce by axes = [1] (i.e., the output is shape [6]). + reduce_first_axis_ = bitmap[dim_index]; + data_reshape_.push_back(data.dim_size(dim_index)); + ++dim_index; + for (; dim_index < data.dims(); ++dim_index) { + const auto size = data.dim_size(dim_index); + if (size == 1) { + bitmap[dim_index] = bitmap[dim_index - 1]; + } + if (bitmap[dim_index - 1] != bitmap[dim_index]) { + // Starts a new run of reduce or !reduce. + data_reshape_.push_back(size); + } else { + // Continue a run of reduce or !reduce. + data_reshape_.back() *= size; + } + } + // If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc + // are reduced), data_reshape_[1, 3, 5, ...] is out_reshape_, + // otherwise, data_reshape_[0, 2, 4, ...] is. + for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size(); + i += 2) { + out_reshape_.push_back(data_reshape_[i]); + } + } + + VLOG(1) << "data reshape: " << str_util::Join(data_reshape_, ","); + VLOG(1) << "out reshape: " << str_util::Join(out_reshape_, ","); + VLOG(1) << "out shape: " << str_util::Join(out_shape_, ","); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h index 19cedb1332..31cf55c78b 100644 --- a/tensorflow/core/kernels/reduction_ops_common.h +++ b/tensorflow/core/kernels/reduction_ops_common.h @@ -68,95 +68,11 @@ struct Constants<CPUDevice> { }; #endif -namespace { - class ReductionHelper { public: ReductionHelper() : reduce_first_axis_(false) {} - Status Simplify(const Tensor& data, const Tensor& axis, - const bool keep_dims) { - // bitmap[i] indicates whether to reduce data along i-th axis. - gtl::InlinedVector<bool, 4> bitmap(data.dims(), false); - auto axis_vec = axis.flat<int32>(); - for (int64 i = 0; i < axis.NumElements(); ++i) { - const int32 index = axis_vec(i); - if (index < 0 || index >= data.dims()) { - return errors::OutOfRange("Invalid reduction dimension (", index, - " for input with ", data.dims(), - " dimension(s)"); - } - bitmap[index] = true; - } - - // Output tensor's dim sizes. - out_shape_.clear(); - for (int i = 0; i < data.dims(); ++i) { - if (!bitmap[i]) { - // If we are not reducing along dimension i. - out_shape_.push_back(data.dim_size(i)); - } else if (keep_dims) { - // We are reducing along dimension i, but we want to keep the - // same number of dimensions, so we set the dimension of i to - // '1'. - out_shape_.push_back(1); - } - } - - // Depending on bitmap[i] and bitmap[i-1], we can collapse axis of - // the input data before doing the reduction on the resulting - // tensor. The shape of the reduction is a reshape of the final - // output. - - // We'll skip the leading 1s. - int dim_index = 0; - for (; dim_index < data.dims(); ++dim_index) { - if (data.dim_size(dim_index) != 1) break; - } - if (dim_index >= data.dims()) { - // Special case. The input is essentially a scalar. - reduce_first_axis_ = true; - } else { - // Starting from the (dim_index)-th dimension, dimensions - // alternates between runs that need to be reduced and runs that - // don't. - // - // NOTE: If a dimension has size 1, we group it as the current - // run so that we can minimize the number of runs. - // - // E.g., when we want to reduce a tensor of shape [2, 1, 3, 1, - // 5] by axes = [1, 4], we should treat the tensor as a [6, 5] - // and reduce by axes = [1] (i.e., the output is shape [6]). - reduce_first_axis_ = bitmap[dim_index]; - data_reshape_.push_back(data.dim_size(dim_index)); - ++dim_index; - for (; dim_index < data.dims(); ++dim_index) { - const auto size = data.dim_size(dim_index); - if (size == 1) { - bitmap[dim_index] = bitmap[dim_index - 1]; - } - if (bitmap[dim_index - 1] != bitmap[dim_index]) { - // Starts a new run of reduce or !reduce. - data_reshape_.push_back(size); - } else { - // Continue a run of reduce or !reduce. - data_reshape_.back() *= size; - } - } - // If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc - // are reduced), data_reshape_[1, 3, 5, ...] is out_reshape_, - // otherwise, data_reshape_[0, 2, 4, ...] is. - for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size(); - i += 2) { - out_reshape_.push_back(data_reshape_[i]); - } - } - - VLOG(1) << "data reshape: " << str_util::Join(data_reshape_, ","); - VLOG(1) << "out reshape: " << str_util::Join(out_reshape_, ","); - VLOG(1) << "out shape: " << str_util::Join(out_shape_, ","); - return Status::OK(); - } + Status Simplify(const Tensor& data, const Tensor& axis, const bool keep_dims); // We need to do roughly: // tmp_out = allocate(out_reshape()) @@ -164,18 +80,10 @@ class ReductionHelper { // out = tmp_out.reshape(out_shape) // The reduction result must be allocated with this shape. - TensorShape out_reshape() const { - TensorShape shape; - for (auto size : out_reshape_) shape.AddDim(size); - return shape; - } + TensorShape out_reshape() const; // The final output shape must be allocated with this shape. - TensorShape out_shape() const { - TensorShape shape; - for (auto size : out_shape_) shape.AddDim(size); - return shape; - } + TensorShape out_shape() const; // The reduction is on a reshaped tensor of this rank. int ndims() const { return data_reshape_.size(); } @@ -203,31 +111,10 @@ class ReductionHelper { } // Shape with all reduction dimensions at the end - TensorShape shuffled_shape() { - const int dims = data_reshape_.size(); - TensorShape shape; - for (int i = reduce_first_axis_; i < dims; i += 2) { - shape.AddDim(data_reshape_[i]); - } - for (int i = !reduce_first_axis_; i < dims; i += 2) { - shape.AddDim(data_reshape_[i]); - } - return shape; - } + TensorShape shuffled_shape(); // Permutation of reduced dims needed to put reduction dimensions at the end - gtl::InlinedVector<int32, 8> permutation() { - const int dims = data_reshape_.size(); - const int unreduced_dims = (dims + !reduce_first_axis_) / 2; - gtl::InlinedVector<int32, 8> perm(dims); - for (int i = 0; i < unreduced_dims; i++) { - perm[i] = 2 * i + reduce_first_axis_; - } - for (int i = unreduced_dims; i < dims; i++) { - perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_; - } - return perm; - } + gtl::InlinedVector<int32, 8> permutation(); private: bool reduce_first_axis_; // True if need to reduce the 0-th dimension. @@ -236,8 +123,6 @@ class ReductionHelper { gtl::InlinedVector<int64, 4> out_reshape_; // Reshape output for reduction. }; -} // end namespace - // For operations where the output is a reduction function along some // dimensions of the input. template <typename Device, class T, typename Reducer> |