aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/keras.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/keras.py')
-rw-r--r--tensorflow/python/estimator/keras.py153
1 files changed, 106 insertions, 47 deletions
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 5769f5739c..70517ae278 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -21,11 +21,14 @@ from __future__ import print_function
import os
import re
+import tempfile
+
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import export as export_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config as run_config_lib
+from tensorflow.python.estimator.run_config import RunConfig
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
@@ -39,12 +42,14 @@ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
-from tensorflow.python.ops import variables as variables_module
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
+from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import data_structures
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -69,16 +74,22 @@ def _convert_tensor(x):
return x
-def _any_variable_initialized():
- """Check if any variable has been initialized in the Keras model.
+def _any_weight_initialized(keras_model):
+ """Check if any weights has been initialized in the Keras model.
+
+ Args:
+ keras_model: An instance of compiled keras model.
Returns:
- boolean, True if at least one variable has been initialized, else False.
+ boolean, True if at least one weight has been initialized, else False.
+ Currently keras initialize all weights at get_session().
"""
- variables = variables_module.global_variables()
- for v in variables:
- if getattr(v, '_keras_initialized', False):
- return True
+ if keras_model is None:
+ return False
+ for layer in keras_model.layers:
+ for weight in layer.weights:
+ if hasattr(weight, '_keras_initialized'):
+ return True
return False
@@ -173,7 +184,7 @@ def _in_place_subclassed_model_reset(model):
# Replace layers on the model with fresh layers
layers_to_names = {value: key for key, value in attributes_cache.items()}
original_layers = model._layers[:]
- model._layers = []
+ model._layers = data_structures.NoDependency([])
for layer in original_layers: # We preserve layer order.
config = layer.get_config()
# This will not work for nested subclassed models used as layers.
@@ -221,7 +232,8 @@ def _in_place_subclassed_model_reset(model):
]
for name in attributes_to_cache:
attributes_cache[name] = getattr(model, name)
- model._original_attributes_cache = attributes_cache
+ model._original_attributes_cache = data_structures.NoDependency(
+ attributes_cache)
# Reset built state
model.built = False
model.inputs = None
@@ -241,8 +253,17 @@ def _in_place_subclassed_model_state_restoration(model):
# Restore layers and build attributes
if (hasattr(model, '_original_attributes_cache') and
model._original_attributes_cache is not None):
- model._layers = []
+ # Models have sticky attribute assignment, so we want to be careful to add
+ # back the previous attributes and track Layers by their original names
+ # without adding dependencies on "utility" attributes which Models exempt
+ # when they're constructed.
+ model._layers = data_structures.NoDependency([])
for name, value in model._original_attributes_cache.items():
+ if not isinstance(value, checkpointable.CheckpointableBase):
+ # If this value is not already checkpointable, it's probably that way
+ # for a reason; we don't want to start tracking data structures that the
+ # original Model didn't.
+ value = data_structures.NoDependency(value)
setattr(model, name, value)
model._original_attributes_cache = None
else:
@@ -410,29 +431,34 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
return model_fn
-def _save_first_checkpoint(keras_model, estimator, custom_objects,
- keras_weights):
+def _save_first_checkpoint(keras_model, custom_objects, config):
"""Save first checkpoint for the keras Estimator.
Args:
keras_model: an instance of compiled keras model.
- estimator: keras estimator.
custom_objects: Dictionary for custom objects.
- keras_weights: A flat list of Numpy arrays for weights of given keras_model.
+ config: Estimator config.
Returns:
- The model_fn for a keras Estimator.
+ The path where keras model checkpoint is saved.
"""
+ # save checkpoint into subdirectory to allow warm start
+ keras_model_dir = os.path.join(config.model_dir, 'keras')
# Load weights and save to checkpoint if there is no checkpoint
- latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
+ latest_path = saver_lib.latest_checkpoint(keras_model_dir)
if not latest_path:
+ keras_weights = None
+ if _any_weight_initialized(keras_model):
+ keras_weights = keras_model.get_weights()
+ if not gfile.IsDirectory(keras_model_dir):
+ gfile.MakeDirs(keras_model_dir)
with ops.Graph().as_default():
- random_seed.set_random_seed(estimator.config.tf_random_seed)
+ random_seed.set_random_seed(config.tf_random_seed)
training_util.create_global_step()
model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
custom_objects)
# save to checkpoint
- with session.Session(config=estimator._session_config) as sess:
+ with session.Session(config=config.session_config) as sess:
if keras_weights:
model.set_weights(keras_weights)
# Make update ops and initialize all variables.
@@ -442,7 +468,46 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects,
K._initialize_variables(sess)
# pylint: enable=protected-access
saver = saver_lib.Saver()
- saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt'))
+ latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt')
+ saver.save(sess, latest_path)
+ return latest_path
+
+
+def _maybe_overwrite_model_dir_and_session_config(config, model_dir):
+ """Overwrite estimator config by `model_dir` and `session_config` if needed.
+
+ Args:
+ config: Original estimator config.
+ model_dir: Estimator model checkpoint directory.
+
+ Returns:
+ Overwritten estimator config.
+
+ Raises:
+ ValueError: Model directory inconsistent between `model_dir` and `config`.
+ """
+
+ default_session_config = run_config_lib.get_default_session_config()
+ if isinstance(config, dict):
+ config = RunConfig(**config)
+ elif config is None:
+ config = RunConfig(session_config=default_session_config)
+ if config.session_config is None:
+ config = RunConfig.replace(config, session_config=default_session_config)
+
+ if model_dir is not None:
+ if (getattr(config, 'model_dir', None) is not None and
+ config.model_dir != model_dir):
+ raise ValueError(
+ "`model_dir` are set both in constructor and `RunConfig`, but with "
+ "different values. In constructor: '{}', in `RunConfig`: "
+ "'{}' ".format(model_dir, config.model_dir))
+ config = RunConfig.replace(config, model_dir=model_dir)
+ elif getattr(config, 'model_dir', None) is None:
+ model_dir = tempfile.mkdtemp()
+ config = RunConfig.replace(config, model_dir=model_dir)
+
+ return config
def model_to_estimator(keras_model=None,
@@ -501,45 +566,39 @@ def model_to_estimator(keras_model=None,
'Please compile the model with `model.compile()` '
'before calling `model_to_estimator()`.')
- if isinstance(config, dict):
- config = run_config_lib.RunConfig(**config)
+ config = _maybe_overwrite_model_dir_and_session_config(config, model_dir)
keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
- estimator = estimator_lib.Estimator(
- keras_model_fn, model_dir=model_dir, config=config)
-
- # Check if we need to call get_weights:
- if _any_variable_initialized():
- keras_weights = keras_model.get_weights()
+ if _any_weight_initialized(keras_model):
# Warn if config passed to estimator tries to update GPUOptions. If a
# session has already been created, the GPUOptions passed to the first
# session sticks.
- if estimator._session_config.HasField('gpu_options'):
+ if config.session_config.HasField('gpu_options'):
logging.warning(
'The Keras backend session has already been set. '
'The _session_config passed to model_to_estimator will not be used.')
else:
# Pass the config into keras backend's default session.
- sess = session.Session(config=estimator._session_config)
+ sess = session.Session(config=config.session_config)
K.set_session(sess)
- keras_weights = None
+ warm_start_path = None
if keras_model._is_graph_network:
- # TODO(yifeif): move checkpoint initialization to scaffold.init_fn
- _save_first_checkpoint(keras_model,
- estimator,
- custom_objects,
- keras_weights)
+ warm_start_path = _save_first_checkpoint(keras_model, custom_objects,
+ config)
elif keras_model.built:
- logging.warning('You are creating an Estimator from a Keras model '
- 'manually subclassed from `Model`, that was '
- 'already called on some inputs (and thus already had '
- 'weights). We are currently unable to preserve '
- 'the model\'s state (its weights) '
- 'as part of the estimator '
- 'in this case. Be warned that the estimator '
- 'has been created using '
- 'a freshly initialized version of your model.\n'
- 'Note that this doesn\'t affect the state of the '
- 'model instance you passed as `keras_model` argument.')
+ logging.warning('You are creating an Estimator from a Keras model manually '
+ 'subclassed from `Model`, that was already called on some '
+ 'inputs (and thus already had weights). We are currently '
+ 'unable to preserve the model\'s state (its weights) as '
+ 'part of the estimator in this case. Be warned that the '
+ 'estimator has been created using a freshly initialized '
+ 'version of your model.\n'
+ 'Note that this doesn\'t affect the state of the model '
+ 'instance you passed as `keras_model` argument.')
+
+ estimator = estimator_lib.Estimator(keras_model_fn,
+ config=config,
+ warm_start_from=warm_start_path)
+
return estimator