From 94c6e1b3e13b1456e4578eaa50e2066b1d26b40a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 Jun 2018 15:56:44 -0700 Subject: ConfigureGcsHooks: Fixed a couple of typos. - _configure_op was spelled with a trailing 's' - _block_cache_op was only conditionally set but unconditionally read. Added a fake test that triggered the bugs before and passes after. PiperOrigin-RevId: 201256874 --- tensorflow/contrib/cloud/python/ops/gcs_config_ops.py | 7 ++++++- tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) (limited to 'tensorflow/contrib/cloud') diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py index 8c8c5acb31..95e7e744d3 100644 --- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -120,13 +120,18 @@ class ConfigureGcsHook(training.SessionRunHook): def begin(self): if self._credentials: self._credentials_placeholder = array_ops.placeholder(dtypes.string) - self._credentials_ops = gen_gcs_config_ops.gcs_configure_credentials( + self._credentials_op = gen_gcs_config_ops.gcs_configure_credentials( self._credentials_placeholder) + else: + self._credentials_op = None + if self._block_cache: self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache( max_cache_size=self._block_cache.max_bytes, block_size=self._block_cache.block_size, max_staleness=self._block_cache.max_staleness) + else: + self._block_cache_op = None def after_create_session(self, session, coord): del coord diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py index fc0c994812..9b6c056d6c 100644 --- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py @@ -29,6 +29,16 @@ class GcsConfigOpsTest(test.TestCase): with self.test_session() as sess: gcs_config_ops.configure_gcs(sess, block_cache=cfg) + def testConfigureGcsHook(self): + creds = {'client_id': 'fake_client', + 'refresh_token': 'fake_token', + 'client_secret': 'fake_secret', + 'type': 'authorized_user'} + hook = gcs_config_ops.ConfigureGcsHook(credentials=creds) + hook.begin() + with self.test_session() as sess: + sess.run = lambda _, feed_dict=None, options=None, run_metadata=None: None + hook.after_create_session(sess, None) if __name__ == '__main__': test.main() -- cgit v1.2.3