aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/layer_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/layer_utils.py')
-rw-r--r--tensorflow/python/training/checkpointable/layer_utils.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py
index fdcf963d32..d65b631fe9 100644
--- a/tensorflow/python/training/checkpointable/layer_utils.py
+++ b/tensorflow/python/training/checkpointable/layer_utils.py
@@ -30,6 +30,19 @@ def is_layer(obj):
and hasattr(obj, "variables"))
+def filter_empty_layer_containers(layer_list):
+ """Filter out empty Layer-like containers."""
+ filtered = []
+ for obj in layer_list:
+ if is_layer(obj):
+ filtered.append(obj)
+ else:
+ # Checkpointable data structures will not show up in ".layers" lists, but
+ # the layers they contain will.
+ filtered.extend(obj.layers)
+ return filtered
+
+
def gather_trainable_weights(trainable, sub_layers, extra_variables):
"""Lists the trainable weights for an object with sub-layers.