diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-05-02 19:24:23 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-02 20:49:02 -0700 |
commit | 485a24eda09965b83af1b2218bc12c529cc35c91 (patch) | |
tree | 1d265462f8a06b76a09d83db5f07e4afa0eebb5f /tensorflow/contrib/keras | |
parent | 3af03be757b63ea6fbd28cc351d5d2323c526354 (diff) |
[tf layers] Delay marking a layer as built until the end of its first apply().
This allows the layer's call() method to call add_variable, making it much
easier to create variables while building the layer's logic.
Change: 154916035
Diffstat (limited to 'tensorflow/contrib/keras')
-rw-r--r-- | tensorflow/contrib/keras/python/keras/layers/core.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/keras/python/keras/layers/merge.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/keras/python/keras/layers/wrappers.py | 1 |
3 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/contrib/keras/python/keras/layers/core.py b/tensorflow/contrib/keras/python/keras/layers/core.py index 7a9e0d1736..0b6cdc65a4 100644 --- a/tensorflow/contrib/keras/python/keras/layers/core.py +++ b/tensorflow/contrib/keras/python/keras/layers/core.py @@ -741,6 +741,7 @@ class Dense(tf_core_layers.Dense, Layer): self.constraints[self.kernel] = self.kernel_constraint if self.use_bias and self.bias_constraint: self.constraints[self.bias] = self.bias_constraint + self.built = True def get_config(self): config = { diff --git a/tensorflow/contrib/keras/python/keras/layers/merge.py b/tensorflow/contrib/keras/python/keras/layers/merge.py index 25921979bd..b4bb9935fd 100644 --- a/tensorflow/contrib/keras/python/keras/layers/merge.py +++ b/tensorflow/contrib/keras/python/keras/layers/merge.py @@ -111,6 +111,7 @@ class _Merge(Layer): self._reshape_required = False else: self._reshape_required = True + self.built = True def call(self, inputs): if self._reshape_required: @@ -302,6 +303,7 @@ class Concatenate(_Merge): 'inputs with matching shapes ' 'except for the concat axis. ' 'Got inputs shapes: %s' % (input_shape)) + self.built = True def call(self, inputs): if not isinstance(inputs, list): @@ -414,6 +416,7 @@ class Dot(_Merge): raise ValueError('Dimension incompatibility ' '%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) + 'Layer shapes: %s, %s' % (shape1, shape2)) + self.built = True def call(self, inputs): x1 = inputs[0] diff --git a/tensorflow/contrib/keras/python/keras/layers/wrappers.py b/tensorflow/contrib/keras/python/keras/layers/wrappers.py index ce6458fd0c..092501cb11 100644 --- a/tensorflow/contrib/keras/python/keras/layers/wrappers.py +++ b/tensorflow/contrib/keras/python/keras/layers/wrappers.py @@ -166,6 +166,7 @@ class TimeDistributed(Wrapper): self.layer.build(child_input_shape) self.layer.built = True super(TimeDistributed, self).build() + self.built = True def _compute_output_shape(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() |