aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Katherine Wu <kathywu@google.com>2018-09-06 16:08:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 16:12:42 -0700
commit33d2a0e7064cd14540121e38457d4a81aa57a650 (patch)
tree9fd2f1fa98d215926f12c39dd2a1c5cc01c73dec /tensorflow/python/estimator
parent75bc3006b890bfc9c58a05097a7bce10bb30c17e (diff)
Fix bug that prevented iterations variable from updating when training an Estimator that is created from a Keras model.
PiperOrigin-RevId: 211886643
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/keras_test.py102
1 files changed, 70 insertions, 32 deletions
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 290c4604ce..7e5a0c80a7 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -26,20 +26,23 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import keras as keras_lib
+from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config as run_config_lib
-from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.optimizers import SGD
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.ops.parsing_ops import gen_parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import rmsprop
from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
try:
@@ -90,6 +93,15 @@ def simple_subclassed_model():
return SimpleModel()
+def gen_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=False):
+ def input_fn():
+ ds = dataset_ops.Dataset.from_tensor_slices((x, y) if y is not None else x)
+ if shuffle:
+ ds = ds.shuffle(1000)
+ return ds.repeat(num_epochs).batch(batch_size)
+ return input_fn
+
+
def get_resource_for_simple_model(model_type='sequential',
is_evaluate=False,):
if model_type == 'sequential':
@@ -117,19 +129,19 @@ def get_resource_for_simple_model(model_type='sequential',
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
- train_input_fn = numpy_io.numpy_input_fn(
+ train_input_fn = gen_input_fn(
x=randomize_io_type(x_train, input_name),
y=randomize_io_type(y_train, output_name),
shuffle=False,
num_epochs=None,
batch_size=16)
- evaluate_input_fn = numpy_io.numpy_input_fn(
+ evaluate_input_fn = gen_input_fn(
x=randomize_io_type(x_test, input_name),
y=randomize_io_type(y_test, output_name),
num_epochs=1, shuffle=False)
- predict_input_fn = numpy_io.numpy_input_fn(
+ predict_input_fn = gen_input_fn(
x=randomize_io_type(x_test, input_name), num_epochs=1, shuffle=False)
inference_input_fn = evaluate_input_fn if is_evaluate else predict_input_fn
@@ -203,7 +215,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer='rmsprop',
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
before_eval_results = est_keras.evaluate(
@@ -228,7 +240,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
metrics=['mse', keras.metrics.categorical_accuracy])
my_hook = MyHook()
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
before_eval_results = est_keras.evaluate(
@@ -252,7 +264,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
my_hook = MyHook()
- with self.test_session():
+ with self.cached_session():
keras_model.fit(x_train, y_train, epochs=1)
keras_est = keras_lib.model_to_estimator(
@@ -274,7 +286,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model,
config=self._config)
@@ -297,7 +309,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
@@ -316,7 +328,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
# Create state
keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE),
np.random.random((10, _NUM_CLASS)))
@@ -343,7 +355,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
x_test, y_test), _, eval_input_fn = get_resource_for_simple_model(
model_type='functional', is_evaluate=True)
- with self.test_session():
+ with self.cached_session():
metrics = [
'binary_accuracy', 'binary_crossentropy', 'categorical_accuracy',
'categorical_crossentropy', 'cosine_proximity', 'hinge',
@@ -357,7 +369,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.fit(x_train, y_train, epochs=1)
keras_eval = keras_model.evaluate(x_test, y_test, batch_size=32)
- with self.test_session():
+ with self.cached_session():
keras_est = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
est_eval = keras_est.evaluate(input_fn=eval_input_fn)
@@ -385,7 +397,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
x_test, _), _, pred_input_fn = get_resource_for_simple_model(
model_type='sequential', is_evaluate=False)
- with self.test_session():
+ with self.cached_session():
keras_model.compile(
loss='categorical_crossentropy',
optimizer='adam',
@@ -393,7 +405,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.fit(x_train, y_train, epochs=1)
keras_pred = [np.argmax(y) for y in keras_model.predict(x_test)]
- with self.test_session():
+ with self.cached_session():
keras_est = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
est_pred = [
@@ -439,7 +451,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
output_dict = {'dense_2': c_test, 'dense_3': d_test}
return input_dict, output_dict
- with self.test_session():
+ with self.cached_session():
model = multi_inputs_multi_outputs_model()
est_keras = keras_lib.model_to_estimator(
keras_model=model, config=self._config)
@@ -456,7 +468,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
x_test, _), _, pred_input_fn = get_resource_for_simple_model(
model_type='functional', is_evaluate=False)
- with self.test_session():
+ with self.cached_session():
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
@@ -466,7 +478,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
fname = os.path.join(self._base_dir, 'keras_model.h5')
keras.models.save_model(keras_model, fname)
- with self.test_session():
+ with self.cached_session():
keras_est = keras_lib.model_to_estimator(
keras_model_path=fname, config=self._config)
est_pred = [
@@ -479,19 +491,19 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, 'Either'):
keras_lib.model_to_estimator()
- with self.test_session():
+ with self.cached_session():
keras_model = simple_sequential_model()
with self.assertRaisesRegexp(ValueError, 'not both'):
keras_lib.model_to_estimator(
keras_model=keras_model,
keras_model_path=tempfile.mkdtemp(dir=self._base_dir))
- with self.test_session():
+ with self.cached_session():
keras_model = simple_sequential_model()
with self.assertRaisesRegexp(ValueError, 'compiled'):
keras_lib.model_to_estimator(keras_model=keras_model)
- with self.test_session():
+ with self.cached_session():
keras_model = simple_sequential_model()
with self.assertRaisesRegexp(ValueError, 'not a local path'):
keras_lib.model_to_estimator(
@@ -516,10 +528,10 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
model = simple_functional_model()
model.compile(
loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=model, config=self._config)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(KeyError,
'Difference: .*invalid_input_name'):
est_keras.train(input_fn=invald_input_name_input_fn, steps=100)
@@ -547,20 +559,20 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
y_train = keras.utils.to_categorical(y_train, 2)
input_name = keras_model.input_names[0]
output_name = keras_model.output_names[0]
- train_input_fn = numpy_io.numpy_input_fn(
+ train_input_fn = gen_input_fn(
x=randomize_io_type(x_train, input_name),
y=randomize_io_type(y_train, output_name),
shuffle=False,
num_epochs=None,
batch_size=16)
with self.assertRaisesRegexp(ValueError, 'relu6'):
- with self.test_session():
+ with self.cached_session():
est = keras_lib.model_to_estimator(
keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir))
est.train(input_fn=train_input_fn, steps=1)
- with self.test_session():
+ with self.cached_session():
est = keras_lib.model_to_estimator(
keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir),
@@ -586,7 +598,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
}
})
with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
- with self.test_session():
+ with self.cached_session():
keras_lib.model_to_estimator(
keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir))
@@ -602,7 +614,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
self._config._session_config = sess_config
- with self.test_session():
+ with self.cached_session():
keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
self.assertEqual(
@@ -618,7 +630,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer='rmsprop',
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, model_dir=self._base_dir,
config=run_config_lib.RunConfig())
@@ -629,7 +641,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
self.assertEqual(self._base_dir, est_keras._config.model_dir)
self.assertEqual(self._base_dir, est_keras._model_dir)
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, model_dir=self._base_dir,
config=None)
@@ -648,7 +660,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer='rmsprop',
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model,
@@ -663,7 +675,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer='rmsprop',
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in '
'constructor and `RunConfig`'):
keras_lib.model_to_estimator(
@@ -676,7 +688,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
keras_model.train_on_batch(
np.random.random((10,) + _INPUT_SIZE),
np.random.random((10, _NUM_CLASS)))
@@ -690,6 +702,32 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
+ def assert_increasing_global_step(self, optimizer):
+ keras_model, _, _, train_input_fn, _ = get_resource_for_simple_model(
+ model_type='sequential', is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=optimizer,
+ metrics=['mse', keras.metrics.categorical_accuracy])
+ with self.cached_session() as sess:
+ keras_model_fn = keras_lib._create_keras_model_fn(keras_model)
+ global_step = training_util.create_global_step()
+ features, labels = train_input_fn().make_one_shot_iterator().get_next()
+ spec = keras_model_fn(features, labels, mode=model_fn_lib.ModeKeys.TRAIN)
+
+ sess.run(variables.global_variables_initializer())
+ sess.run(variables.local_variables_initializer())
+
+ self.assertEqual(global_step.eval(), 0) # Sanity check
+ sess.run(spec.train_op)
+ self.assertEqual(global_step.eval(), 1)
+
+ def test_model_fn_increments_global_step_tf_optimizer(self):
+ self.assert_increasing_global_step(rmsprop.RMSPropOptimizer(1e-3))
+
+ def test_model_fn_increments_global_step_keras_optimizer(self):
+ self.assert_increasing_global_step('rmsprop')
+
if __name__ == '__main__':
test.main()