aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-08-17 19:19:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-17 19:23:38 -0700
commitb7c242475c3d6e38ae864ae06f937c1b29c0a494 (patch)
treef6ded73c971ab178bde3f1c319e6ce867e95ae29 /tensorflow/contrib/eager
parentfd1957d8f6ed223bdc424f0bfbe6bab01a43c828 (diff)
Support nested defuns on TPU
PiperOrigin-RevId: 209239670
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/examples/densenet/densenet_test.py11
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py11
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py8
3 files changed, 13 insertions, 17 deletions
diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py
index 0736ed02b7..e5058bfd94 100644
--- a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py
+++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py
@@ -218,7 +218,7 @@ class DensenetBenchmark(tf.test.Benchmark):
tf.constant(1.).cpu()
def _benchmark_eager_apply(self, label, device_and_format, defun=False,
- execution_mode=None, compiled=False):
+ execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks,
@@ -228,7 +228,7 @@ class DensenetBenchmark(tf.test.Benchmark):
weight_decay=1e-4, dropout_rate=0,
pool_initial=True, include_top=True)
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
+ model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 30
@@ -264,8 +264,7 @@ class DensenetBenchmark(tf.test.Benchmark):
make_iterator,
device_and_format,
defun=False,
- execution_mode=None,
- compiled=False):
+ execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
for batch_size in self._train_batch_sizes():
@@ -279,8 +278,8 @@ class DensenetBenchmark(tf.test.Benchmark):
optimizer = tf.train.GradientDescentOptimizer(0.1)
apply_grads = apply_gradients
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
- apply_grads = tfe.defun(apply_gradients, compiled=compiled)
+ model.call = tfe.defun(model.call)
+ apply_grads = tfe.defun(apply_gradients)
num_burn = 3
num_iters = 10
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index 07d8788882..d265169b5e 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -216,12 +216,12 @@ class ResNet50Benchmarks(tf.test.Benchmark):
tf.constant(1.).cpu()
def _benchmark_eager_apply(self, label, device_and_format, defun=False,
- execution_mode=None, compiled=False):
+ execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
model = resnet50.ResNet50(data_format)
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
+ model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 30
@@ -257,8 +257,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
make_iterator,
device_and_format,
defun=False,
- execution_mode=None,
- compiled=False):
+ execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
for batch_size in self._train_batch_sizes():
@@ -267,8 +266,8 @@ class ResNet50Benchmarks(tf.test.Benchmark):
optimizer = tf.train.GradientDescentOptimizer(0.1)
apply_grads = apply_gradients
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
- apply_grads = tfe.defun(apply_gradients, compiled=compiled)
+ model.call = tfe.defun(model.call)
+ apply_grads = tfe.defun(apply_gradients)
num_burn = 3
num_iters = 10
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
index 84b2ddf0de..6a921e1997 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -226,14 +226,13 @@ class RevNetBenchmark(tf.test.Benchmark):
label,
device_and_format,
defun=False,
- execution_mode=None,
- compiled=False):
+ execution_mode=None):
config = config_.get_hparams_imagenet_56()
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
model = revnet.RevNet(config=config)
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
+ model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 10
@@ -271,8 +270,7 @@ class RevNetBenchmark(tf.test.Benchmark):
make_iterator,
device_and_format,
defun=False,
- execution_mode=None,
- compiled=False):
+ execution_mode=None):
config = config_.get_hparams_imagenet_56()
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format