aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar Adam Roberts <adarob@google.com>2018-01-25 17:58:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-25 18:02:44 -0800
commit642454bd3296959f0025e1fb1730cdd95c36713f (patch)
tree3b06d0c5f609a761b793e02c967a829e5fca71cb /tensorflow/contrib/seq2seq
parentc8f1a34ba6f9bbde7e3fbb106158661ded98f0a0 (diff)
Add input_shape to seq2seq helpers.
PiperOrigin-RevId: 183321394
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/helper.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py
index ef3722ee41..6d8f786223 100644
--- a/tensorflow/contrib/seq2seq/python/ops/helper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/helper.py
@@ -73,6 +73,14 @@ class Helper(object):
raise NotImplementedError("batch_size has not been implemented")
@abc.abstractproperty
+ def input_shape(self):
+ """Shape of each input element in batch.
+
+ Returns a `TensorShape`.
+ """
+ raise NotImplementedError("input_shape has not been implemented")
+
+ @abc.abstractproperty
def sample_ids_shape(self):
"""Shape of tensor returned by `sample`, excluding the batch dimension.
@@ -127,6 +135,7 @@ class CustomHelper(Helper):
self._sample_fn = sample_fn
self._next_inputs_fn = next_inputs_fn
self._batch_size = None
+ self._input_shape = None
self._sample_ids_shape = tensor_shape.TensorShape(sample_ids_shape or [])
self._sample_ids_dtype = sample_ids_dtype or dtypes.int32
@@ -149,6 +158,8 @@ class CustomHelper(Helper):
(finished, next_inputs) = self._initialize_fn()
if self._batch_size is None:
self._batch_size = array_ops.size(finished)
+ if self._input_shape is None:
+ self._input_shape = next_inputs.shape[1:]
return (finished, next_inputs)
def sample(self, time, outputs, state, name=None):
@@ -184,6 +195,7 @@ class TrainingHelper(Helper):
"""
with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
inputs = ops.convert_to_tensor(inputs, name="inputs")
+ self._inputs = inputs
if not time_major:
inputs = nest.map_structure(_transpose_batch_time, inputs)
@@ -199,12 +211,17 @@ class TrainingHelper(Helper):
lambda inp: array_ops.zeros_like(inp[0, :]), inputs)
self._batch_size = array_ops.size(sequence_length)
+ self._input_shape = inputs.shape[2:]
@property
def batch_size(self):
return self._batch_size
@property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
def sample_ids_shape(self):
return tensor_shape.TensorShape([])
@@ -212,6 +229,14 @@ class TrainingHelper(Helper):
def sample_ids_dtype(self):
return dtypes.int32
+ @property
+ def inputs(self):
+ return self._inputs
+
+ @property
+ def sequence_length(self):
+ return self._sequence_length
+
def initialize(self, name=None):
with ops.name_scope(name, "TrainingHelperInitialize"):
finished = math_ops.equal(0, self._sequence_length)
@@ -516,12 +541,17 @@ class GreedyEmbeddingHelper(Helper):
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._start_inputs = self._embedding_fn(self._start_tokens)
+ self._input_shape = self._start_inputs.shape[1:]
@property
def batch_size(self):
return self._batch_size
@property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
def sample_ids_shape(self):
return tensor_shape.TensorShape([])
@@ -632,6 +662,8 @@ class InferenceHelper(Helper):
self._sample_dtype = sample_dtype
self._next_inputs_fn = next_inputs_fn
self._batch_size = array_ops.shape(start_inputs)[0]
+ self._input_shape = start_inputs.shape[1:]
+
self._start_inputs = ops.convert_to_tensor(
start_inputs, name="start_inputs")
@@ -640,6 +672,10 @@ class InferenceHelper(Helper):
return self._batch_size
@property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
def sample_ids_shape(self):
return self._sample_shape