aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/tensor_array.h
blob: 384a63e945306637bcf074d1f3709eea055bffe9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
#define TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_

#include <limits.h>
#include <vector>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/aggregate_ops.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;

namespace tensor_array {

// Full implementations are in tensor_array.cc
template <typename Device, typename T>
Status AddToTensor(OpKernelContext* ctx, Tensor* sum, const Tensor* current,
                   const Tensor* add) {
  return errors::InvalidArgument(
      "tensor_array::AddToTensor type not supported: ",
      DataTypeString(DataTypeToEnum<T>::value));
};

#define TENSOR_ARRAY_WRITE_OR_ADD(Device, T)                         \
  template <>                                                        \
  Status AddToTensor<Device, T>(OpKernelContext * ctx, Tensor * sum, \
                                const Tensor* current, const Tensor* add);

#define TENSOR_ARRAY_WRITE_OR_ADD_CPU(T) TENSOR_ARRAY_WRITE_OR_ADD(CPUDevice, T)
TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_CPU)
#undef TENSOR_ARRAY_WRITE_OR_ADD_CPU

#if GOOGLE_CUDA

#define TENSOR_ARRAY_WRITE_OR_ADD_GPU(T) TENSOR_ARRAY_WRITE_OR_ADD(GPUDevice, T)
TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
TF_CALL_complex64(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
TF_CALL_complex128(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
#undef TENSOR_ARRAY_WRITE_OR_ADD_GPU

#endif  // GOOGLE_CUDA

#undef TENSOR_ARRAY_WRITE_OR_ADD

template <typename Device, typename T>
Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
  return errors::InvalidArgument(
      "tensor_array::TensorSetZero type not supported: ",
      DataTypeString(DataTypeToEnum<T>::value));
};

#define TENSOR_ARRAY_SET_ZERO(Device, T) \
  template <>                            \
  Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value);

#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
#undef TENSOR_ARRAY_SET_ZERO_CPU

#if GOOGLE_CUDA

#define TENSOR_ARRAY_SET_ZERO_GPU(T) TENSOR_ARRAY_SET_ZERO(GPUDevice, T)
TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
TF_CALL_complex64(TENSOR_ARRAY_SET_ZERO_GPU);
TF_CALL_complex128(TENSOR_ARRAY_SET_ZERO_GPU);
#undef TENSOR_ARRAY_SET_ZERO_GPU

#endif  // GOOGLE_CUDA

#undef TENSOR_ARRAY_SET_ZERO

}  // namespace tensor_array

// The TensorArray object keeps an array of PersistentTensors.  It
// allows reading from the array and writing to the array.
//
// Important properties:
//   * Usually, writing to a particular index in the TensorArray is allowed at
//     most once per index.  In a special case, writes with the flag
//     multiple_writes_aggregate allow multiple writes to the same
//     index.  In this case, the writes are summed.
//   * Multiple reads are supported.
//   * Deep copies of PersistentTensors are rarely made.  The only
//     time they are made is when WriteOrAggregate is called at least twice
//     on the same index with the flag multiple_writes_aggregate = True.
//   * Reading and Writing to the array is protected by a mutex.
//     All operations on a TensorArray are thread-safe.
//   * A TensorArray may be preemptively closed, which releases all
//     memory associated with it.
//
// These properties together allow the TensorArray to work as a
// functional object and makes gradient computation easy.  For
// example:
//   * Write-Once semantics mean the gradient of a TensorArray Read never has to
//     worry which of multiple writes to that index the gradient value
//     is meant for.
//   * Read-Many semantics (when using clear_after_read=false) allow the
//     TensorArray to be read, packed, or concatenated multiple times;
//     and the gradient operations use the multiple_writes_aggregate
//     flag to aggregate the backprop writes.  Multiple backprop writes to
//     the same index are partial gradients corresponding to the
//     multiple reads of that index in the forward phase.
//
class TensorArray : public ResourceBase {
 public:
  static std::atomic<int64> tensor_array_counter;

