aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/sequential.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/sequential.py')
-rw-r--r--tensorflow/python/keras/engine/sequential.py30
1 files changed, 26 insertions, 4 deletions
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index cd76f08a32..41cdfda660 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -29,6 +29,7 @@ from tensorflow.python.keras.engine.input_layer import InputLayer
from tensorflow.python.keras.engine.training import Model
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -108,6 +109,7 @@ class Sequential(Model):
return self._layers[1:]
return self._layers
+ @checkpointable.no_automatic_dependency_tracking
def add(self, layer):
"""Adds a layer instance on top of the layer stack.
@@ -191,6 +193,7 @@ class Sequential(Model):
else:
self._layers.append(layer)
+ @checkpointable.no_automatic_dependency_tracking
def pop(self):
"""Removes the last layer in the model.
@@ -211,11 +214,30 @@ class Sequential(Model):
self.build()
def build(self, input_shape=None):
- if input_shape and not self.inputs:
- batch_shape = tuple(input_shape)
+ self._set_inputs_and_outputs(input_shape=input_shape)
+
+ def symbolic_set_inputs(self, inputs):
+ self._set_inputs_and_outputs(tensor=inputs)
+
+ @checkpointable.no_automatic_dependency_tracking
+ def _set_inputs_and_outputs(self, input_shape=None, tensor=None):
+ """Set model's input and output specs based on the input received.
+
+ If `tensor` is provided, `input_shape` is not required.
+
+ Args:
+ input_shape: Optional shape of input.
+ tensor: Optional existing tensor to wrap into the `Input` layer.
+ """
+ if not self.inputs:
dtype = K.floatx()
- x = Input(
- batch_shape=batch_shape, dtype=dtype, name=self.name + '_input')
+ if tensor is not None:
+ batch_shape = (None,) + tuple(tensor.get_shape().as_list()[1:])
+ x = Input(dtype=dtype, name=self.name + '_input', tensor=tensor)
+ elif input_shape is not None:
+ batch_shape = tuple(input_shape)
+ x = Input(
+ batch_shape=batch_shape, dtype=dtype, name=self.name + '_input')
self.inputs = [x]
for layer in self._layers:
x = layer(x)