aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar Adam Roberts <adarob@google.com>2018-01-29 11:51:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 11:55:12 -0800
commit5fb13ffc145af5a9c707a1388c38dd45f793b0a0 (patch)
tree9a69d41edd301751c3d286b9581ca24eb892f651 /tensorflow/contrib/seq2seq
parent0050ce16d425b3367010d58a5a3dca30aab894a4 (diff)
Add input and sequence_length accessor to TrainingHelper.
PiperOrigin-RevId: 183701716
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/helper.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py
index ef3722ee41..3245cc5e72 100644
--- a/tensorflow/contrib/seq2seq/python/ops/helper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/helper.py
@@ -184,6 +184,7 @@ class TrainingHelper(Helper):
"""
with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
inputs = ops.convert_to_tensor(inputs, name="inputs")
+ self._inputs = inputs
if not time_major:
inputs = nest.map_structure(_transpose_batch_time, inputs)
@@ -201,6 +202,14 @@ class TrainingHelper(Helper):
self._batch_size = array_ops.size(sequence_length)
@property
+ def inputs(self):
+ return self._inputs
+
+ @property
+ def sequence_length(self):
+ return self._sequence_length
+
+ @property
def batch_size(self):
return self._batch_size