  // Construct a TensorArray for holding Tensors of type 'dtype' with
  // 'N' elements.  While the underlying storage is a std::vector and
  // can hold more than MAX_INT entries, in practice we do not expect
  // users to construct this many Tensors for storage in a TensorArray.
  TensorArray(const string& key, const DataType& dtype, const Tensor& handle,
              int32 N, const PartialTensorShape& element_shape,
              bool identical_element_shapes, bool dynamic_size,
              bool multiple_writes_aggregate, bool is_grad, int32 marked_size,
              bool clear_after_read)
      : key_(key),
        dtype_(dtype),
        handle_(handle),
        closed_(false),
        dynamic_size_(dynamic_size),
        multiple_writes_aggregate_(multiple_writes_aggregate),
        gradients_disallowed_(false),
        clear_after_read_(clear_after_read),
        is_grad_(is_grad),
        marked_size_(marked_size),
        element_shape_(element_shape),
        identical_element_shapes_(identical_element_shapes),
        tensors_(N) {}

  // Write PersistentTensor 'value' to index 'index'.
  //
  // Preconditions:
  //  * The TensorArray is not closed
  //  * If the array has dynamic size:
  //      The index is >= 0
  //    Otherwise:
  //      The index is in [0, N) where N == Size()
  //  * The dtype of the Tensor in 'value' matches the TensorArray's dtype.
  //  * If multiple_writes_aggregate is false:
  //    The Tensor at 'index' has not yet been written to.
  //  * If multiple_writes_aggregate is true:
  //    The Tensor at 'index' has the same shape as value.
  //
  // Side effects:
  //  * On the first write to 'index':
  //    - The underlying Tensor in 'value' has a new reference to it.
  //    - The index 'index' is marked as written.
  //  * If multiple_writes_aggregate is false, subsequent writes to 'index'
  //    raise an InvalidArgument error.
  //  * If multiple_writes_aggregate is true, subsequent writes to 'index':
  //    - The underlying Tensors in 'value' and from the first write
  //      are released and a local PersistentTensor is created.
  //    - Index 'index' is also marked as local_copy.
  //    - The gradients_disallowed flag is set true (GradientsAllowed()
  //      will now return false).
  //
  // Note, value is passed as a pointer because we its underlying
  // Tensor's shape is accessed.  Otherwise it is not modified.
  template <typename Device, typename T>
  Status WriteOrAggregate(OpKernelContext* ctx, const int32 index,
                          PersistentTensor* value) {
    mutex_lock l(mu_);
    return LockedWriteOrAggregate<Device, T>(ctx, index, value);
  }

  template <typename Device, typename T>
  Status WriteOrAggregateMany(OpKernelContext* ctx,
                              const std::vector<int32>& indices,
                              std::vector<PersistentTensor>* values) {
    mutex_lock l(mu_);
    int32 i = 0;
    for (const int32 ix : indices) {
      Status s = LockedWriteOrAggregate<Device, T>(ctx, ix, &(*values)[i]);
      ++i;
      TF_RETURN_IF_ERROR(s);
    }
    return Status::OK();
  }

  // Read from index 'index' into PersistentTensor 'value'.
  //
  // Preconditions:
  //  * The TensorArray is not closed
  //  * The index is in [0, N)
  //  * The Tensor at 'index' has been written to.
  //  * The Tensor at 'index' has not been read from with flag
  //    clear_after_read = true.
  //
  // Side effects:
  //  * If clear_after_read is true, the reference to the underlying
  //    Tensor is deleted.
  //  * The reference to the underlying Tensor at 'index' is copied to
  //    the returned '*value'.
  //  * The index is marked as read (it cannot be rewritten to).
  template <typename Device, typename T>
  Status Read(OpKernelContext* ctx, const int32 index,
              PersistentTensor* value) {
    mutex_lock l(mu_);
    return LockedRead<Device, T>(ctx, index, value);
  }

  template <typename Device, typename T>
  Status ReadMany(OpKernelContext* ctx, const std::vector<int32>& indices,
                  std::vector<PersistentTensor>* values) {
    mutex_lock l(mu_);
    values->clear();
    values->resize(indices.size());
    int32 i = 0;
    for (const int32 ix : indices) {
      Status s = LockedRead<Device, T>(ctx, ix, &(*values)[i]);
      ++i;
      if (!s.ok()) return s;
    }
    return Status::OK();
  }

  DataType ElemType() const { return dtype_; }

  PartialTensorShape ElemShape() {
    mutex_lock l(mu_);
    return element_shape_;
  }

  Status SetElemShape(const PartialTensorShape& candidate) {
    mutex_lock l(mu_);
    PartialTensorShape new_element_shape_;
    Status s = element_shape_.MergeWith(candidate, &new_element_shape_);
    if (!s.ok()) {
      return s;
    }
    element_shape_ = new_element_shape_;
    return Status::OK();
  }

