From 6509437545f8fc973b39489c285811ea8cc8b15a Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Mon, 1 Oct 2018 15:52:16 -0700 Subject: If keras_model_path is google storage url, provide util to download model remotely. PiperOrigin-RevId: 215295504 --- tensorflow/python/estimator/keras.py | 48 +++++++++++++++++++++++++++---- tensorflow/python/estimator/keras_test.py | 6 ---- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py index 7546771ed3..5d5ed81fbb 100644 --- a/tensorflow/python/estimator/keras.py +++ b/tensorflow/python/estimator/keras.py @@ -368,6 +368,44 @@ def _save_first_checkpoint(keras_model, custom_objects, config): return latest_path +def _get_file_from_google_storage(keras_model_path, model_dir): + """Get file from google storage and download to local file. + + Args: + keras_model_path: a google storage path for compiled keras model. + model_dir: the directory from estimator config. + + Returns: + The path where keras model is saved. + + Raises: + ValueError: if storage object name does not end with .h5. + """ + try: + from google.cloud import storage # pylint:disable=g-import-not-at-top + except ImportError: + raise TypeError('Could not save model to Google cloud storage; please ' + 'install `google-cloud-storage` via ' + '`pip install google-cloud-storage`.') + storage_client = storage.Client() + path, blob_name = os.path.split(keras_model_path) + _, bucket_name = os.path.split(path) + keras_model_dir = os.path.join(model_dir, 'keras') + if not gfile.Exists(keras_model_dir): + gfile.MakeDirs(keras_model_dir) + file_name = os.path.join(keras_model_dir, 'keras_model.h5') + try: + blob = storage_client.get_bucket(bucket_name).blob(blob_name) + blob.download_to_filename(file_name) + except: + raise ValueError('Failed to download keras model, please check ' + 'environment variable GOOGLE_APPLICATION_CREDENTIALS ' + 'and model path storage.googleapis.com/{bucket}/{object}.') + logging.info('Saving model to {}'.format(file_name)) + del storage_client + return file_name + + def model_to_estimator(keras_model=None, keras_model_path=None, custom_objects=None, @@ -407,12 +445,13 @@ def model_to_estimator(keras_model=None, 'Please specity either `keras_model` or `keras_model_path`, ' 'but not both.') + config = estimator_lib.maybe_overwrite_model_dir_and_session_config( + config, model_dir) if not keras_model: if keras_model_path.startswith( 'gs://') or 'storage.googleapis.com' in keras_model_path: - raise ValueError( - '%s is not a local path. Please copy the model locally first.' % - keras_model_path) + keras_model_path = _get_file_from_google_storage(keras_model_path, + config.model_dir) logging.info('Loading models from %s', keras_model_path) keras_model = models.load_model(keras_model_path) else: @@ -425,9 +464,6 @@ def model_to_estimator(keras_model=None, 'Please compile the model with `model.compile()` ' 'before calling `model_to_estimator()`.') - config = estimator_lib.maybe_overwrite_model_dir_and_session_config(config, - model_dir) - keras_model_fn = _create_keras_model_fn(keras_model, custom_objects) if _any_weight_initialized(keras_model): # Warn if config passed to estimator tries to update GPUOptions. If a diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py index 288f9b8906..4e285fa25a 100644 --- a/tensorflow/python/estimator/keras_test.py +++ b/tensorflow/python/estimator/keras_test.py @@ -581,12 +581,6 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, 'compiled'): keras_lib.model_to_estimator(keras_model=keras_model) - with self.cached_session(): - keras_model = simple_sequential_model() - with self.assertRaisesRegexp(ValueError, 'not a local path'): - keras_lib.model_to_estimator( - keras_model_path='gs://bucket/object') - def test_invalid_ionames_error(self): (x_train, y_train), (_, _) = testing_utils.get_test_data( train_samples=_TRAIN_SIZE, -- cgit v1.2.3