aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/keras_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/keras_test.py')
-rw-r--r--tensorflow/python/estimator/keras_test.py174
1 files changed, 157 insertions, 17 deletions
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 5e094ae92b..cf4ec7f4da 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -32,13 +32,14 @@ from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
-from tensorflow.python.keras.applications import mobilenet
from tensorflow.python.keras.optimizers import SGD
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.parsing_ops import gen_parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import rmsprop
+from tensorflow.python.training import session_run_hook
try:
@@ -51,6 +52,8 @@ _TRAIN_SIZE = 200
_INPUT_SIZE = (10,)
_NUM_CLASS = 2
+_TMP_DIR = '/tmp'
+
def simple_sequential_model():
model = keras.models.Sequential()
@@ -60,9 +63,9 @@ def simple_sequential_model():
return model
-def simple_functional_model():
+def simple_functional_model(activation='relu'):
a = keras.layers.Input(shape=_INPUT_SIZE)
- b = keras.layers.Dense(16, activation='relu')(a)
+ b = keras.layers.Dense(16, activation=activation)(a)
b = keras.layers.Dropout(0.1)(b)
b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b)
model = keras.models.Model(inputs=[a], outputs=[b])
@@ -168,6 +171,12 @@ def multi_inputs_multi_outputs_model():
return model
+class MyHook(session_run_hook.SessionRunHook):
+
+ def begin(self):
+ _ = variable_scope.get_variable('temp', [1])
+
+
class TestKerasEstimator(test_util.TensorFlowTestCase):
def setUp(self):
@@ -204,6 +213,55 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
writer_cache.FileWriterCache.clear()
gfile.DeleteRecursively(self._config.model_dir)
+ # see b/109935364
+ @test_util.run_in_graph_and_eager_modes
+ def test_train_with_hooks(self):
+ for model_type in ['sequential', 'functional']:
+ keras_model, (_, _), (
+ _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
+ model_type=model_type, is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ metrics=['mse', keras.metrics.categorical_accuracy])
+
+ my_hook = MyHook()
+ with self.test_session():
+ est_keras = keras_lib.model_to_estimator(
+ keras_model=keras_model, config=self._config)
+ before_eval_results = est_keras.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ est_keras.train(input_fn=train_input_fn, hooks=[my_hook],
+ steps=_TRAIN_SIZE / 16)
+ after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+ self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+
+ writer_cache.FileWriterCache.clear()
+ gfile.DeleteRecursively(self._config.model_dir)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_train_with_model_fit_and_hooks(self):
+ keras_model, (x_train, y_train), _, \
+ train_input_fn, eval_input_fn = get_resource_for_simple_model(
+ model_type='sequential', is_evaluate=True)
+
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ metrics=['mse', keras.metrics.categorical_accuracy])
+ my_hook = MyHook()
+ with self.test_session():
+ keras_model.fit(x_train, y_train, epochs=1)
+
+ keras_est = keras_lib.model_to_estimator(
+ keras_model=keras_model, config=self._config)
+ before_eval_results = keras_est.evaluate(input_fn=eval_input_fn)
+ keras_est.train(input_fn=train_input_fn, hooks=[my_hook],
+ steps=_TRAIN_SIZE / 16)
+ after_eval_results = keras_est.evaluate(input_fn=eval_input_fn, steps=1)
+ self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+
+ @test_util.run_in_graph_and_eager_modes
def test_train_with_tf_optimizer(self):
for model_type in ['sequential', 'functional']:
keras_model, (_, _), (
@@ -231,6 +289,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
writer_cache.FileWriterCache.clear()
gfile.DeleteRecursively(self._config.model_dir)
+ @test_util.run_in_graph_and_eager_modes
def test_train_with_subclassed_model(self):
keras_model, (_, _), (
_, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
@@ -472,23 +531,43 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
est_keras.train(input_fn=invald_output_name_input_fn, steps=100)
def test_custom_objects(self):
- keras_mobile = mobilenet.MobileNet(weights=None)
- keras_mobile.compile(loss='categorical_crossentropy', optimizer='adam')
+
+ def relu6(x):
+ return keras.backend.relu(x, max_value=6)
+
+ keras_model = simple_functional_model(activation=relu6)
+ keras_model.compile(loss='categorical_crossentropy', optimizer='adam')
custom_objects = {
- 'relu6': mobilenet.relu6,
- 'DepthwiseConv2D': mobilenet.DepthwiseConv2D
+ 'relu6': relu6
}
+
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(10,),
+ num_classes=2)
+ y_train = keras.utils.to_categorical(y_train, 2)
+ input_name = keras_model.input_names[0]
+ output_name = keras_model.output_names[0]
+ train_input_fn = numpy_io.numpy_input_fn(
+ x=randomize_io_type(x_train, input_name),
+ y=randomize_io_type(y_train, output_name),
+ shuffle=False,
+ num_epochs=None,
+ batch_size=16)
with self.assertRaisesRegexp(ValueError, 'relu6'):
with self.test_session():
- keras_lib.model_to_estimator(
- keras_model=keras_mobile,
+ est = keras_lib.model_to_estimator(
+ keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir))
+ est.train(input_fn=train_input_fn, steps=1)
with self.test_session():
- keras_lib.model_to_estimator(
- keras_model=keras_mobile,
+ est = keras_lib.model_to_estimator(
+ keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir),
custom_objects=custom_objects)
+ est.train(input_fn=train_input_fn, steps=1)
def test_tf_config(self):
keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
@@ -525,12 +604,73 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
self._config._session_config = sess_config
- keras_lib.model_to_estimator(
- keras_model=keras_model, config=self._config)
- self.assertEqual(
- keras.backend.get_session()
- ._config.gpu_options.per_process_gpu_memory_fraction,
- gpu_options.per_process_gpu_memory_fraction)
+ with self.test_session():
+ keras_lib.model_to_estimator(
+ keras_model=keras_model, config=self._config)
+ self.assertEqual(
+ keras.backend.get_session()
+ ._config.gpu_options.per_process_gpu_memory_fraction,
+ gpu_options.per_process_gpu_memory_fraction)
+
+ def test_with_empty_config(self):
+ keras_model, _, _, _, _ = get_resource_for_simple_model(
+ model_type='sequential', is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer='rmsprop',
+ metrics=['mse', keras.metrics.categorical_accuracy])
+
+ with self.test_session():
+ est_keras = keras_lib.model_to_estimator(
+ keras_model=keras_model, model_dir=self._base_dir,
+ config=run_config_lib.RunConfig())
+ self.assertEqual(run_config_lib.get_default_session_config(),
+ est_keras._session_config)
+ self.assertEqual(est_keras._session_config,
+ est_keras._config.session_config)
+ self.assertEqual(self._base_dir, est_keras._config.model_dir)
+ self.assertEqual(self._base_dir, est_keras._model_dir)
+
+ with self.test_session():
+ est_keras = keras_lib.model_to_estimator(
+ keras_model=keras_model, model_dir=self._base_dir,
+ config=None)
+ self.assertEqual(run_config_lib.get_default_session_config(),
+ est_keras._session_config)
+ self.assertEqual(est_keras._session_config,
+ est_keras._config.session_config)
+ self.assertEqual(self._base_dir, est_keras._config.model_dir)
+ self.assertEqual(self._base_dir, est_keras._model_dir)
+
+ def test_with_empty_config_and_empty_model_dir(self):
+ keras_model, _, _, _, _ = get_resource_for_simple_model(
+ model_type='sequential', is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer='rmsprop',
+ metrics=['mse', keras.metrics.categorical_accuracy])
+
+ with self.test_session():
+ with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
+ est_keras = keras_lib.model_to_estimator(
+ keras_model=keras_model,
+ config=run_config_lib.RunConfig())
+ self.assertEqual(est_keras._model_dir, _TMP_DIR)
+
+ def test_with_conflicting_model_dir_and_config(self):
+ keras_model, _, _, _, _ = get_resource_for_simple_model(
+ model_type='sequential', is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer='rmsprop',
+ metrics=['mse', keras.metrics.categorical_accuracy])
+
+ with self.test_session():
+ with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in '
+ 'constructor and `RunConfig`'):
+ keras_lib.model_to_estimator(
+ keras_model=keras_model, model_dir=self._base_dir,
+ config=run_config_lib.RunConfig(model_dir=_TMP_DIR))
def test_pretrained_weights(self):
keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()