  string DebugString() override {
    mutex_lock l(mu_);
    CHECK(!closed_);
    return strings::StrCat("TensorArray[", tensors_.size(), "]");
  }

  bool IsClosed() {
    mutex_lock l(mu_);
    return closed_;
  }

  // Return the size of the TensorArray.
  Status Size(int32* size) {
    mutex_lock l(mu_);
    TF_RETURN_IF_ERROR(LockedReturnIfClosed());
    *size = tensors_.size();
    return Status::OK();
  }

  // Record the size of the TensorArray after an unpack or split.
  Status SetMarkedSize(int32 size) {
    mutex_lock l(mu_);
    TF_RETURN_IF_ERROR(LockedReturnIfClosed());
    if (!is_grad_) {
      marked_size_ = size;
    }
    return Status::OK();
  }

  // Return the marked size of the TensorArray.
  Status MarkedSize(int32* size) {
    mutex_lock l(mu_);
    TF_RETURN_IF_ERROR(LockedReturnIfClosed());
    *size = marked_size_;
    return Status::OK();
  }

  // Return the size that should be used by pack or concat op.
  Status PackOrConcatSize(int32* size) {
    mutex_lock l(mu_);
    TF_RETURN_IF_ERROR(LockedReturnIfClosed());
    *size = is_grad_ ? marked_size_ : tensors_.size();
    return Status::OK();
  }

  // Once a TensorArray is being used for gradient calculations, it
  // should be marked as no longer resizeable.
  void DisableDynamicSize() {
    mutex_lock l(mu_);
    dynamic_size_ = false;
  }

  bool HasDynamicSize() {
    mutex_lock l(mu_);
    return dynamic_size_;
  }

  bool GradientsAllowed() {
    mutex_lock l(mu_);
    return !gradients_disallowed_;
  }

  bool HasIdenticalElementShapes() const { return identical_element_shapes_; }

  // Copy the TensorShapes from another TensorArray into this one.
  // If `shapes_to_prepend` is set, expands the rank of the copied shape by
  // prepending the passed in shape prefix to the shape values in `rhs`.
  // The sizes of the two TensorArrays must match and this one
  // may not have any entries filled in.  This performs a "soft copy",
  // essentially filling the current TensorArray with virtual
  // zero-tensors, which will be replaced by future aggregate writes,
  // or instantiated by future reads.  Requires a non-const pointer
  // to the rhs to access its mutex.
  Status CopyShapesFrom(TensorArray* rhs, const TensorShape* shape_to_prepend);

  // Clear the TensorArray, including any Tensor references, and mark as closed.
  void ClearAndMarkClosed() {
    mutex_lock l(mu_);
    tensors_.clear();
    closed_ = true;
  }

  mutex* mu() { return &mu_; }
  Tensor* handle() { return &handle_; }

  ResourceHandle resource_handle(OpKernelContext* ctx) {
    return MakePerStepResourceHandle<TensorArray>(ctx, key_);
  }

 private:
  Status LockedWrite(OpKernelContext* ctx, const int32 index,
                     PersistentTensor* value) EXCLUSIVE_LOCKS_REQUIRED(mu_);

  template <typename Device, typename T>
  Status LockedWriteOrAggregate(OpKernelContext* ctx, const int32 index,
                                PersistentTensor* value)
      EXCLUSIVE_LOCKS_REQUIRED(mu_);

  template <typename Device, typename T>
  Status LockedRead(OpKernelContext* ctx, const int32 index,
                    PersistentTensor* value) EXCLUSIVE_LOCKS_REQUIRED(mu_);

