aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-10 04:49:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 04:53:03 -0700
commit8955c28d591983d47fb08ff9049efdf4830b9aed (patch)
tree2525fed986d4ad84972c92e93c271c652cf449ef /tensorflow/contrib/gan
parent47dea684efa41981e10299c2737317c504ce41af (diff)
Add build rules that were accidentally removed.
PiperOrigin-RevId: 203926475
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/BUILD35
-rw-r--r--tensorflow/contrib/gan/python/estimator/__init__.py5
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_impl.py6
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_test.py2
4 files changed, 43 insertions, 5 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index d38d770bc5..10a8796bcb 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -99,6 +99,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":gan_estimator",
+ ":head",
"//tensorflow/python:util",
],
)
@@ -434,6 +435,40 @@ py_test(
)
py_library(
+ name = "head",
+ srcs = [
+ "python/estimator/python/head.py",
+ "python/estimator/python/head_impl.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":namedtuples",
+ ":train",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/estimator:head",
+ "//tensorflow/python/estimator:model_fn",
+ ],
+)
+
+py_test(
+ name = "head_test",
+ srcs = ["python/estimator/python/head_test.py"],
+ shard_count = 1,
+ srcs_version = "PY2AND3",
+ deps = [
+ ":head",
+ ":namedtuples",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator:model_fn",
+ ],
+)
+
+py_library(
name = "gan_estimator",
srcs = [
"python/estimator/python/gan_estimator.py",
diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py
index 04dddb4b55..c9f7bc61b2 100644
--- a/tensorflow/contrib/gan/python/estimator/__init__.py
+++ b/tensorflow/contrib/gan/python/estimator/__init__.py
@@ -25,13 +25,16 @@ from __future__ import print_function
# Collapse `estimator` into a single namespace.
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.gan.python.estimator.python import gan_estimator
+from tensorflow.contrib.gan.python.estimator.python import head
from tensorflow.contrib.gan.python.estimator.python.gan_estimator import *
+from tensorflow.contrib.gan.python.estimator.python.head import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'gan_estimator',
-] + gan_estimator.__all__
+ 'head',
+] + gan_estimator.__all__ + head.__all__
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
index 3cca6993ee..1a0ee6dfc4 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
@@ -79,12 +79,12 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
name=name)
-@deprecation.deprecated(
- None, 'Please use tf.contrib.gan.GANEstimator without explicitly making a '
- 'GANHead.')
class GANHead(head._Head): # pylint: disable=protected-access
"""`Head` for a GAN."""
+ @deprecation.deprecated(
+ None, 'Please use tf.contrib.gan.GANEstimator without explicitly making '
+ 'a GANHead.')
def __init__(self, generator_loss_fn, discriminator_loss_fn,
generator_optimizer, discriminator_optimizer,
use_loss_summaries=True,
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py
index 5309d87765..8205bc889d 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py
@@ -67,7 +67,7 @@ class GANHeadTest(test.TestCase):
generator_optimizer=training.GradientDescentOptimizer(1.0),
discriminator_optimizer=training.GradientDescentOptimizer(1.0),
get_eval_metric_ops_fn=self.get_metrics)
- self.assertTrue(isinstance(self.gan_head, head.GANHead))
+ self.assertIsInstance(self.gan_head, head.GANHead)
def get_metrics(self, gan_model):
self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel))