/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/util/batch_util.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) namespace tensorflow { namespace batch_util { namespace { Status ValidateInput(const Tensor& parent, const Tensor& element, int64 index) { DCHECK_NE(parent.dim_size(0), 0); DCHECK_GE(index, 0); if (element.NumElements() != (parent.NumElements() / parent.dim_size(0))) { TensorShape chip_shape = parent.shape(); chip_shape.RemoveDim(0); return errors::Internal( "ValidateInput Cannot perform copy: number of elements does not match. " " Shapes are: [element]: ", element.shape().DebugString(), ", [parent slice]: ", chip_shape.DebugString()); } return Status::OK(); } template Status HandleElementToSlice(Tensor element, Tensor* parent, int64 index, bool /* can_move */) { parent->flat_outer_dims().chip(index, 0) = element.flat(); return Status::OK(); } template <> Status HandleElementToSlice(Tensor element, Tensor* parent, int64 index, bool can_move) { auto parent_as_matrix = parent->flat_outer_dims(); auto element_flat = element.flat(); if (can_move) { for (int64 i = 0; i < element.NumElements(); ++i) { parent_as_matrix(index, i) = std::move(element_flat(i)); } } else { parent_as_matrix.chip(index, 0) = element_flat; } return Status::OK(); } template <> Status HandleElementToSlice(Tensor element, Tensor* parent, int64 index, bool can_move) { auto parent_as_matrix = parent->flat_outer_dims(); auto element_flat = element.flat(); if (can_move) { for (int64 i = 0; i < element.NumElements(); ++i) { parent_as_matrix(index, i) = std::move(element_flat(i)); } } else { parent_as_matrix.chip(index, 0) = element_flat; } return Status::OK(); } // TODO(b/78245576): Consider removing this overload. template void HandleSliceToElement(const Tensor& parent, Tensor* element, int64 index) { element->flat() = parent.flat_outer_dims().chip(index, 0); } template void HandleSliceToElement(Tensor* parent, Tensor* element, int64 index, bool can_move) { element->flat() = parent->flat_outer_dims().chip(index, 0); } template <> void HandleSliceToElement(Tensor* parent, Tensor* element, int64 index, bool can_move) { auto parent_as_matrix = parent->flat_outer_dims(); auto element_flat = element->flat(); if (can_move) { for (int64 i = 0; i < element->NumElements(); ++i) { element_flat(i) = std::move(parent_as_matrix(index, i)); } } else { element_flat = parent_as_matrix.chip(index, 0); } } template <> void HandleSliceToElement(Tensor* parent, Tensor* element, int64 index, bool can_move) { auto parent_as_matrix = parent->flat_outer_dims(); auto element_flat = element->flat(); if (can_move) { for (int64 i = 0; i < element->NumElements(); ++i) { element_flat(i) = std::move(parent_as_matrix(index, i)); } } else { element_flat = parent_as_matrix.chip(index, 0); } } } // namespace // Copies element into the index^th slice of parent (in the 0th dimension). Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) { TF_RETURN_IF_ERROR(ValidateInput(*parent, element, index)); bool can_move = element.RefCountIsOne(); #define HANDLE_TYPE(T) \ case DataTypeToEnum::value: { \ return HandleElementToSlice(std::move(element), parent, index, \ can_move); \ } switch (element.dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); TF_CALL_uint32(HANDLE_TYPE); TF_CALL_uint64(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented("CopyElementToSlice Unhandled data type: ", element.dtype()); } } // Copies the index^th slice of parent (in the 0th dimension) into element. Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) { TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index)); #define HANDLE_TYPE(T) \ case DataTypeToEnum::value: { \ HandleSliceToElement(parent, element, index); \ return Status::OK(); \ } switch (parent.dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented("CopySliceToElement Unhandled data type: ", element->dtype()); } } // Copies the index^th slice of parent (in the 0th dimension) into element. // // NOTE(mrry): The implementation may be able to optimize the copy to a move. // This is particularly important for DT_STRING tensors. Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index) { TF_RETURN_IF_ERROR(ValidateInput(*parent, *element, index)); bool can_move = parent->RefCountIsOne(); #define HANDLE_TYPE(T) \ case DataTypeToEnum::value: { \ HandleSliceToElement(parent, element, index, can_move); \ return Status::OK(); \ } switch (parent->dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented( "MaybeMoveSliceToElement Unhandled data type: ", element->dtype()); } } // The following five functions are copied from padding_fifo_queue.cc. // TODO(mrry): Reconcile these functions with the similar methods in the // queue implementation. Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) { DCHECK_NE(parent->dim_size(0), 0); if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) { TensorShape chip_shape = parent->shape(); chip_shape.RemoveDim(0); return errors::Internal( "HandleElementToLargerSlice Cannot copy slice: number of entries in " "element is greater than number of elements in parent slice. ", "Shapes are: [element]: ", element.shape().DebugString(), ", [parent slice]: ", chip_shape.DebugString()); } return Status::OK(); } template Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, int index) { TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent)); if (element.NumElements() == 0) { return Status::OK(); } auto element_t = element.tensor(); auto parent_t = parent->tensor(); Eigen::DSizes slice_indices; slice_indices[0] = index; Eigen::DSizes slice_size; slice_size[0] = 1; for (size_t i = 1; i < slice_size.size(); ++i) { slice_size[i] = element_t.dimension(i - 1); } parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size); return Status::OK(); } template Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, int index) { #define HANDLE_TYPE(T) \ case DataTypeToEnum::value: { \ return HandleElementToLargerSlice(element, parent, index); \ } switch (element.dtype()) { TF_CALL_DATASET_TYPES(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented( "HandleElementToLargerSliceWithRank Unhandled data type: ", element.dtype()); } } Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, int index) { if (parent->dims() != element.dims() + 1) { return errors::Internal( "Mismatched ranks. Element's rank is: ", element.dims(), " but element is meant to be a slice in output Tensor having rank: ", parent->dims(), " (should be: ", element.dims() + 1, ")"); } #define HANDLE_DIMS(NDIMS) \ case NDIMS: { \ TF_RETURN_IF_ERROR( \ HandleElementToLargerSliceWithRank(element, parent, index)); \ return Status::OK(); \ } switch (element.dims()) { HANDLE_DIMS(0); HANDLE_DIMS(1); HANDLE_DIMS(2); HANDLE_DIMS(3); HANDLE_DIMS(4); HANDLE_DIMS(5); #undef HANDLE_DIMS default: return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ", element.dims()); } } Status SetElementZero(Tensor* element, const Tensor& padding) { #define HANDLE_TYPE(T) \ if (element->dtype() == DataTypeToEnum::value) { \ element->flat().setConstant(padding.scalar()()); \ return Status::OK(); \ } TF_CALL_DATASET_TYPES(HANDLE_TYPE); #undef HANDLE_TYPE return errors::Unimplemented("SetElementZero Unhandled data type: ", element->dtype()); } } // namespace batch_util } // namespace tensorflow