diff options
Diffstat (limited to 'tensorflow/python/keras/engine/sequential.py')
-rw-r--r-- | tensorflow/python/keras/engine/sequential.py | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index 371504a503..41cdfda660 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -213,13 +213,31 @@ class Sequential(Model): self.outputs = [self.layers[-1].output] self.build() - @checkpointable.no_automatic_dependency_tracking def build(self, input_shape=None): - if input_shape and not self.inputs: - batch_shape = tuple(input_shape) + self._set_inputs_and_outputs(input_shape=input_shape) + + def symbolic_set_inputs(self, inputs): + self._set_inputs_and_outputs(tensor=inputs) + + @checkpointable.no_automatic_dependency_tracking + def _set_inputs_and_outputs(self, input_shape=None, tensor=None): + """Set model's input and output specs based on the input received. + + If `tensor` is provided, `input_shape` is not required. + + Args: + input_shape: Optional shape of input. + tensor: Optional existing tensor to wrap into the `Input` layer. + """ + if not self.inputs: dtype = K.floatx() - x = Input( - batch_shape=batch_shape, dtype=dtype, name=self.name + '_input') + if tensor is not None: + batch_shape = (None,) + tuple(tensor.get_shape().as_list()[1:]) + x = Input(dtype=dtype, name=self.name + '_input', tensor=tensor) + elif input_shape is not None: + batch_shape = tuple(input_shape) + x = Input( + batch_shape=batch_shape, dtype=dtype, name=self.name + '_input') self.inputs = [x] for layer in self._layers: x = layer(x) |