aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-03-05 18:49:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-05 18:53:40 -0800
commit73999dc944b3516d485081fe060d6916c089e412 (patch)
treea77350a24ccc5f5e95d0b2469bce37588b5c2130 /tensorflow/python
parentb5f943201afc06525818f45da28f82559fceced2 (diff)
Fixes a number of usability issues with model_to_estimator, in particular:
- make it possible to use a model that was compiled with a TF optimizer (do not require a Keras optimizer) - do not require input to be dict (input_fn supports plain arrays) - do not require `config` to be a RunConfig instance, can now be a dict (better UX) - make it possible to use a subclassed model (caveat: weights are not preserved, yet) - clear error message when model isn't compiled; improve various error messages PiperOrigin-RevId: 187959927
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/keras/_impl/keras/estimator.py291
-rw-r--r--tensorflow/python/keras/_impl/keras/estimator_test.py146
-rw-r--r--tensorflow/python/layers/base.py5
3 files changed, 374 insertions, 68 deletions
diff --git a/tensorflow/python/keras/_impl/keras/estimator.py b/tensorflow/python/keras/_impl/keras/estimator.py
index 5697771a79..081f25e914 100644
--- a/tensorflow/python/keras/_impl/keras/estimator.py
+++ b/tensorflow/python/keras/_impl/keras/estimator.py
@@ -25,11 +25,15 @@ from tensorflow.python.client import session
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import export as export_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import models
+from tensorflow.python.keras._impl.keras import optimizers
+from tensorflow.python.keras._impl.keras.engine.base_layer import Layer
+from tensorflow.python.keras._impl.keras.engine.network import Network
from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
@@ -50,36 +54,174 @@ def _cast_tensor_to_floatx(x):
return math_ops.cast(x, K.floatx())
-def _create_ordered_io(keras_model, estimator_io_dict, is_input=True):
+def _create_ordered_io(keras_model, estimator_io, is_input=True):
"""Create a list of tensors from IO dictionary based on Keras IO order.
Args:
- keras_model: an instance of compiled keras model.
- estimator_io_dict: features or labels dictionary from model_fn.
+ keras_model: An instance of compiled keras model.
+ estimator_io: The features or labels (dict or plain array) from model_fn.
is_input: True if dictionary is for inputs.
Returns:
- a list of tensors based on Keras IO order.
+ A list of tensors based on Keras IO order.
Raises:
ValueError: if dictionary keys cannot be found in Keras model input_names
or output_names.
"""
- if is_input:
- keras_io_names = keras_model.input_names
+ if isinstance(estimator_io, (list, tuple)):
+ # Case currently not supported by most built-in input_fn,
+ # but it's good to have for sanity
+ return [_cast_tensor_to_floatx(x) for x in estimator_io]
+ elif isinstance(estimator_io, dict):
+ if is_input:
+ if keras_model._is_graph_network:
+ keras_io_names = keras_model.input_names
+ else:
+ keras_io_names = [
+ 'input_%d' % i for i in range(1, len(estimator_io) + 1)]
+ else:
+ if keras_model._is_graph_network:
+ keras_io_names = keras_model.output_names
+ else:
+ keras_io_names = [
+ 'output_%d' % i for i in range(1, len(estimator_io) + 1)]
+
+ for key in estimator_io:
+ if key not in keras_io_names:
+ raise ValueError(
+ 'Cannot find %s with name "%s" in Keras Model. '
+ 'It needs to match one '
+ 'of the following: %s' % ('input' if is_input else 'output', key,
+ ', '.join(keras_io_names)))
+ tensors = [_cast_tensor_to_floatx(estimator_io[io_name])
+ for io_name in keras_io_names]
+ return tensors
else:
- keras_io_names = keras_model.output_names
+ # Plain array.
+ return _cast_tensor_to_floatx(estimator_io)
- for key in estimator_io_dict:
- if key not in keras_io_names:
- raise ValueError(
- 'Cannot find %s with name "%s" in Keras Model. It needs to match '
- 'one of the following: %s' % ('input' if is_input else 'output', key,
- ', '.join(keras_io_names)))
- tensors = []
- for io_name in keras_io_names:
- tensors.append(_cast_tensor_to_floatx(estimator_io_dict[io_name]))
- return tensors
+
+def _in_place_subclassed_model_reset(model):
+ """Substitute for model cloning that works for subclassed models.
+
+ Subclassed models cannot be cloned because their topology is not serializable.
+ To "instantiate" an identical model in a new TF graph, we reuse the original
+ model object, but we clear its state.
+
+ After calling this function on a model intance, you can use the model instance
+ as if it were a model clone (in particular you can use it in a new graph).
+
+ This method clears the state of the input model. It is thus destructive.
+ However the original state can be restored fully by calling
+ `_in_place_subclassed_model_state_restoration`.
+
+ Args:
+ model: Instance of a Keras model created via subclassing.
+
+ Raises:
+ ValueError: In case the model uses a subclassed model as inner layer.
+ """
+ assert not model._is_graph_network # Only makes sense for subclassed networks
+ # Retrieve all layers tracked by the model as well as their attribute names
+ attributes_cache = {}
+ for name in dir(model):
+ try:
+ value = getattr(model, name)
+ except (AttributeError, ValueError, TypeError):
+ continue
+ if isinstance(value, Layer):
+ attributes_cache[name] = value
+ assert value in model._layers
+ elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'):
+ # Handle case: list/tuple of layers (also tracked by the Network API).
+ if value and all(isinstance(val, Layer) for val in value):
+ raise ValueError('We do not support the use of list-of-layers '
+ 'attributes in subclassed models used with '
+ '`model_to_estimator` at this time. Found list '
+ 'model: %s' % name)
+
+ # Replace layers on the model with fresh layers
+ layers_to_names = {value: key for key, value in attributes_cache.items()}
+ original_layers = model._layers[:]
+ model._layers = []
+ for layer in original_layers: # We preserve layer order.
+ config = layer.get_config()
+ # This will not work for nested subclassed models used as layers.
+ # This would be theoretically possible to support, but would add complexity.
+ # Only do it if users complain.
+ if isinstance(layer, Network) and not layer._is_graph_network:
+ raise ValueError('We do not support the use of nested subclassed models '
+ 'in `model_to_estimator` at this time. Found nested '
+ 'model: %s' % layer)
+ fresh_layer = layer.__class__.from_config(config)
+ name = layers_to_names[layer]
+ setattr(model, name, fresh_layer)
+
+ # Cache original model build attributes (in addition to layers)
+ if (not hasattr(model, '_original_attributes_cache') or
+ model._original_attributes_cache is None):
+ if model.built:
+ attributes_to_cache = [
+ 'inputs',
+ 'outputs',
+ '_feed_outputs',
+ '_feed_output_names',
+ '_feed_output_shapes',
+ '_feed_loss_fns',
+ 'loss_weights_list',
+ 'targets',
+ '_feed_targets',
+ 'sample_weight_modes',
+ 'weighted_metrics',
+ 'metrics_names',
+ 'metrics_tensors',
+ 'metrics_updates',
+ 'stateful_metric_names',
+ 'total_loss',
+ 'sample_weights',
+ '_feed_sample_weights',
+ 'train_function',
+ 'test_function',
+ 'predict_function',
+ '_collected_trainable_weights',
+ '_feed_inputs',
+ '_feed_input_names',
+ '_feed_input_shapes',
+ 'optimizer',
+ ]
+ for name in attributes_to_cache:
+ attributes_cache[name] = getattr(model, name)
+ model._original_attributes_cache = attributes_cache
+
+ # Reset built state
+ model.built = False
+ model.inputs = None
+ model.outputs = None
+
+
+def _in_place_subclassed_model_state_restoration(model):
+ """Restores the original state of a model after it was "reset".
+
+ This undoes this action of `_in_place_subclassed_model_reset`.
+
+ Args:
+ model: Instance of a Keras model created via subclassing, on which
+ `_in_place_subclassed_model_reset` was previously called.
+ """
+ assert not model._is_graph_network
+ # Restore layers and build attributes
+ if (hasattr(model, '_original_attributes_cache') and
+ model._original_attributes_cache is not None):
+ model._layers = []
+ for name, value in model._original_attributes_cache.items():
+ setattr(model, name, value)
+ model._original_attributes_cache = None
+ else:
+ # Restore to the state of a never-called model.
+ model.built = False
+ model.inputs = None
+ model.outputs = None
def _clone_and_build_model(mode,
@@ -93,8 +235,8 @@ def _clone_and_build_model(mode,
mode: training mode.
keras_model: an instance of compiled keras model.
custom_objects: Dictionary for custom objects.
- features:
- labels:
+ features: Dict of tensors.
+ labels: Dict of tensors, or single tensor instance.
Returns:
The newly built model.
@@ -102,33 +244,49 @@ def _clone_and_build_model(mode,
# Set to True during training, False for inference.
K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
- # Clone keras model.
- input_tensors = None if features is None else _create_ordered_io(
- keras_model, features)
- if custom_objects:
- with CustomObjectScope(custom_objects):
+ # Get list of inputs.
+ if features is None:
+ input_tensors = None
+ else:
+ input_tensors = _create_ordered_io(keras_model,
+ estimator_io=features,
+ is_input=True)
+ # Get list of outputs.
+ if labels is None:
+ target_tensors = None
+ elif isinstance(labels, dict):
+ target_tensors = _create_ordered_io(keras_model,
+ estimator_io=labels,
+ is_input=False)
+ else:
+ target_tensors = [
+ _cast_tensor_to_floatx(
+ sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(labels))
+ ]
+
+ if keras_model._is_graph_network:
+ if custom_objects:
+ with CustomObjectScope(custom_objects):
+ model = models.clone_model(keras_model, input_tensors=input_tensors)
+ else:
model = models.clone_model(keras_model, input_tensors=input_tensors)
else:
- model = models.clone_model(keras_model, input_tensors=input_tensors)
+ model = keras_model
+ _in_place_subclassed_model_reset(model)
+ if input_tensors is not None:
+ model._set_inputs(input_tensors)
# Compile/Build model
- if mode is model_fn_lib.ModeKeys.PREDICT and not model.built:
- model.build()
+ if mode is model_fn_lib.ModeKeys.PREDICT:
+ if isinstance(model, models.Sequential):
+ model.build()
else:
- optimizer_config = keras_model.optimizer.get_config()
- optimizer = keras_model.optimizer.__class__.from_config(optimizer_config)
- optimizer.iterations = training_util.get_or_create_global_step()
-
- # Get list of outputs.
- if labels is None:
- target_tensors = None
- elif isinstance(labels, dict):
- target_tensors = _create_ordered_io(keras_model, labels, is_input=False)
+ if isinstance(keras_model.optimizer, optimizers.TFOptimizer):
+ optimizer = keras_model.optimizer
else:
- target_tensors = [
- _cast_tensor_to_floatx(
- sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(labels))
- ]
+ optimizer_config = keras_model.optimizer.get_config()
+ optimizer = keras_model.optimizer.__class__.from_config(optimizer_config)
+ optimizer.iterations = training_util.get_or_create_global_step()
model.compile(
optimizer,
@@ -168,10 +326,14 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
# Set loss and metric only during train and evaluate.
if mode is not model_fn_lib.ModeKeys.PREDICT:
- model._make_train_function() # pylint: disable=protected-access
+ if mode is model_fn_lib.ModeKeys.TRAIN:
+ model._make_train_function() # pylint: disable=protected-access
+ else:
+ model._make_test_function() # pylint: disable=protected-access
loss = model.total_loss
if model.metrics:
+ # TODO(fchollet): support stateful metrics
eval_metric_ops = {}
# When each metric maps to an output
if isinstance(model.metrics, dict):
@@ -195,6 +357,10 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
if mode is model_fn_lib.ModeKeys.TRAIN:
train_op = model.train_function.updates_op
+ if not model._is_graph_network:
+ # Reset model state to original state,
+ # to avoid `model_fn` being destructive for the initial model argument.
+ _in_place_subclassed_model_state_restoration(keras_model)
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=predictions,
@@ -274,10 +440,11 @@ def model_to_estimator(keras_model=None,
"""
if (not keras_model) and (not keras_model_path):
raise ValueError(
- 'Either keras_model or keras_model_path needs to be provided.')
+ 'Either `keras_model` or `keras_model_path` needs to be provided.')
if keras_model and keras_model_path:
raise ValueError(
- 'Please specity either keras_model or keras_model_path but not both.')
+ 'Please specity either `keras_model` or `keras_model_path`, '
+ 'but not both.')
if not keras_model:
if keras_model_path.startswith(
@@ -288,22 +455,42 @@ def model_to_estimator(keras_model=None,
logging.info('Loading models from %s', keras_model_path)
keras_model = models.load_model(keras_model_path)
else:
- logging.info('Using the Keras model from memory.')
+ logging.info('Using the Keras model provided.')
keras_model = keras_model
- if not hasattr(keras_model, 'optimizer'):
+ if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer:
raise ValueError(
- 'Given keras model has not been compiled yet. Please compile first '
- 'before creating the estimator.')
+ 'The given keras model has not been compiled yet. Please compile first '
+ 'before calling `model_to_estimator`.')
+
+ if isinstance(config, dict):
+ config = run_config_lib.RunConfig(**config)
keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
- est = estimator_lib.Estimator(
+ estimator = estimator_lib.Estimator(
keras_model_fn, model_dir=model_dir, config=config)
+
# Pass the config into keras backend's default session.
- with session.Session(config=est._session_config) as sess:
+ with session.Session(config=estimator._session_config) as sess:
K.set_session(sess)
keras_weights = keras_model.get_weights()
- # TODO(yifeif): move checkpoint initialization to scaffold.init_fn
- _save_first_checkpoint(keras_model, est, custom_objects, keras_weights)
- return est
+ if keras_model._is_graph_network:
+ # TODO(yifeif): move checkpoint initialization to scaffold.init_fn
+ _save_first_checkpoint(keras_model,
+ estimator,
+ custom_objects,
+ keras_weights)
+ elif keras_model.built:
+ logging.warning('You are creating an Estimator from a Keras model '
+ 'manually subclassed from `Model`, that was '
+ 'already called on some inputs (and thus already had '
+ 'weights). We are currently unable to preserve '
+ 'the model\'s state (its weights) '
+ 'as part of the estimator '
+ 'in this case. Be warned that the estimator '
+ 'has been created using '
+ 'a freshly initialized version of your model.\n'
+ 'Note that this doesn\'t affect the state of the '
+ 'model instance you passed as `keras_model` argument.')
+ return estimator
diff --git a/tensorflow/python/keras/_impl/keras/estimator_test.py b/tensorflow/python/keras/_impl/keras/estimator_test.py
index a9de5dd076..e076dc25b1 100644
--- a/tensorflow/python/keras/_impl/keras/estimator_test.py
+++ b/tensorflow/python/keras/_impl/keras/estimator_test.py
@@ -34,6 +34,7 @@ from tensorflow.python.keras._impl.keras.applications import mobilenet
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import rmsprop
try:
@@ -64,12 +65,42 @@ def simple_functional_model():
return model
-def get_resource_for_simple_model(is_sequential=True, is_evaluate=False):
- model = simple_sequential_model(
- ) if is_sequential else simple_functional_model()
- if is_sequential:
+def simple_subclassed_model():
+
+ class SimpleModel(keras.Model):
+
+ def __init__(self):
+ super(SimpleModel, self).__init__()
+ self.dense1 = keras.layers.Dense(16, activation='relu')
+ self.dp = keras.layers.Dropout(0.1)
+ self.dense2 = keras.layers.Dense(_NUM_CLASS, activation='softmax')
+
+ def call(self, inputs):
+ x = self.dense1(inputs)
+ x = self.dp(x)
+ return self.dense2(x)
+
+ return SimpleModel()
+
+
+def get_resource_for_simple_model(model_type='sequential',
+ is_evaluate=False,):
+ if model_type == 'sequential':
+ model = simple_sequential_model()
model.build()
- input_name = model.input_names[0]
+ elif model_type == 'subclass':
+ model = simple_subclassed_model()
+ else:
+ assert model_type == 'functional'
+ model = simple_functional_model()
+
+ if model_type == 'subclass':
+ input_name = 'input_1'
+ output_name = 'output_1'
+ else:
+ input_name = model.input_names[0]
+ output_name = model.output_names[0]
+
np.random.seed(_RANDOM_SEED)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=_TRAIN_SIZE,
@@ -80,17 +111,19 @@ def get_resource_for_simple_model(is_sequential=True, is_evaluate=False):
y_test = keras.utils.to_categorical(y_test)
train_input_fn = numpy_io.numpy_input_fn(
- x={input_name: x_train},
- y=y_train,
+ x=randomize_io_type(x_train, input_name),
+ y=randomize_io_type(y_train, output_name),
shuffle=False,
num_epochs=None,
batch_size=16)
evaluate_input_fn = numpy_io.numpy_input_fn(
- x={input_name: x_test}, y=y_test, num_epochs=1, shuffle=False)
+ x=randomize_io_type(x_test, input_name),
+ y=randomize_io_type(y_test, output_name),
+ num_epochs=1, shuffle=False)
predict_input_fn = numpy_io.numpy_input_fn(
- x={input_name: x_test}, num_epochs=1, shuffle=False)
+ x=randomize_io_type(x_test, input_name), num_epochs=1, shuffle=False)
inference_input_fn = evaluate_input_fn if is_evaluate else predict_input_fn
@@ -98,6 +131,14 @@ def get_resource_for_simple_model(is_sequential=True, is_evaluate=False):
y_test), train_input_fn, inference_input_fn
+def randomize_io_type(array, name):
+ switch = np.random.random()
+ if switch > 0.5:
+ return array
+ else:
+ return {name: array}
+
+
def multi_inputs_multi_outputs_model():
# test multi-input layer
a = keras.layers.Input(shape=(16,), name='input_a')
@@ -134,10 +175,10 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
gfile.DeleteRecursively(self._base_dir)
def test_train(self):
- for is_sequential in [True, False]:
+ for model_type in ['sequential', 'functional']:
keras_model, (_, _), (
_, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
- is_sequential=is_sequential, is_evaluate=True)
+ model_type=model_type, is_evaluate=True)
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
@@ -155,10 +196,87 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
writer_cache.FileWriterCache.clear()
gfile.DeleteRecursively(self._config.model_dir)
+ def test_train_with_tf_optimizer(self):
+ for model_type in ['sequential', 'functional']:
+ keras_model, (_, _), (
+ _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
+ model_type=model_type, is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ metrics=['mse', keras.metrics.categorical_accuracy])
+
+ with self.test_session():
+ est_keras = keras.estimator.model_to_estimator(
+ keras_model=keras_model,
+ # Also use dict config argument to get test coverage for that line.
+ config={
+ 'tf_random_seed': _RANDOM_SEED,
+ 'model_dir': self._base_dir,
+ })
+ before_eval_results = est_keras.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
+ after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+ self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+
+ writer_cache.FileWriterCache.clear()
+ gfile.DeleteRecursively(self._config.model_dir)
+
+ def test_train_with_subclassed_model(self):
+ keras_model, (_, _), (
+ _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
+ model_type='subclass', is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ metrics=['mse', keras.metrics.categorical_accuracy])
+
+ with self.test_session():
+ est_keras = keras.estimator.model_to_estimator(
+ keras_model=keras_model, config=self._config)
+ est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
+ before_eval_results = est_keras.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
+ after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+ self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+
+ def test_train_with_subclassed_model_with_existing_state(self):
+ keras_model, (_, _), (
+ _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
+ model_type='subclass', is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ metrics=['mse', keras.metrics.categorical_accuracy])
+
+ with self.test_session():
+ # Create state
+ keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE),
+ np.random.random((10, _NUM_CLASS)))
+ original_preds = keras_model.predict(np.ones((10,) + _INPUT_SIZE))
+
+ est_keras = keras.estimator.model_to_estimator(
+ keras_model=keras_model, config=self._config)
+ est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
+ before_eval_results = est_keras.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
+ after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+ self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+
+ # Check that original model state was not altered
+ preds = keras_model.predict(np.ones((10,) + _INPUT_SIZE))
+ self.assertAllClose(original_preds, preds, atol=1e-5)
+ # Check that the original model compilation did not break
+ keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE),
+ np.random.random((10, _NUM_CLASS)))
+
def test_evaluate(self):
keras_model, (x_train, y_train), (
x_test, y_test), _, eval_input_fn = get_resource_for_simple_model(
- is_sequential=False, is_evaluate=True)
+ model_type='functional', is_evaluate=True)
with self.test_session():
metrics = [
@@ -200,7 +318,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
# Check that predict on a pretrained model yield the same result.
keras_model, (x_train, y_train), (
x_test, _), _, pred_input_fn = get_resource_for_simple_model(
- is_sequential=True, is_evaluate=False)
+ model_type='sequential', is_evaluate=False)
with self.test_session():
keras_model.compile(
@@ -262,7 +380,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model, (x_train, y_train), (
x_test, _), _, pred_input_fn = get_resource_for_simple_model(
- is_sequential=False, is_evaluate=False)
+ model_type='functional', is_evaluate=False)
with self.test_session():
keras_model.compile(
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 2ec9971b88..c6d16a3bc0 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -127,7 +127,7 @@ class Layer(checkpointable.CheckpointableBase):
# return tensors. When using graph execution, _losses is a list of ops.
self._losses = []
self._reuse = kwargs.get('_reuse')
- self._graph = ops.get_default_graph()
+ self._graph = None # Will be set at build time.
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
call_fn_args = estimator_util.fn_args(self.call)
self._compute_previous_mask = ('mask' in call_fn_args or
@@ -630,7 +630,8 @@ class Layer(checkpointable.CheckpointableBase):
# the same graph as where it was created.
if in_graph_mode:
try:
- ops._get_graph_from_inputs(input_list, graph=self.graph) # pylint: disable=protected-access
+ # Set layer's "graph" at build time
+ self._graph = ops._get_graph_from_inputs(input_list, graph=self._graph) # pylint: disable=protected-access
except ValueError as e:
raise ValueError('Input graph and Layer graph are not the same: %s' % e)
if in_graph_mode or in_deferred_mode: