diff options
Diffstat (limited to 'tensorflow/python/training/checkpointable/layer_utils.py')
-rw-r--r-- | tensorflow/python/training/checkpointable/layer_utils.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py index fdcf963d32..d65b631fe9 100644 --- a/tensorflow/python/training/checkpointable/layer_utils.py +++ b/tensorflow/python/training/checkpointable/layer_utils.py @@ -30,6 +30,19 @@ def is_layer(obj): and hasattr(obj, "variables")) +def filter_empty_layer_containers(layer_list): + """Filter out empty Layer-like containers.""" + filtered = [] + for obj in layer_list: + if is_layer(obj): + filtered.append(obj) + else: + # Checkpointable data structures will not show up in ".layers" lists, but + # the layers they contain will. + filtered.extend(obj.layers) + return filtered + + def gather_trainable_weights(trainable, sub_layers, extra_variables): """Lists the trainable weights for an object with sub-layers. |