aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/sequential.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/sequential.py')
-rw-r--r--tensorflow/python/keras/engine/sequential.py28
1 files changed, 23 insertions, 5 deletions
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index 371504a503..41cdfda660 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -213,13 +213,31 @@ class Sequential(Model):
self.outputs = [self.layers[-1].output]
self.build()
- @checkpointable.no_automatic_dependency_tracking
def build(self, input_shape=None):
- if input_shape and not self.inputs:
- batch_shape = tuple(input_shape)
+ self._set_inputs_and_outputs(input_shape=input_shape)
+
+ def symbolic_set_inputs(self, inputs):
+ self._set_inputs_and_outputs(tensor=inputs)
+
+ @checkpointable.no_automatic_dependency_tracking
+ def _set_inputs_and_outputs(self, input_shape=None, tensor=None):
+ """Set model's input and output specs based on the input received.
+
+ If `tensor` is provided, `input_shape` is not required.
+
+ Args:
+ input_shape: Optional shape of input.
+ tensor: Optional existing tensor to wrap into the `Input` layer.
+ """
+ if not self.inputs:
dtype = K.floatx()
- x = Input(
- batch_shape=batch_shape, dtype=dtype, name=self.name + '_input')
+ if tensor is not None:
+ batch_shape = (None,) + tuple(tensor.get_shape().as_list()[1:])
+ x = Input(dtype=dtype, name=self.name + '_input', tensor=tensor)
+ elif input_shape is not None:
+ batch_shape = tuple(input_shape)
+ x = Input(
+ batch_shape=batch_shape, dtype=dtype, name=self.name + '_input')
self.inputs = [x]
for layer in self._layers:
x = layer(x)