  Status LockedReturnIfClosed() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    if (closed_) {
      return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
                                     " has already been closed.");
    }
    return Status::OK();
  }

  const string key_;

  const DataType dtype_;
  Tensor handle_;

  mutex mu_;

  // Marks that the tensor_array_ has been cleared.
  bool closed_ GUARDED_BY(mu_);

  // Writes are allowed to grow the array.
  bool dynamic_size_;

  // Multiple writes to the same index will result in summation of the
  // values (used by backprop)
  const bool multiple_writes_aggregate_;

  // If multiple Writes were attempted (e.g. via attribute
  // multiple_writes_aggregate), then gradients are disallowed.
  bool gradients_disallowed_ GUARDED_BY(mu_);

  // After a read at an index, clear away its PersistentTensor to
  // release memory.
  const bool clear_after_read_;

  // True iff this is a gradient tensor array.
  const bool is_grad_;

  // The size of the TensorArray after a (legacy) unpack or split is performed.
  // -1 if there has been no unpack or split performed on the TensorArray.
  int32 marked_size_;

  // The shape of each element in the TensorArray, may be partially known or not
  // known at all.
  PartialTensorShape element_shape_ GUARDED_BY(mu_);

  // Whether all elements in the TensorArray have identical shapes.
  // This allows certain behaviors, like dynamically checking for
  // consistent shapes on write, and being able to fill in properly
  // shaped zero tensors on stack -- even if the initial element_shape
  // was not fully defined.
  const bool identical_element_shapes_;

  // TensorAndState is used to keep track of the PersistentTensors
  // stored in the TensorArray, along with their shapes, and a boolean
  // that determines whether they have already been read or not.
  struct TensorAndState {
    TensorAndState()
        : written(false), read(false), cleared(false), local_copy(false) {}
    PersistentTensor tensor;
    TensorShape shape;
    bool written;  // True if a Tensor has been written to the index.
    bool read;  // True if a Tensor has been written to and read from the index.
    bool cleared;  // True if a tensor has been read with
                   // clear_after_read = true;

    // Used by writes when multiple_writes_aggregate is true.  In this
    // case, the first time a value is written, it is a shallow copy.
    // The second time a value is written, it is aggregated.  However,
    // in this case a new Tensor must be constructed to hold the
    // aggregated value.  This flag marks that such a Tensor is being
    // used.  All future writes will aggregate to the existing local Tensor.
    bool local_copy;
  };
  // The list of underlying PersistentTensors and states.
  std::vector<TensorAndState> tensors_ GUARDED_BY(mu_);
};

template <typename Device, typename T>
Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx,
                                           const int32 index,
                                           PersistentTensor* value) {
  TF_RETURN_IF_ERROR(LockedReturnIfClosed());
  size_t index_size = static_cast<size_t>(index);
  if (index < 0 || (!dynamic_size_ && index_size >= tensors_.size())) {
    return errors::InvalidArgument(
        "TensorArray ", handle_.vec<string>()(1), ": Tried to write to index ",
        index, " but array is not resizeable and size is: ", tensors_.size());
  }
  if (dynamic_size_) {
    // We must grow the internal TensorArray
    if (index_size >= tensors_.capacity()) {
      tensors_.reserve(2 * (index_size + 1));
    }
    if (index_size >= tensors_.size()) {
      tensors_.resize(index_size + 1);
    }
  }
  TensorAndState& t = tensors_[index];

  Tensor* value_t = value->AccessTensor(ctx);
  if (value_t->dtype() != dtype_) {
    return errors::InvalidArgument(
        "TensorArray ", handle_.vec<string>()(1),
        ": Could not write to TensorArray index ", index,
        " because the value dtype is ", DataTypeString(value_t->dtype()),
        " but TensorArray dtype is ", DataTypeString(dtype_), ".");
  }
  if (!element_shape_.IsCompatibleWith(value_t->shape())) {
    return errors::InvalidArgument(
        "TensorArray ", handle_.vec<string>()(1),
        ": Could not write to TensorArray index ", index,
        " because the value shape is ", value_t->shape().DebugString(),
        " which is incompatible with the TensorArray's inferred element "
        "shape: ",
        element_shape_.DebugString(), " (consider setting infer_shape=False).");
  } else if (identical_element_shapes_ && !element_shape_.IsFullyDefined()) {
    element_shape_ = PartialTensorShape(value_t->shape().dim_sizes());
  }

  if (t.read) {
    return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
                                   ": Could not write to TensorArray index ",
                                   index, " because it has already been read.");
  }

  if (!multiple_writes_aggregate_ && t.written) {
    return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
                                   ": Could not write to TensorArray index ",
                                   index,
                                   " because it has already been written to.");
  }

  if (t.written) {
    DCHECK(multiple_writes_aggregate_);

    // Check that value_t shape matches t.shape
    if (value_t->shape() != t.shape) {
      return errors::InvalidArgument(
          "TensorArray ", handle_.vec<string>()(1),
          ": Could not aggregate to TensorArray index ", index,
          " because the existing shape is ", t.shape.DebugString(),
          " but the new input shape is ", value_t->shape().DebugString(), ".");
    }

    if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
      // If existing_t == nullptr but written == true, then what was stored
      // was just a shape, which just means zeros.  So all we must do in this
      // case is copy the reference over and return early.
      t.tensor = *value;
      return Status::OK();
    }

    Tensor* existing_t = t.tensor.AccessTensor(ctx);

    if (t.local_copy) {
      Status s = tensor_array::AddToTensor<Device, T>(ctx, existing_t,
                                                      existing_t, value_t);
      TF_RETURN_IF_ERROR(s);
    } else {
      PersistentTensor local_tensor;
      Tensor* local_tensor_t;
      TF_RETURN_IF_ERROR(ctx->allocate_persistent(
          dtype_, existing_t->shape(), &local_tensor, &local_tensor_t));
      Status s = tensor_array::AddToTensor<Device, T>(ctx, local_tensor_t,
                                                      existing_t, value_t);
      TF_RETURN_IF_ERROR(s);
      t.tensor = local_tensor;
      t.local_copy = true;
    }

    // We've aggregated the values, so disallow backprop on this
    // TensorArray.
    gradients_disallowed_ = true;
  } else {
    t.tensor = *value;
    t.shape = value_t->shape();
    t.written = true;
  }
  return Status::OK();
}

