diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-10 04:49:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-10 04:53:03 -0700 |
commit | 8955c28d591983d47fb08ff9049efdf4830b9aed (patch) | |
tree | 2525fed986d4ad84972c92e93c271c652cf449ef /tensorflow/contrib/gan | |
parent | 47dea684efa41981e10299c2737317c504ce41af (diff) |
Add build rules that were accidentally removed.
PiperOrigin-RevId: 203926475
Diffstat (limited to 'tensorflow/contrib/gan')
4 files changed, 43 insertions, 5 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index d38d770bc5..10a8796bcb 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -99,6 +99,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":gan_estimator", + ":head", "//tensorflow/python:util", ], ) @@ -434,6 +435,40 @@ py_test( ) py_library( + name = "head", + srcs = [ + "python/estimator/python/head.py", + "python/estimator/python/head_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":namedtuples", + ":train", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "head_test", + srcs = ["python/estimator/python/head_test.py"], + shard_count = 1, + srcs_version = "PY2AND3", + deps = [ + ":head", + ":namedtuples", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_library( name = "gan_estimator", srcs = [ "python/estimator/python/gan_estimator.py", diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py index 04dddb4b55..c9f7bc61b2 100644 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -25,13 +25,16 @@ from __future__ import print_function # Collapse `estimator` into a single namespace. # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.gan.python.estimator.python import gan_estimator +from tensorflow.contrib.gan.python.estimator.python import head from tensorflow.contrib.gan.python.estimator.python.gan_estimator import * +from tensorflow.contrib.gan.python.estimator.python.head import * # pylint: enable=unused-import,wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'gan_estimator', -] + gan_estimator.__all__ + 'head', +] + gan_estimator.__all__ + head.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 3cca6993ee..1a0ee6dfc4 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -79,12 +79,12 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, name=name) -@deprecation.deprecated( - None, 'Please use tf.contrib.gan.GANEstimator without explicitly making a ' - 'GANHead.') class GANHead(head._Head): # pylint: disable=protected-access """`Head` for a GAN.""" + @deprecation.deprecated( + None, 'Please use tf.contrib.gan.GANEstimator without explicitly making ' + 'a GANHead.') def __init__(self, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 5309d87765..8205bc889d 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -67,7 +67,7 @@ class GANHeadTest(test.TestCase): generator_optimizer=training.GradientDescentOptimizer(1.0), discriminator_optimizer=training.GradientDescentOptimizer(1.0), get_eval_metric_ops_fn=self.get_metrics) - self.assertTrue(isinstance(self.gan_head, head.GANHead)) + self.assertIsInstance(self.gan_head, head.GANHead) def get_metrics(self, gan_model): self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel)) |