aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/list_kernels.h
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-04-12 10:35:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-12 10:38:25 -0700
commit844b8cae970d835850a75f8063324224b2de0df0 (patch)
tree73cc881d4a3616c8888d9eb64530d6a55a87d01f /tensorflow/core/kernels/list_kernels.h
parentffbf77de81d0b7b4b169c92d0d9fbbdef5b8842a (diff)
[TF] Add TensorListPushBackBatch.
Also modify code to ensure aliased forwarding happens whenever possible with DT_VARIANT objects in ResourceVariables and in the new op. PiperOrigin-RevId: 192632202
Diffstat (limited to 'tensorflow/core/kernels/list_kernels.h')
-rw-r--r--tensorflow/core/kernels/list_kernels.h121
1 files changed, 121 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index f3bbf3b6e3..42871c6113 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -34,6 +34,8 @@ limitations under the License.
namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
// Variant compatible type for a list of tensors. This is mutable but instances
// should never be mutated after stored in a variant tensor.
struct TensorList {
@@ -146,6 +148,10 @@ class TensorListFromTensor : public OpKernel {
TensorList output_list;
const Tensor& t = c->input(0);
output_list.element_dtype = t.dtype();
+ OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
+ errors::InvalidArgument(
+ "Tensor must be at least a vector, but saw shape: ",
+ t.shape().DebugString()));
TensorShape output_shape(t.shape());
output_shape.RemoveDim(0);
OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
@@ -267,6 +273,121 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
return Status::OK();
}
+template <typename Device, typename T>
+class TensorListPushBackBatch : public OpKernel {
+ public:
+ explicit TensorListPushBackBatch(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
+ }
+
+ ~TensorListPushBackBatch() override {}
+
+ void Compute(OpKernelContext* c) override {
+ const Tensor& input = c->input(1);
+ OP_REQUIRES(c, element_dtype_ == input.dtype(),
+ errors::InvalidArgument("Invalid data types; list elements ",
+ DataTypeString(element_dtype_),
+ " but tried to append ",
+ DataTypeString(input.dtype())));
+ OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input.shape()),
+ errors::InvalidArgument(
+ "Expected tensor to be at least a vector, but saw shape: ",
+ input.shape().DebugString()));
+
+ const TensorShape& tls_shape = c->input(0).shape();
+
+ // For purposes of input forwarding, we want the least restrictive
+ // AllocatorAttributes possible. If we need to allocate later,
+ // we'll request the DT_VARIANT be allocated on host.
+ AllocatorAttributes attr;
+
+ std::unique_ptr<Tensor> tls_alias = c->forward_input(
+ 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape,
+ DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
+
+ const Tensor& tls = tls_alias ? *tls_alias : c->input(0);
+
+ OP_REQUIRES(c, tls.dtype() == DT_VARIANT,
+ errors::InvalidArgument(
+ "Expected input_handles dtype to be Variant, but saw: ",
+ DataTypeString(tls.dtype())));
+ OP_REQUIRES(c, TensorShapeUtils::IsVector(tls_shape),
+ errors::InvalidArgument(
+ "Expected input_handles to be a vector, but saw shape: ",
+ tls_shape.DebugString()));
+ const int64 batch_size = tls.NumElements();
+ OP_REQUIRES(c, input.dim_size(0) == batch_size,
+ errors::InvalidArgument(
+ "Expected tensor.shape[0] == input_handles.size, but saw ",
+ input.dim_size(0), " vs. ", batch_size));
+ auto tls_t = tls.vec<Variant>();
+
+ TensorShape input_element_shape = input.shape();
+ input_element_shape.RemoveDim(0);
+ std::vector<const TensorList*> tl_batch;
+ for (int64 b = 0; b < batch_size; ++b) {
+ const TensorList* l = tls_t(b).get<TensorList>();
+ OP_REQUIRES(c, l != nullptr,
+ errors::InvalidArgument("Input handle at index ", b,
+ " is not a list. Saw: '",
+ tls_t(b).DebugString(), "'"));
+ OP_REQUIRES(
+ c, l->element_shape.IsCompatibleWith(input_element_shape),
+ errors::InvalidArgument(
+ "Tried to append a tensor with incompatible shape to a "
+ "list at index ",
+ b, ". Op element shape: ", input_element_shape.DebugString(),
+ " list shape: ", l->element_shape.DebugString()));
+ OP_REQUIRES(c, element_dtype_ == l->element_dtype,
+ errors::InvalidArgument(
+ "Invalid data type at index ", b, "; op elements ",
+ DataTypeString(element_dtype_), " but list elements ",
+ DataTypeString(l->element_dtype)));
+ tl_batch.push_back(l);
+ }
+
+ Tensor* result;
+
+ if (tls_alias) {
+ result = tls_alias.get();
+ c->set_output(0, *result);
+ } else {
+ // DT_VARIANT tensors always allocated on host.
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ OP_REQUIRES_OK(
+ c, c->allocate_output(0, TensorShape{batch_size}, &result, attr));
+ }
+
+ if (batch_size == 0) {
+ return;
+ }
+
+ auto input_t = input.flat_outer_dims<T, 2>();
+ auto result_t = result->vec<Variant>();
+
+ for (int64 b = 0; b < batch_size; ++b) {
+ if (!tls_alias) {
+ result_t(b) = *tl_batch[b];
+ }
+ TensorList* output = result_t(b).get<TensorList>();
+ DCHECK(output != nullptr);
+ Tensor* frame;
+ PersistentTensor tmp;
+ OP_REQUIRES_OK(c, c->allocate_persistent(
+ element_dtype_, input_element_shape, &tmp, &frame));
+ if (input_element_shape.num_elements() > 0) {
+ auto frame_t = frame->flat<T>();
+ frame_t.device(c->eigen_device<Device>()) = input_t.template chip<0>(b);
+ }
+ output->tensors.push_back(std::move(*frame));
+ }
+ }
+
+ private:
+ DataType element_dtype_;
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_