aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-04-04 15:14:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-04 15:17:25 -0700
commitbf8ad8277258bdf352ddd1df5200e61ba625f7a2 (patch)
treedac2707245c7a900f337237283a72251081b460f /tensorflow/python/layers
parent91bf5524560c5bc0783b43717156c7dbb6f798f5 (diff)
Creates a LinearModel (inherits from keras.training.Model) that creates a linear
model. Had to modify the __call__ method in the base layer class so that it could work with feature style inputs in which case we lazily convert the inputs to tensors instead of providing tensors as inputs upfront. PiperOrigin-RevId: 191655445
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/base.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 242cdff6f3..ec741d3265 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -694,7 +694,8 @@ class Layer(checkpointable.CheckpointableBase):
self._dtype = input_list[0].dtype.base_dtype.name
except AttributeError:
pass
- input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
+ if all(hasattr(x, 'get_shape') for x in input_list):
+ input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
self.build(input_shapes)
try:
# Note: not all sub-classes of Layer call Layer.__init__ (especially