aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Zhenyu Tan <tanzheny@google.com>2018-10-01 15:52:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 16:02:01 -0700
commit6509437545f8fc973b39489c285811ea8cc8b15a (patch)
tree6059953902a30f69e9dc4ec9c8d17a697d2f2e9b /tensorflow/python/estimator
parent28a5ce4cf8702a6605e13a99c861ec6f2cd75929 (diff)
If keras_model_path is google storage url, provide util to download model
remotely. PiperOrigin-RevId: 215295504
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/keras.py48
-rw-r--r--tensorflow/python/estimator/keras_test.py6
2 files changed, 42 insertions, 12 deletions
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 7546771ed3..5d5ed81fbb 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -368,6 +368,44 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
return latest_path
+def _get_file_from_google_storage(keras_model_path, model_dir):
+ """Get file from google storage and download to local file.
+
+ Args:
+ keras_model_path: a google storage path for compiled keras model.
+ model_dir: the directory from estimator config.
+
+ Returns:
+ The path where keras model is saved.
+
+ Raises:
+ ValueError: if storage object name does not end with .h5.
+ """
+ try:
+ from google.cloud import storage # pylint:disable=g-import-not-at-top
+ except ImportError:
+ raise TypeError('Could not save model to Google cloud storage; please '
+ 'install `google-cloud-storage` via '
+ '`pip install google-cloud-storage`.')
+ storage_client = storage.Client()
+ path, blob_name = os.path.split(keras_model_path)
+ _, bucket_name = os.path.split(path)
+ keras_model_dir = os.path.join(model_dir, 'keras')
+ if not gfile.Exists(keras_model_dir):
+ gfile.MakeDirs(keras_model_dir)
+ file_name = os.path.join(keras_model_dir, 'keras_model.h5')
+ try:
+ blob = storage_client.get_bucket(bucket_name).blob(blob_name)
+ blob.download_to_filename(file_name)
+ except:
+ raise ValueError('Failed to download keras model, please check '
+ 'environment variable GOOGLE_APPLICATION_CREDENTIALS '
+ 'and model path storage.googleapis.com/{bucket}/{object}.')
+ logging.info('Saving model to {}'.format(file_name))
+ del storage_client
+ return file_name
+
+
def model_to_estimator(keras_model=None,
keras_model_path=None,
custom_objects=None,
@@ -407,12 +445,13 @@ def model_to_estimator(keras_model=None,
'Please specity either `keras_model` or `keras_model_path`, '
'but not both.')
+ config = estimator_lib.maybe_overwrite_model_dir_and_session_config(
+ config, model_dir)
if not keras_model:
if keras_model_path.startswith(
'gs://') or 'storage.googleapis.com' in keras_model_path:
- raise ValueError(
- '%s is not a local path. Please copy the model locally first.' %
- keras_model_path)
+ keras_model_path = _get_file_from_google_storage(keras_model_path,
+ config.model_dir)
logging.info('Loading models from %s', keras_model_path)
keras_model = models.load_model(keras_model_path)
else:
@@ -425,9 +464,6 @@ def model_to_estimator(keras_model=None,
'Please compile the model with `model.compile()` '
'before calling `model_to_estimator()`.')
- config = estimator_lib.maybe_overwrite_model_dir_and_session_config(config,
- model_dir)
-
keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
if _any_weight_initialized(keras_model):
# Warn if config passed to estimator tries to update GPUOptions. If a
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 288f9b8906..4e285fa25a 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -581,12 +581,6 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, 'compiled'):
keras_lib.model_to_estimator(keras_model=keras_model)
- with self.cached_session():
- keras_model = simple_sequential_model()
- with self.assertRaisesRegexp(ValueError, 'not a local path'):
- keras_lib.model_to_estimator(
- keras_model_path='gs://bucket/object')
-
def test_invalid_ionames_error(self):
(x_train, y_train), (_, _) = testing_utils.get_test_data(
train_samples=_TRAIN_SIZE,