template <typename Device, typename T>
Status TensorArray::LockedRead(OpKernelContext* ctx, const int32 index,
                               PersistentTensor* value) {
  TF_RETURN_IF_ERROR(LockedReturnIfClosed());
  if ((index < 0) ||
      (!is_grad_ && (static_cast<size_t>(index) >= tensors_.size()))) {
    return errors::InvalidArgument("Tried to read from index ", index,
                                   " but array size is: ", tensors_.size());
  }
  size_t index_t = static_cast<size_t>(index);
  if ((is_grad_ && (index_t >= tensors_.size() || !tensors_[index].written)) ||
      (!is_grad_ && (index_t < tensors_.size() && !tensors_[index].written))) {
    // Special case returning zeros if this is a gradient read that happens
    // after a stop_gradients call with dynamic forward TensorArrays.
    // There is sometimes a race condition where the gradient is not
    // written due to stop_gradients, but is later read.
    TensorShape element_shape;
    if (is_grad_ && index_t < tensors_.size() &&
        tensors_[index].shape.dims() > 0) {
      // A gradient TensorArray has more specific gradient information
      // available for each entry.  A forward TensorArray must rely on
      // the global element_shape_ to fill in zeros on read.
      element_shape = tensors_[index].shape;
    } else if (!element_shape_.IsFullyDefined()) {
      return errors::InvalidArgument(
          "TensorArray ", handle_.vec<string>()(1),
          ": Could not read from TensorArray index ", index,
          ".  Furthermore, the element shape is not fully defined: ",
          element_shape_.DebugString(),
          ".  It is possible you are working with a resizeable TensorArray and "
          "stop_gradients is not allowing the gradients to be written.  If you "
          "set the full "
          "element_shape property on the forward TensorArray, the proper "
          "all-zeros tensor "
          "will be returned instead of incurring this error.");
    } else {
      element_shape_.AsTensorShape(&element_shape);  // Always succeeds.
    }
    if (index_t >= tensors_.size()) {
      // Fill in tensors_ up to index to have known shape.
      size_t old_tensors_size = tensors_.size();
      tensors_.resize(index + 1);
      for (size_t i = old_tensors_size; i < index + 1; ++i) {
        tensors_[i].shape = element_shape;
        tensors_[i].written = true;
      }
    } else {
      tensors_[index].shape = element_shape;
      tensors_[index].written = true;
    }
  }

  TensorAndState& t = tensors_[index];

  if (t.cleared) {
    return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
                                   ": Could not read index ", index,
                                   " twice because it was cleared after a "
                                   "previous read (perhaps try setting "
                                   "clear_after_read = false?).");
  }

  if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
    // We stored just a shape, but no value.  This means create and
    // return zeros of the appropriate shape.
    Tensor* tensor_t;
    TF_RETURN_IF_ERROR(
        ctx->allocate_persistent(dtype_, t.shape, &t.tensor, &tensor_t));
    if (t.shape.num_elements() > 0) {
      Status s = tensor_array::TensorSetZero<Device, T>(ctx, tensor_t);
      if (!s.ok()) return s;
    }
  }

  // Data is available inside the tensor, copy the reference over.
  *value = t.tensor;

  if (clear_after_read_) {
    t.tensor = PersistentTensor();
    t.cleared = true;
  }
  t.read = true;
  return Status::OK();
}

}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_