aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/layer_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/layer_utils.py')
-rw-r--r--tensorflow/python/training/checkpointable/layer_utils.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py
index 978fcb2252..d65b631fe9 100644
--- a/tensorflow/python/training/checkpointable/layer_utils.py
+++ b/tensorflow/python/training/checkpointable/layer_utils.py
@@ -32,10 +32,15 @@ def is_layer(obj):
def filter_empty_layer_containers(layer_list):
"""Filter out empty Layer-like containers."""
- return [layer for layer in layer_list
- # Filter out only empty Checkpointable data structures. Empty Networks
- # will still show up in Model.layers.
- if is_layer(layer) or getattr(layer, "layers", True)]
+ filtered = []
+ for obj in layer_list:
+ if is_layer(obj):
+ filtered.append(obj)
+ else:
+ # Checkpointable data structures will not show up in ".layers" lists, but
+ # the layers they contain will.
+ filtered.extend(obj.layers)
+ return filtered
def gather_trainable_weights(trainable, sub_layers, extra_variables):