diff options
author | Adam Roberts <adarob@google.com> | 2018-01-25 17:58:48 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-25 18:02:44 -0800 |
commit | 642454bd3296959f0025e1fb1730cdd95c36713f (patch) | |
tree | 3b06d0c5f609a761b793e02c967a829e5fca71cb | |
parent | c8f1a34ba6f9bbde7e3fbb106158661ded98f0a0 (diff) |
Add input_shape to seq2seq helpers.
PiperOrigin-RevId: 183321394
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/helper.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index ef3722ee41..6d8f786223 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -73,6 +73,14 @@ class Helper(object): raise NotImplementedError("batch_size has not been implemented") @abc.abstractproperty + def input_shape(self): + """Shape of each input element in batch. + + Returns a `TensorShape`. + """ + raise NotImplementedError("input_shape has not been implemented") + + @abc.abstractproperty def sample_ids_shape(self): """Shape of tensor returned by `sample`, excluding the batch dimension. @@ -127,6 +135,7 @@ class CustomHelper(Helper): self._sample_fn = sample_fn self._next_inputs_fn = next_inputs_fn self._batch_size = None + self._input_shape = None self._sample_ids_shape = tensor_shape.TensorShape(sample_ids_shape or []) self._sample_ids_dtype = sample_ids_dtype or dtypes.int32 @@ -149,6 +158,8 @@ class CustomHelper(Helper): (finished, next_inputs) = self._initialize_fn() if self._batch_size is None: self._batch_size = array_ops.size(finished) + if self._input_shape is None: + self._input_shape = next_inputs.shape[1:] return (finished, next_inputs) def sample(self, time, outputs, state, name=None): @@ -184,6 +195,7 @@ class TrainingHelper(Helper): """ with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]): inputs = ops.convert_to_tensor(inputs, name="inputs") + self._inputs = inputs if not time_major: inputs = nest.map_structure(_transpose_batch_time, inputs) @@ -199,12 +211,17 @@ class TrainingHelper(Helper): lambda inp: array_ops.zeros_like(inp[0, :]), inputs) self._batch_size = array_ops.size(sequence_length) + self._input_shape = inputs.shape[2:] @property def batch_size(self): return self._batch_size @property + def input_shape(self): + return self._input_shape + + @property def sample_ids_shape(self): return tensor_shape.TensorShape([]) @@ -212,6 +229,14 @@ class TrainingHelper(Helper): def sample_ids_dtype(self): return dtypes.int32 + @property + def inputs(self): + return self._inputs + + @property + def sequence_length(self): + return self._sequence_length + def initialize(self, name=None): with ops.name_scope(name, "TrainingHelperInitialize"): finished = math_ops.equal(0, self._sequence_length) @@ -516,12 +541,17 @@ class GreedyEmbeddingHelper(Helper): if self._end_token.get_shape().ndims != 0: raise ValueError("end_token must be a scalar") self._start_inputs = self._embedding_fn(self._start_tokens) + self._input_shape = self._start_inputs.shape[1:] @property def batch_size(self): return self._batch_size @property + def input_shape(self): + return self._input_shape + + @property def sample_ids_shape(self): return tensor_shape.TensorShape([]) @@ -632,6 +662,8 @@ class InferenceHelper(Helper): self._sample_dtype = sample_dtype self._next_inputs_fn = next_inputs_fn self._batch_size = array_ops.shape(start_inputs)[0] + self._input_shape = start_inputs.shape[1:] + self._start_inputs = ops.convert_to_tensor( start_inputs, name="start_inputs") @@ -640,6 +672,10 @@ class InferenceHelper(Helper): return self._batch_size @property + def input_shape(self): + return self._input_shape + + @property def sample_ids_shape(self): return self._sample_shape |