aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/keras
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-05-02 19:24:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-02 20:49:02 -0700
commit485a24eda09965b83af1b2218bc12c529cc35c91 (patch)
tree1d265462f8a06b76a09d83db5f07e4afa0eebb5f /tensorflow/contrib/keras
parent3af03be757b63ea6fbd28cc351d5d2323c526354 (diff)
[tf layers] Delay marking a layer as built until the end of its first apply().
This allows the layer's call() method to call add_variable, making it much easier to create variables while building the layer's logic. Change: 154916035
Diffstat (limited to 'tensorflow/contrib/keras')
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/core.py1
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/merge.py3
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/wrappers.py1
3 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/contrib/keras/python/keras/layers/core.py b/tensorflow/contrib/keras/python/keras/layers/core.py
index 7a9e0d1736..0b6cdc65a4 100644
--- a/tensorflow/contrib/keras/python/keras/layers/core.py
+++ b/tensorflow/contrib/keras/python/keras/layers/core.py
@@ -741,6 +741,7 @@ class Dense(tf_core_layers.Dense, Layer):
self.constraints[self.kernel] = self.kernel_constraint
if self.use_bias and self.bias_constraint:
self.constraints[self.bias] = self.bias_constraint
+ self.built = True
def get_config(self):
config = {
diff --git a/tensorflow/contrib/keras/python/keras/layers/merge.py b/tensorflow/contrib/keras/python/keras/layers/merge.py
index 25921979bd..b4bb9935fd 100644
--- a/tensorflow/contrib/keras/python/keras/layers/merge.py
+++ b/tensorflow/contrib/keras/python/keras/layers/merge.py
@@ -111,6 +111,7 @@ class _Merge(Layer):
self._reshape_required = False
else:
self._reshape_required = True
+ self.built = True
def call(self, inputs):
if self._reshape_required:
@@ -302,6 +303,7 @@ class Concatenate(_Merge):
'inputs with matching shapes '
'except for the concat axis. '
'Got inputs shapes: %s' % (input_shape))
+ self.built = True
def call(self, inputs):
if not isinstance(inputs, list):
@@ -414,6 +416,7 @@ class Dot(_Merge):
raise ValueError('Dimension incompatibility '
'%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) +
'Layer shapes: %s, %s' % (shape1, shape2))
+ self.built = True
def call(self, inputs):
x1 = inputs[0]
diff --git a/tensorflow/contrib/keras/python/keras/layers/wrappers.py b/tensorflow/contrib/keras/python/keras/layers/wrappers.py
index ce6458fd0c..092501cb11 100644
--- a/tensorflow/contrib/keras/python/keras/layers/wrappers.py
+++ b/tensorflow/contrib/keras/python/keras/layers/wrappers.py
@@ -166,6 +166,7 @@ class TimeDistributed(Wrapper):
self.layer.build(child_input_shape)
self.layer.built = True
super(TimeDistributed, self).build()
+ self.built = True
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()