diff options
author | 2018-04-12 10:35:41 -0700 | |
---|---|---|
committer | 2018-04-12 10:38:25 -0700 | |
commit | 844b8cae970d835850a75f8063324224b2de0df0 (patch) | |
tree | 73cc881d4a3616c8888d9eb64530d6a55a87d01f /tensorflow/core/kernels/list_kernels.h | |
parent | ffbf77de81d0b7b4b169c92d0d9fbbdef5b8842a (diff) |
[TF] Add TensorListPushBackBatch.
Also modify code to ensure aliased forwarding happens whenever
possible with DT_VARIANT objects in ResourceVariables and in the new op.
PiperOrigin-RevId: 192632202
Diffstat (limited to 'tensorflow/core/kernels/list_kernels.h')
-rw-r--r-- | tensorflow/core/kernels/list_kernels.h | 121 |
1 files changed, 121 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h index f3bbf3b6e3..42871c6113 100644 --- a/tensorflow/core/kernels/list_kernels.h +++ b/tensorflow/core/kernels/list_kernels.h @@ -34,6 +34,8 @@ limitations under the License. namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; + // Variant compatible type for a list of tensors. This is mutable but instances // should never be mutated after stored in a variant tensor. struct TensorList { @@ -146,6 +148,10 @@ class TensorListFromTensor : public OpKernel { TensorList output_list; const Tensor& t = c->input(0); output_list.element_dtype = t.dtype(); + OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()), + errors::InvalidArgument( + "Tensor must be at least a vector, but saw shape: ", + t.shape().DebugString())); TensorShape output_shape(t.shape()); output_shape.RemoveDim(0); OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape), @@ -267,6 +273,121 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, return Status::OK(); } +template <typename Device, typename T> +class TensorListPushBackBatch : public OpKernel { + public: + explicit TensorListPushBackBatch(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); + } + + ~TensorListPushBackBatch() override {} + + void Compute(OpKernelContext* c) override { + const Tensor& input = c->input(1); + OP_REQUIRES(c, element_dtype_ == input.dtype(), + errors::InvalidArgument("Invalid data types; list elements ", + DataTypeString(element_dtype_), + " but tried to append ", + DataTypeString(input.dtype()))); + OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input.shape()), + errors::InvalidArgument( + "Expected tensor to be at least a vector, but saw shape: ", + input.shape().DebugString())); + + const TensorShape& tls_shape = c->input(0).shape(); + + // For purposes of input forwarding, we want the least restrictive + // AllocatorAttributes possible. If we need to allocate later, + // we'll request the DT_VARIANT be allocated on host. + AllocatorAttributes attr; + + std::unique_ptr<Tensor> tls_alias = c->forward_input( + 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape, + DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr); + + const Tensor& tls = tls_alias ? *tls_alias : c->input(0); + + OP_REQUIRES(c, tls.dtype() == DT_VARIANT, + errors::InvalidArgument( + "Expected input_handles dtype to be Variant, but saw: ", + DataTypeString(tls.dtype()))); + OP_REQUIRES(c, TensorShapeUtils::IsVector(tls_shape), + errors::InvalidArgument( + "Expected input_handles to be a vector, but saw shape: ", + tls_shape.DebugString())); + const int64 batch_size = tls.NumElements(); + OP_REQUIRES(c, input.dim_size(0) == batch_size, + errors::InvalidArgument( + "Expected tensor.shape[0] == input_handles.size, but saw ", + input.dim_size(0), " vs. ", batch_size)); + auto tls_t = tls.vec<Variant>(); + + TensorShape input_element_shape = input.shape(); + input_element_shape.RemoveDim(0); + std::vector<const TensorList*> tl_batch; + for (int64 b = 0; b < batch_size; ++b) { + const TensorList* l = tls_t(b).get<TensorList>(); + OP_REQUIRES(c, l != nullptr, + errors::InvalidArgument("Input handle at index ", b, + " is not a list. Saw: '", + tls_t(b).DebugString(), "'")); + OP_REQUIRES( + c, l->element_shape.IsCompatibleWith(input_element_shape), + errors::InvalidArgument( + "Tried to append a tensor with incompatible shape to a " + "list at index ", + b, ". Op element shape: ", input_element_shape.DebugString(), + " list shape: ", l->element_shape.DebugString())); + OP_REQUIRES(c, element_dtype_ == l->element_dtype, + errors::InvalidArgument( + "Invalid data type at index ", b, "; op elements ", + DataTypeString(element_dtype_), " but list elements ", + DataTypeString(l->element_dtype))); + tl_batch.push_back(l); + } + + Tensor* result; + + if (tls_alias) { + result = tls_alias.get(); + c->set_output(0, *result); + } else { + // DT_VARIANT tensors always allocated on host. + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK( + c, c->allocate_output(0, TensorShape{batch_size}, &result, attr)); + } + + if (batch_size == 0) { + return; + } + + auto input_t = input.flat_outer_dims<T, 2>(); + auto result_t = result->vec<Variant>(); + + for (int64 b = 0; b < batch_size; ++b) { + if (!tls_alias) { + result_t(b) = *tl_batch[b]; + } + TensorList* output = result_t(b).get<TensorList>(); + DCHECK(output != nullptr); + Tensor* frame; + PersistentTensor tmp; + OP_REQUIRES_OK(c, c->allocate_persistent( + element_dtype_, input_element_shape, &tmp, &frame)); + if (input_element_shape.num_elements() > 0) { + auto frame_t = frame->flat<T>(); + frame_t.device(c->eigen_device<Device>()) = input_t.template chip<0>(b); + } + output->tensors.push_back(std::move(*frame)); + } + } + + private: + DataType element_dtype_; +}; + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ |