diff options
author | 2016-01-19 09:39:34 -0800 | |
---|---|---|
committer | 2016-01-20 07:47:54 -0800 | |
commit | e39629219e748b08177f2c457ba45d51f5370aae (patch) | |
tree | b0861c366bbbdf8330bb1862aa5a32251c012e65 /tensorflow/core/kernels | |
parent | f592f23775e2a6ac75496829db5005d3bb70a3d2 (diff) |
PaddingFIFOQueue is like FIFOQueue but allows dynamic shapes (using padding with DequeueMany)
Change: 112482056
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r-- | tensorflow/core/kernels/fifo_queue.h | 3 | ||||
-rw-r--r-- | tensorflow/core/kernels/fifo_queue_op.cc | 59 | ||||
-rw-r--r-- | tensorflow/core/kernels/padding_fifo_queue.cc | 370 | ||||
-rw-r--r-- | tensorflow/core/kernels/padding_fifo_queue.h | 88 | ||||
-rw-r--r-- | tensorflow/core/kernels/padding_fifo_queue_op.cc | 68 | ||||
-rw-r--r-- | tensorflow/core/kernels/queue_base.cc | 19 | ||||
-rw-r--r-- | tensorflow/core/kernels/queue_base.h | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/queue_op.h | 105 | ||||
-rw-r--r-- | tensorflow/core/kernels/random_shuffle_queue_op.cc | 57 |
9 files changed, 659 insertions, 112 deletions
diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h index c8406e93bb..35ac38777e 100644 --- a/tensorflow/core/kernels/fifo_queue.h +++ b/tensorflow/core/kernels/fifo_queue.h @@ -52,7 +52,7 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > { return queues_[0].size(); } - private: + protected: ~FIFOQueue() override {} // Helper for dequeuing a single element from queues_. @@ -64,6 +64,7 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > { OpKernelContext* ctx, PersistentTensor* out_element); + private: TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueue); }; diff --git a/tensorflow/core/kernels/fifo_queue_op.cc b/tensorflow/core/kernels/fifo_queue_op.cc index 663250359c..a43c17637c 100644 --- a/tensorflow/core/kernels/fifo_queue_op.cc +++ b/tensorflow/core/kernels/fifo_queue_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/fifo_queue.h" #include "tensorflow/core/kernels/queue_base.h" +#include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -38,70 +39,24 @@ namespace tensorflow { // backed by FIFOQueue) that persists across different graph // executions, and sessions. Running this op produces a single-element // tensor of handles to Queues in the corresponding device. -class FIFOQueueOp : public OpKernel { +class FIFOQueueOp : public QueueOp { public: - explicit FIFOQueueOp(OpKernelConstruction* context) - : OpKernel(context), queue_handle_set_(false) { - OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_)); - OP_REQUIRES_OK(context, - context->allocate_persistent(DT_STRING, TensorShape({2}), - &queue_handle_, nullptr)); - if (capacity_ < 0) { - capacity_ = FIFOQueue::kUnbounded; - } - OP_REQUIRES_OK(context, - context->GetAttr("component_types", &component_types_)); + explicit FIFOQueueOp(OpKernelConstruction* context) : QueueOp(context) { OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); } - ~FIFOQueueOp() override { - // If the queue object was not shared, delete it. - if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) { - TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>( - cinfo_.container(), cinfo_.name())); - } - } - - void Compute(OpKernelContext* ctx) override { - mutex_lock l(mu_); - if (!queue_handle_set_) { - OP_REQUIRES_OK(ctx, SetQueueHandle(ctx)); - } - ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx)); - } - - private: - Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def())); - QueueInterface* queue; - auto creator = [this](QueueInterface** ret) { + protected: + CreatorCallback GetCreator() const override { + return [this](QueueInterface** ret) { FIFOQueue* queue = new FIFOQueue(capacity_, component_types_, component_shapes_, cinfo_.name()); *ret = queue; return queue->Initialize(); }; - TF_RETURN_IF_ERROR( - cinfo_.resource_manager()->LookupOrCreate<QueueInterface>( - cinfo_.container(), cinfo_.name(), &queue, creator)); - core::ScopedUnref unref_me(queue); - // Verify that the shared queue is compatible with the requested arguments. - TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def())); - auto h = queue_handle_.AccessTensor(ctx)->flat<string>(); - h(0) = cinfo_.container(); - h(1) = cinfo_.name(); - queue_handle_set_ = true; - return Status::OK(); } - int32 capacity_; - DataTypeVector component_types_; + private: std::vector<TensorShape> component_shapes_; - ContainerInfo cinfo_; - - mutex mu_; - PersistentTensor queue_handle_ GUARDED_BY(mu_); - bool queue_handle_set_ GUARDED_BY(mu_); - TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp); }; diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc new file mode 100644 index 0000000000..d2f8c06fdb --- /dev/null +++ b/tensorflow/core/kernels/padding_fifo_queue.cc @@ -0,0 +1,370 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/data_flow_ops.cc. + +#include <deque> +#include <vector> + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/padding_fifo_queue.h" +#include "tensorflow/core/kernels/queue_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +PaddingFIFOQueue::PaddingFIFOQueue( + int capacity, const DataTypeVector& component_dtypes, + const std::vector<PartialTensorShape>& partial_shapes, const string& name) + : FIFOQueue(capacity, component_dtypes, + ConvertShapesPartialDimensionsToZero(partial_shapes), name), + partial_shapes_(partial_shapes) {} + +Status PaddingFIFOQueue::Initialize() { + Status s = FIFOQueue::Initialize(); + if (!s.ok()) return s; + + if (component_dtypes_.size() != partial_shapes_.size()) { + return errors::InvalidArgument( + "Shapes must be provided for all components, but received ", + component_dtypes_.size(), " dtypes and ", partial_shapes_.size(), + " shapes."); + } + + return Status::OK(); +} + +/* static */ +Status PaddingFIFOQueue::GetElementComponent( + const PaddingFIFOQueue::Tuple& tuple, int component, OpKernelContext* ctx, + PersistentTensor* out_tensor) { + TensorShape element_shape(tuple[component].shape()); + Tensor* element_access = nullptr; + TF_RETURN_IF_ERROR(ctx->allocate_persistent( + tuple[component].dtype(), element_shape, out_tensor, &element_access)); + *element_access = tuple[component]; + return Status::OK(); +} + +void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, + CallbackWithTuple callback) { + if (num_elements == 0) { + Tuple tuple; + tuple.reserve(num_components()); + for (int i = 0; i < num_components(); ++i) { + // TODO(josh11b,misard): Switch to allocate_output(). + // See similar comment in fifo_queue.cc + Tensor element; + // Here, ManyOutShape returns zeros for undetermined shapes, + // which is exactly what we want to use. + ctx->allocate_temp(component_dtypes_[i], ManyOutShape(i, 0), &element); + tuple.emplace_back(element); + } + callback(tuple); + return; + } + + CancellationManager* cm = ctx->cancellation_manager(); + CancellationToken token = cm->get_cancellation_token(); + bool already_cancelled; + { + mutex_lock l(mu_); + already_cancelled = !cm->RegisterCallback( + token, [this, token]() { Cancel(kDequeue, token); }); + if (!already_cancelled) { + // TODO(josh11b): This makes two copies of callback, avoid this if possible. + dequeue_attempts_.emplace_back( + num_elements, [callback]() { callback(Tuple()); }, ctx, token, + [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int32 s = queues_[0].size(); + if (closed_ && s < attempt->elements_requested) { + attempt->context->SetStatus(errors::OutOfRange( + "PaddingFIFOQueue '", name_, "' is closed and has ", + "insufficient elements (requested ", + attempt->elements_requested, ", current size ", s, ")")); + + // TODO(mrry): Add support for producing a partial batch as + // output when the queue is closed. + if (!attempt->tuples.empty()) { + // Restore already-dequeued elements to the front of the queue. + for (int64 i = attempt->tuples.size() - 1; i >= 0; --i) { + for (int j = 0; j < num_components(); ++j) { + PersistentTensor element; + Status s = GetElementComponent(attempt->tuples[i], j, + attempt->context, &element); + if (!s.ok()) { + attempt->context->SetStatus( + errors::DataLoss("Failed to restore element from " + "partially-dequeued batch " + "to PaddingFIFOQueue: ", + s.error_message())); + } + queues_[j].push_front(element); + } + } + } + return kComplete; + } + + RunResult result = kNoProgress; + for (; s > 0; --s) { + result = kProgress; + Tuple tuple; + DequeueLocked(attempt->context, &tuple); + attempt->tuples.push_back(tuple); + tuple.clear(); + --attempt->elements_requested; + + if (attempt->elements_requested == 0) { + // Finished. Allocate attempt->tuple and + // copy from attempt->tuples to attempt->tuple. + attempt->tuple.reserve(num_components()); + const std::vector<Tuple>& tuples = attempt->tuples; + + std::vector<bool> dynamic_shape; + const int64 batch_size = tuples.size(); + + for (int i = 0; i < num_components(); ++i) { + const PartialTensorShape partial_shape = + PartialTensorShape({batch_size}) + .Concatenate(partial_shapes_[i]); + TensorShape shape({batch_size}); + + for (int j = 0; j < partial_shape.dims() - 1; ++j) { + if (partial_shape.dim_size(j + 1) > -1) { + shape.AddDim(partial_shape.dim_size(j + 1)); + } else { + // Expand sizes to match. + int64 max_val = 0; + for (const Tuple& t : tuples) { + max_val = max(max_val, t[i].shape().dim_size(j)); + } + shape.AddDim(max_val); + } + } + + Tensor element; + attempt->context->allocate_temp(component_dtypes_[i], shape, + &element); + + bool has_dynamic_shape = !partial_shape.IsFullyDefined(); + if (has_dynamic_shape) { + // Set all values to zero because not all values + // will get written over. + attempt->context->SetStatus(SetElementZero(&element)); + if (!attempt->context->status().ok()) return kComplete; + } + + dynamic_shape.push_back(has_dynamic_shape); + + // TODO(ebrevdo): should this be a persistent tensor? + attempt->tuple.emplace_back(element); + } + + for (int index = 0; index < tuples.size(); ++index) { + for (int i = 0; i < num_components(); ++i) { + if (dynamic_shape[i]) { + // Slightly slower copy operation + attempt->context->SetStatus(CopyElementToLargerSlice( + tuples[index][i], &attempt->tuple[i], index)); + } else { + attempt->context->SetStatus(CopyElementToSlice( + tuples[index][i], &attempt->tuple[i], index)); + } + if (!attempt->context->status().ok()) return kComplete; + } + } + tuple = attempt->tuple; + attempt->tuples.clear(); + attempt->done_callback = [callback, tuple]() { + callback(tuple); + }; + return kComplete; + } + } + return result; + }); + } + } + if (!already_cancelled) { + FlushUnlocked(); + } else { + ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled")); + callback(Tuple()); + } +} + +Status PaddingFIFOQueue::ValidateTuple(const Tuple& tuple) { + TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); + for (size_t i = 0; i < tuple.size(); ++i) { + if (!partial_shapes_[i].IsCompatibleWith(tuple[i].shape())) { + return errors::InvalidArgument("Shape mismatch in tuple component ", i, + ". Expected ", + partial_shapes_[i].DebugString(), ", got ", + tuple[i].shape().ShortDebugString()); + } + } + return Status::OK(); +} + +Status PaddingFIFOQueue::ValidateManyTuple(const Tuple& tuple) { + TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); + const int64 batch_size = tuple[0].dim_size(0); + for (size_t i = 0; i < tuple.size(); ++i) { + // Expected shape is [batch_size] + partial_shapes_[i] + const PartialTensorShape expected_shape = + PartialTensorShape({batch_size}).Concatenate(partial_shapes_[i]); + if (!expected_shape.IsCompatibleWith(tuple[i].shape())) { + return errors::InvalidArgument("Shape mismatch in tuple component ", i, + ". Expected ", + expected_shape.DebugString(), ", got ", + tuple[i].shape().ShortDebugString()); + } + } + return Status::OK(); +} + +Status PaddingFIFOQueue::CompatibleNodeDefShapes( + const NodeDef& node_def) const { + std::vector<PartialTensorShape> requested_shapes; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes)); + if (!PartialTensorShapeUtils::AreCompatible(requested_shapes, + partial_shapes_)) { + return errors::InvalidArgument( + "Shared queue '", name_, "' has component shapes ", + PartialTensorShapeUtils::PartialShapeListString(partial_shapes_), + " but requested component shapes were ", + PartialTensorShapeUtils::PartialShapeListString(requested_shapes)); + } else { + return Status::OK(); + } +} + +Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) { + TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "PaddingFIFOQueue")); + TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); + TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def)); + TF_RETURN_IF_ERROR(CompatibleNodeDefShapes(node_def)); + return Status::OK(); +} + +template <typename T, int NDIMS> +Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, + int index) { + DCHECK_NE(parent->dim_size(0), 0); + if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) { + TensorShape chip_shape = parent->shape(); + chip_shape.RemoveDim(0); + return errors::Internal( + "HandleElementToLargerSlice Cannot copy slice: number of entries in " + "element is greater than number of elements in parent slice. ", + "Shapes are: [element]: ", element.shape().DebugString(), + ", [parent slice]: ", chip_shape.DebugString()); + } + auto element_t = element.tensor<T, NDIMS>(); + auto parent_t = parent->tensor<T, NDIMS + 1>(); + Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices; + slice_indices[0] = index; + Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size; + slice_size[0] = 1; + for (int i = 1; i < slice_size.size(); ++i) { + slice_size[i] = element_t.dimension(i - 1); + } + parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size); + return Status::OK(); +} + +namespace { + +template <int NDIMS> +Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, + int index) { +#define HANDLE_TYPE(T) \ + case DataTypeToEnum<T>::value: { \ + return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \ + } + + switch (element.dtype()) { + TF_CALL_ALL_TYPES(HANDLE_TYPE); +#undef HANDLE_TYPE + default: + return errors::Unimplemented( + "HandleElementToLargerSliceWithRank Unhandled data type: ", + element.dtype()); + } +} + +} // namespace + +Status PaddingFIFOQueue::CopyElementToLargerSlice(const Tensor& element, + Tensor* parent, int index) { + if (parent->dims() != element.dims() + 1) { + return errors::Internal( + "Mismatched ranks. Element's rank is: ", element.dims(), + " but element is meant to be a slice in output Tensor having rank: ", + parent->dims(), " (should be: ", element.dims() + 1, ")"); + } + +#define HANDLE_DIMS(NDIMS) \ + case NDIMS: { \ + TF_RETURN_IF_ERROR( \ + HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \ + return Status::OK(); \ + } + + switch (element.dims()) { + HANDLE_DIMS(0); + HANDLE_DIMS(1); + HANDLE_DIMS(2); + HANDLE_DIMS(3); + HANDLE_DIMS(4); +#undef HANDLE_DIMS + default: + return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ", + element.dims()); + } +} + +// Static method +Status PaddingFIFOQueue::SetElementZero(Tensor* element) { +#define HANDLE_TYPE(T) \ + if (element->dtype() == DataTypeToEnum<T>::value) { \ + element->flat<T>().setConstant(T()); \ + return Status::OK(); \ + } + TF_CALL_ALL_TYPES(HANDLE_TYPE); +#undef HANDLE_TYPE + return errors::Unimplemented("SetElementZero Unhandled data type: ", + element->dtype()); +} + +std::vector<TensorShape> PaddingFIFOQueue::ConvertShapesPartialDimensionsToZero( + const gtl::ArraySlice<PartialTensorShape>& partial_shapes) { + std::vector<TensorShape> shapes(partial_shapes.size()); + for (int i = 0; i < shapes.size(); ++i) { + const PartialTensorShape& partial = partial_shapes[i]; + TensorShape& shape = shapes[i]; + for (int64 s : partial.dim_sizes()) shape.AddDim(s < 0 ? 0 : s); + } + return shapes; +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/padding_fifo_queue.h b/tensorflow/core/kernels/padding_fifo_queue.h new file mode 100644 index 0000000000..afcbeea7e8 --- /dev/null +++ b/tensorflow/core/kernels/padding_fifo_queue.h @@ -0,0 +1,88 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_ +#define TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_ + +#include <deque> +#include <vector> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/fifo_queue.h" +#include "tensorflow/core/kernels/typed_queue.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/partial_tensor_shape.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +class PaddingFIFOQueue : public FIFOQueue { + public: + PaddingFIFOQueue(int32 capacity, const DataTypeVector& component_dtypes, + const std::vector<PartialTensorShape>& component_shapes, + const string& name); + + Status Initialize() override; + + // Implementations of QueueInterface methods -------------------------------- + + void TryDequeueMany(int num_elements, OpKernelContext* ctx, + CallbackWithTuple callback) override; + Status MatchesNodeDef(const NodeDef& node_def) override; + + protected: + Status ValidateManyTuple(const Tuple& tuple) override; + Status ValidateTuple(const Tuple& tuple) override; + Status CompatibleNodeDefShapes(const NodeDef& node_def) const; + + // Convert a list of PartialTensorShape to a list of + // TensorShape. + // Any unknown dimension sizes are converted to 0. + // REQUIRED: All the input shapes have well defined rank. + static std::vector<TensorShape> ConvertShapesPartialDimensionsToZero( + const gtl::ArraySlice<PartialTensorShape>& partial_shapes); + + // Sets the values in the given element to zero. + static Status SetElementZero(Tensor* element); + + // Copies element into the index^th slice (in the first dimension) + // of parent. Allows for the parent's slice to have a larger size + // than the element, and copies the element into the upper left hand + // corner of the slice. + static Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, + int index); + + std::vector<PartialTensorShape> partial_shapes_; + + private: + ~PaddingFIFOQueue() override {} + + static Status GetElementComponent(const PaddingFIFOQueue::Tuple& tuple, + int component, OpKernelContext* ctx, + PersistentTensor* out_tensor); + + static Status IsSameSizeExceptZerosInFirst(const TensorShape& first, + const TensorShape& second); + + TF_DISALLOW_COPY_AND_ASSIGN(PaddingFIFOQueue); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_ diff --git a/tensorflow/core/kernels/padding_fifo_queue_op.cc b/tensorflow/core/kernels/padding_fifo_queue_op.cc new file mode 100644 index 0000000000..08347bcf63 --- /dev/null +++ b/tensorflow/core/kernels/padding_fifo_queue_op.cc @@ -0,0 +1,68 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/data_flow_ops.cc. + +#include <deque> +#include <vector> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/padding_fifo_queue.h" +#include "tensorflow/core/kernels/queue_base.h" +#include "tensorflow/core/kernels/queue_op.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/partial_tensor_shape.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +// Defines a PaddingFIFOQueueOp, which produces a Queue (specifically, one +// backed by PaddingFIFOQueue) that persists across different graph +// executions, and sessions. Running this op produces a single-element +// tensor of handles to Queues in the corresponding device. +class PaddingFIFOQueueOp : public QueueOp { + public: + explicit PaddingFIFOQueueOp(OpKernelConstruction* context) : QueueOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); + } + + protected: + CreatorCallback GetCreator() const override { + return [this](QueueInterface** ret) { + PaddingFIFOQueue* queue = new PaddingFIFOQueue( + capacity_, component_types_, component_shapes_, cinfo_.name()); + *ret = queue; + return queue->Initialize(); + }; + } + + private: + std::vector<PartialTensorShape> component_shapes_; + + TF_DISALLOW_COPY_AND_ASSIGN(PaddingFIFOQueueOp); +}; + +REGISTER_KERNEL_BUILDER(Name("PaddingFIFOQueue").Device(DEVICE_CPU), + PaddingFIFOQueueOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc index 590e4e9123..b33af06408 100644 --- a/tensorflow/core/kernels/queue_base.cc +++ b/tensorflow/core/kernels/queue_base.cc @@ -32,8 +32,8 @@ Status HandleSliceToElement(const Tensor& parent, Tensor* element, int index) { TensorShape chip_shape = parent.shape(); chip_shape.RemoveDim(0); return errors::Internal( - "Cannot copy slice: number of elements does not match. Shapes are: " - "[element]: ", + "HandleSliceToElement Cannot copy slice: number of elements does not " + "match. Shapes are: [element]: ", element->shape().DebugString(), ", [parent slice]: ", chip_shape.DebugString()); } @@ -50,8 +50,8 @@ Status HandleElementToSlice(const Tensor& element, Tensor* parent, int index) { TensorShape chip_shape = parent->shape(); chip_shape.RemoveDim(0); return errors::Internal( - "Cannot copy slice: number of elements does not match. Shapes are: " - "[element]: ", + "HandleElementToSlice Cannot copy slice: number of elements does not " + "match. Shapes are: [element]: ", element.shape().DebugString(), ", [parent slice]: ", chip_shape.DebugString()); } @@ -156,7 +156,7 @@ Status QueueBase::ValidateTuple(const Tuple& tuple) { TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); if (specified_shapes()) { for (size_t i = 0; i < tuple.size(); ++i) { - if (!tuple[i].shape().IsSameSize(component_shapes_[i])) { + if (!component_shapes_[i].IsSameSize(tuple[i].shape())) { return errors::InvalidArgument( "Shape mismatch in tuple component ", i, ". Expected ", component_shapes_[i].ShortDebugString(), ", got ", @@ -176,7 +176,7 @@ Status QueueBase::ValidateManyTuple(const Tuple& tuple) { for (size_t i = 0; i < tuple.size(); ++i) { // Expected shape is [batch_size] + component_shapes_[i] const TensorShape expected_shape = ManyOutShape(i, batch_size); - if (!tuple[i].shape().IsSameSize(expected_shape)) { + if (!expected_shape.IsSameSize(tuple[i].shape())) { return errors::InvalidArgument( "Shape mismatch in tuple component ", i, ". Expected ", expected_shape.ShortDebugString(), ", got ", @@ -331,7 +331,6 @@ void QueueBase::FlushUnlocked() { } } -// Static method Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element, int index) { #define HANDLE_TYPE(DT) \ @@ -355,7 +354,8 @@ Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element, HANDLE_TYPE(DT_QINT16); HANDLE_TYPE(DT_QUINT16); #undef HANDLE_TYPE - return errors::Unimplemented("Unhandled data type: ", parent.dtype()); + return errors::Unimplemented("CopySliceToElement Unhandled data type: ", + parent.dtype()); } // Static method @@ -382,7 +382,8 @@ Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent, HANDLE_TYPE(DT_QINT16); HANDLE_TYPE(DT_QUINT16); #undef HANDLE_TYPE - return errors::Unimplemented("Unhandled data type: ", element.dtype()); + return errors::Unimplemented("CopyElementToSlice Unhandled data type: ", + element.dtype()); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h index 10a6d65417..1c0e377cda 100644 --- a/tensorflow/core/kernels/queue_base.h +++ b/tensorflow/core/kernels/queue_base.h @@ -146,6 +146,8 @@ class QueueBase : public QueueInterface { RunCallback run_callback; // must be run while holding mu_ bool is_cancelled; Tuple tuple; + // tuples is used by some implementations allowing dynamic shapes. + std::vector<Tuple> tuples; Attempt(int32 elements_requested, DoneCallback done_callback, OpKernelContext* context, CancellationToken cancellation_token, diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h new file mode 100644 index 0000000000..a549439220 --- /dev/null +++ b/tensorflow/core/kernels/queue_op.h @@ -0,0 +1,105 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_QUEUE_OP_H_ +#define TENSORFLOW_KERNELS_QUEUE_OP_H_ + +#include <deque> +#include <vector> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/queue_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +// Defines a QueueOp, an abstract class for Queue construction ops. +class QueueOp : public OpKernel { + public: + QueueOp(OpKernelConstruction* context) + : OpKernel(context), queue_handle_set_(false) { + OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_)); + OP_REQUIRES_OK(context, + context->allocate_persistent(DT_STRING, TensorShape({2}), + &queue_handle_, nullptr)); + if (capacity_ < 0) { + capacity_ = QueueBase::kUnbounded; + } + OP_REQUIRES_OK(context, + context->GetAttr("component_types", &component_types_)); + } + + void Compute(OpKernelContext* ctx) override { + mutex_lock l(mu_); + if (!queue_handle_set_) { + OP_REQUIRES_OK(ctx, SetQueueHandle(ctx)); + } + ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx)); + } + + protected: + ~QueueOp() override { + // If the queue object was not shared, delete it. + if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) { + TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>( + cinfo_.container(), cinfo_.name())); + } + } + + protected: + typedef std::function<Status(QueueInterface**)> CreatorCallback; + + // Subclasses must override this + virtual CreatorCallback GetCreator() const = 0; + + // Variables accessible by subclasses + int32 capacity_; + DataTypeVector component_types_; + ContainerInfo cinfo_; + + private: + Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def())); + CreatorCallback creator = GetCreator(); + QueueInterface* queue; + TF_RETURN_IF_ERROR( + cinfo_.resource_manager()->LookupOrCreate<QueueInterface>( + cinfo_.container(), cinfo_.name(), &queue, creator)); + core::ScopedUnref unref_me(queue); + // Verify that the shared queue is compatible with the requested arguments. + TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def())); + auto h = queue_handle_.AccessTensor(ctx)->flat<string>(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + queue_handle_set_ = true; + return Status::OK(); + } + + mutex mu_; + PersistentTensor queue_handle_ GUARDED_BY(mu_); + bool queue_handle_set_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_QUEUE_OP_H_ diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc index da391224af..e9b6ead381 100644 --- a/tensorflow/core/kernels/random_shuffle_queue_op.cc +++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/kernels/typed_queue.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/random/philox_random.h" @@ -404,17 +405,10 @@ Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) { // backed by RandomShuffleQueue) that persists across different graph // executions, and sessions. Running this op produces a single-element // tensor of handles to Queues in the corresponding device. -class RandomShuffleQueueOp : public OpKernel { +class RandomShuffleQueueOp : public QueueOp { public: explicit RandomShuffleQueueOp(OpKernelConstruction* context) - : OpKernel(context), queue_handle_set_(false) { - OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_)); - OP_REQUIRES_OK(context, - context->allocate_persistent(DT_STRING, TensorShape({2}), - &queue_handle_, nullptr)); - if (capacity_ < 0) { - capacity_ = RandomShuffleQueue::kUnbounded; - } + : QueueOp(context) { OP_REQUIRES_OK(context, context->GetAttr("min_after_dequeue", &min_after_dequeue_)); OP_REQUIRES(context, min_after_dequeue_ >= 0, @@ -427,32 +421,12 @@ class RandomShuffleQueueOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_)); OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_)); - OP_REQUIRES_OK(context, - context->GetAttr("component_types", &component_types_)); OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); } - ~RandomShuffleQueueOp() override { - // If the queue object was not shared, delete it. - if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) { - TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>( - cinfo_.container(), cinfo_.name())); - } - } - - void Compute(OpKernelContext* ctx) override { - mutex_lock l(mu_); - if (!queue_handle_set_) { - OP_REQUIRES_OK(ctx, SetQueueHandle(ctx)); - } - ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx)); - } - - private: - Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def())); - QueueInterface* queue; - auto creator = [this](QueueInterface** ret) { + protected: + CreatorCallback GetCreator() const override { + return [this](QueueInterface** ret) { auto* q = new RandomShuffleQueue(capacity_, min_after_dequeue_, seed_, seed2_, component_types_, component_shapes_, cinfo_.name()); @@ -464,30 +438,13 @@ class RandomShuffleQueueOp : public OpKernel { } return s; }; - TF_RETURN_IF_ERROR( - cinfo_.resource_manager()->LookupOrCreate<QueueInterface>( - cinfo_.container(), cinfo_.name(), &queue, creator)); - core::ScopedUnref unref_me(queue); - // Verify that the shared queue is compatible with the requested arguments. - TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def())); - auto h = queue_handle_.AccessTensor(ctx)->flat<string>(); - h(0) = cinfo_.container(); - h(1) = cinfo_.name(); - queue_handle_set_ = true; - return Status::OK(); } - int32 capacity_; + private: int32 min_after_dequeue_; int64 seed_; int64 seed2_; - DataTypeVector component_types_; std::vector<TensorShape> component_shapes_; - ContainerInfo cinfo_; - - mutex mu_; - PersistentTensor queue_handle_ GUARDED_BY(mu_); - bool queue_handle_set_ GUARDED_BY(mu_); TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueueOp); }; |