aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-08-20 16:05:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 16:13:36 -0700
commite687764a94abc17866213d505d1dbe5e4873e1b9 (patch)
tree35d13b0c28af30134c455251fb6d9f43b7b4362f /tensorflow/contrib/eager/python
parent53e178616cfd0d06196216c1c1d708f1a3fbd3ef (diff)
Remove graph_callable.py and all references to it.
Our APIs for creating functions are tfe.defun() and tfe.make_template(); graph_callable is no longer needed. Additionally, change GraphModeFunction to accept a FuncGraph instead of a CapturingGraph --- graph_callable was preventing us from doing this. This is part of ongoing hygiene work to help us think more clearly about the function APIs we'll want to provide. GraphModeFunction (which is renamed to GraphCallable here) previously had way too many attributes. PiperOrigin-RevId: 209502276
Diffstat (limited to 'tensorflow/contrib/eager/python')
-rw-r--r--tensorflow/contrib/eager/python/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/saver_test.py51
2 files changed, 0 insertions, 52 deletions
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index f7933639a0..fa3f1bb7ad 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -104,7 +104,6 @@ cuda_py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
- "//tensorflow/python/eager:graph_callable",
"//tensorflow/python/eager:test",
"//tensorflow/python:variables",
],
diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py
index 90a3711475..91bc75213c 100644
--- a/tensorflow/contrib/eager/python/saver_test.py
+++ b/tensorflow/contrib/eager/python/saver_test.py
@@ -21,15 +21,11 @@ import os
from tensorflow.contrib.eager.python import saver as _saver
from tensorflow.python.eager import context
-from tensorflow.python.eager import graph_callable
from tensorflow.python.eager import test
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import variable_scope
from tensorflow.python.training import adam
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import momentum
@@ -142,53 +138,6 @@ class SaverTest(test.TestCase):
with _saver.restore_variables_on_create(ckpt_prefix):
_ = model(resource_variable_ops.ResourceVariable(1.0, name='v2'))
- def testSaveRestoreGraphCallable(self):
- with ops.device(self._dev()):
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def model(x):
- v = variable_scope.get_variable(
- 'v', initializer=init_ops.zeros_initializer(), shape=())
- return v + x
-
- # Default 2 + 0 = 2
- self.assertEqual(
- 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
- # Save the variable value 0.
- ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
- _saver.Saver(model.variables).save(ckpt_prefix)
-
- # update variable to 1, so that 2 + 1 = 3
- model.variables[0].assign(1.)
- self.assertEqual(
- 3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
- # load the variable value 0, so that 2 + 0 = 2
- _saver.Saver(model.variables).restore(ckpt_prefix)
- self.assertEqual(
- 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
- # update checkpoint variable to 1 and memory value to 2.
- model.variables[0].assign(1.)
- _saver.Saver(model.variables).save(ckpt_prefix)
- model.variables[0].assign(2.)
- self.assertEqual(
- 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
- # reset the graph and reload on create, so that 1 + 2 = 3
- ops.reset_default_graph()
- with _saver.restore_variables_on_create(ckpt_prefix):
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def model2(x):
- v = variable_scope.get_variable(
- 'v', initializer=init_ops.zeros_initializer(), shape=())
- return v + x
-
- self.assertEqual(
- 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
class GetOptimizerTests(test.TestCase):