diff options
author | Akshay Agrawal <akshayka@google.com> | 2018-08-20 16:05:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-20 16:13:36 -0700 |
commit | e687764a94abc17866213d505d1dbe5e4873e1b9 (patch) | |
tree | 35d13b0c28af30134c455251fb6d9f43b7b4362f /tensorflow/contrib/eager/python | |
parent | 53e178616cfd0d06196216c1c1d708f1a3fbd3ef (diff) |
Remove graph_callable.py and all references to it.
Our APIs for creating functions are tfe.defun() and tfe.make_template();
graph_callable is no longer needed.
Additionally, change GraphModeFunction to accept a FuncGraph instead of
a CapturingGraph --- graph_callable was preventing us from doing this.
This is part of ongoing hygiene work to help us think more clearly about the function APIs we'll want to provide. GraphModeFunction (which is renamed to GraphCallable here) previously had way too many attributes.
PiperOrigin-RevId: 209502276
Diffstat (limited to 'tensorflow/contrib/eager/python')
-rw-r--r-- | tensorflow/contrib/eager/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/saver_test.py | 51 |
2 files changed, 0 insertions, 52 deletions
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index f7933639a0..fa3f1bb7ad 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -104,7 +104,6 @@ cuda_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python/eager:graph_callable", "//tensorflow/python/eager:test", "//tensorflow/python:variables", ], diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 90a3711475..91bc75213c 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -21,15 +21,11 @@ import os from tensorflow.contrib.eager.python import saver as _saver from tensorflow.python.eager import context -from tensorflow.python.eager import graph_callable from tensorflow.python.eager import test -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.training import adam from tensorflow.python.training import gradient_descent from tensorflow.python.training import momentum @@ -142,53 +138,6 @@ class SaverTest(test.TestCase): with _saver.restore_variables_on_create(ckpt_prefix): _ = model(resource_variable_ops.ResourceVariable(1.0, name='v2')) - def testSaveRestoreGraphCallable(self): - with ops.device(self._dev()): - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def model(x): - v = variable_scope.get_variable( - 'v', initializer=init_ops.zeros_initializer(), shape=()) - return v + x - - # Default 2 + 0 = 2 - self.assertEqual( - 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) - - # Save the variable value 0. - ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') - _saver.Saver(model.variables).save(ckpt_prefix) - - # update variable to 1, so that 2 + 1 = 3 - model.variables[0].assign(1.) - self.assertEqual( - 3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) - - # load the variable value 0, so that 2 + 0 = 2 - _saver.Saver(model.variables).restore(ckpt_prefix) - self.assertEqual( - 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) - - # update checkpoint variable to 1 and memory value to 2. - model.variables[0].assign(1.) - _saver.Saver(model.variables).save(ckpt_prefix) - model.variables[0].assign(2.) - self.assertEqual( - 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) - - # reset the graph and reload on create, so that 1 + 2 = 3 - ops.reset_default_graph() - with _saver.restore_variables_on_create(ckpt_prefix): - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def model2(x): - v = variable_scope.get_variable( - 'v', initializer=init_ops.zeros_initializer(), shape=()) - return v + x - - self.assertEqual( - 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) - class GetOptimizerTests(test.TestCase): |