diff options
author | Alexandre Passos <apassos@google.com> | 2018-09-27 13:18:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 13:23:04 -0700 |
commit | 4cedc8b6e738b7a188c9c091cf667bacafae44b7 (patch) | |
tree | 56de35940e5f9daedd5f39a82d2cd90cf374e4e4 /tensorflow/python/saved_model | |
parent | c898e63d07fc63315be98f0772736e5d7f2fb44c (diff) |
Updating the V2 variables API.
PiperOrigin-RevId: 214824023
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r-- | tensorflow/python/saved_model/loader_test.py | 14 | ||||
-rw-r--r-- | tensorflow/python/saved_model/saved_model_test.py | 56 |
2 files changed, 36 insertions, 34 deletions
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py index b7e217a35b..924b2e7c06 100644 --- a/tensorflow/python/saved_model/loader_test.py +++ b/tensorflow/python/saved_model/loader_test.py @@ -47,8 +47,8 @@ class SavedModelLoaderTest(test.TestCase): def setUp(self): """Write test SavedModels to a temp directory.""" with session.Session(graph=ops.Graph()) as sess: - x = variables.Variable(5, name="x") - y = variables.Variable(11, name="y") + x = variables.VariableV1(5, name="x") + y = variables.VariableV1(11, name="y") z = x + y sess.run(variables.global_variables_initializer()) @@ -134,8 +134,8 @@ class SavedModelLoaderTest(test.TestCase): def test_restore_variables(self): loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) with self.session(graph=ops.Graph()) as sess: - x = variables.Variable(0, name="x") - y = variables.Variable(0, name="y") + x = variables.VariableV1(0, name="x") + y = variables.VariableV1(0, name="y") z = x * y sess.run(variables.global_variables_initializer()) @@ -186,8 +186,10 @@ class SavedModelLoaderTest(test.TestCase): """ path = _get_export_dir("no_variable_saved_model") with session.Session(graph=ops.Graph()) as sess: - x = variables.Variable(5, name="x", collections=["not_global_variable"]) - y = variables.Variable(11, name="y", collections=["not_global_variable"]) + x = variables.VariableV1( + 5, name="x", collections=["not_global_variable"]) + y = variables.VariableV1( + 11, name="y", collections=["not_global_variable"]) self.assertFalse(variables._all_saveable_objects()) z = x + y sess.run(variables.variables_initializer([x, y])) diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 49d52d3bee..80b75b7ee6 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -60,7 +60,7 @@ class SavedModelTest(test.TestCase): return os.path.join(test.get_temp_dir(), label) def _init_and_validate_variable(self, sess, variable_name, variable_value): - v = variables.Variable(variable_value, name=variable_name) + v = variables.VariableV1(variable_value, name=variable_name) sess.run(variables.global_variables_initializer()) self.assertEqual(variable_value, v.eval()) @@ -458,7 +458,7 @@ class SavedModelTest(test.TestCase): # Graph with a single variable added to a collection. SavedModel invoked to: # - add with weights. with self.session(graph=ops.Graph()) as sess: - v = variables.Variable(42, name="v") + v = variables.VariableV1(42, name="v") ops.add_to_collection("foo_vars", v) sess.run(variables.global_variables_initializer()) self.assertEqual(42, v.eval()) @@ -468,7 +468,7 @@ class SavedModelTest(test.TestCase): # SavedModel invoked to: # - simply add the model (weights are not updated). with self.session(graph=ops.Graph()) as sess: - v = variables.Variable(43, name="v") + v = variables.VariableV1(43, name="v") ops.add_to_collection("bar_vars", v) sess.run(variables.global_variables_initializer()) self.assertEqual(43, v.eval()) @@ -780,13 +780,13 @@ class SavedModelTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: # Add `v1` and `v2` variables to the graph. - v1 = variables.Variable(1, name="v1") + v1 = variables.VariableV1(1, name="v1") ops.add_to_collection("v", v1) - v2 = variables.Variable(2, name="v2") + v2 = variables.VariableV1(2, name="v2") ops.add_to_collection("v", v2) # Initialize another variable `v3` to 42. - v3 = variables.Variable(42, name="v3") + v3 = variables.VariableV1(42, name="v3") ops.add_to_collection("v", v3) # Set up an assignment op to be run as part of the main_op. @@ -815,13 +815,13 @@ class SavedModelTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: # Add `v1` and `v2` variables to the graph. - v1 = variables.Variable(1, name="v1") + v1 = variables.VariableV1(1, name="v1") ops.add_to_collection("v", v1) - v2 = variables.Variable(2, name="v2") + v2 = variables.VariableV1(2, name="v2") ops.add_to_collection("v", v2) # Initialize another variable `v3` to 42. - v3 = variables.Variable(42, name="v3", trainable=False, collections=[]) + v3 = variables.VariableV1(42, name="v3", trainable=False, collections=[]) ops.add_to_collection("v", v3) # Set up an assignment op to be run as part of the legacy_init_op. @@ -860,11 +860,11 @@ class SavedModelTest(test.TestCase): g = ops.Graph() with self.session(graph=g) as sess: # Initialize variable `v1` to 1. - v1 = variables.Variable(1, name="v1") + v1 = variables.VariableV1(1, name="v1") ops.add_to_collection("v", v1) # Initialize another variable `v2` to 42. - v2 = variables.Variable(42, name="v2", trainable=False, collections=[]) + v2 = variables.VariableV1(42, name="v2", trainable=False, collections=[]) ops.add_to_collection("v", v2) # Set up an assignment op to be run as part of the init op. @@ -889,9 +889,9 @@ class SavedModelTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: # Add `v1` and `v2` variables to the graph. - v1 = variables.Variable(1, name="v1") + v1 = variables.VariableV1(1, name="v1") ops.add_to_collection("v", v1) - v2 = variables.Variable(2, name="v2") + v2 = variables.VariableV1(2, name="v2") ops.add_to_collection("v", v2) sess.run(variables.global_variables_initializer()) @@ -918,9 +918,9 @@ class SavedModelTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: # Add `v1` and `v2` variables to the graph. - v1 = variables.Variable(1, name="v1") + v1 = variables.VariableV1(1, name="v1") ops.add_to_collection("v", v1) - v2 = variables.Variable(2, name="v2") + v2 = variables.VariableV1(2, name="v2") ops.add_to_collection("v", v2) sess.run(variables.global_variables_initializer()) @@ -947,9 +947,9 @@ class SavedModelTest(test.TestCase): with self.session(graph=ops.Graph()) as sess: # Add `v1` and `v2` variables to the graph. - v1 = variables.Variable(1, name="v1") + v1 = variables.VariableV1(1, name="v1") ops.add_to_collection("v", v1) - v2 = variables.Variable(2, name="v2") + v2 = variables.VariableV1(2, name="v2") ops.add_to_collection("v", v2) sess.run(variables.global_variables_initializer()) @@ -1071,13 +1071,13 @@ class SavedModelTest(test.TestCase): graph=ops.Graph(), config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: with sess.graph.device("/cpu:0"): - v1 = variables.Variable(1, name="v1") + v1 = variables.VariableV1(1, name="v1") with sess.graph.device("/cpu:1"): - v2 = variables.Variable(2, name="v2") + v2 = variables.VariableV1(2, name="v2") # v3 is an unsaved variable derived from v1 and v2. It is used to # exercise the ability to run an init op when restoring a graph. - v3 = variables.Variable(1, name="v3", trainable=False, collections=[]) + v3 = variables.VariableV1(1, name="v3", trainable=False, collections=[]) assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2)) init_op = control_flow_ops.group(assign_v3, name="init_op") @@ -1140,7 +1140,7 @@ class SavedModelTest(test.TestCase): builder = saved_model_builder.SavedModelBuilder(export_dir) with self.session(graph=ops.Graph()) as sess: - variables.Variable(1, name="v1") + variables.VariableV1(1, name="v1") sess.run(variables.global_variables_initializer()) custom_saver = training.Saver(name="my_saver") builder.add_meta_graph_and_variables(sess, ["tag"], saver=custom_saver) @@ -1162,7 +1162,7 @@ class SavedModelTest(test.TestCase): builder = saved_model_builder.SavedModelBuilder(export_dir) with self.session(graph=ops.Graph()) as sess: - variables.Variable(1, name="v1") + variables.VariableV1(1, name="v1") sess.run(variables.global_variables_initializer()) training.Saver(name="my_saver") builder.add_meta_graph_and_variables(sess, ["tag"]) @@ -1184,7 +1184,7 @@ class SavedModelTest(test.TestCase): builder = saved_model_builder.SavedModelBuilder(export_dir) with self.session(graph=ops.Graph()) as sess: - variables.Variable(1, name="v1") + variables.VariableV1(1, name="v1") sess.run(variables.global_variables_initializer()) builder.add_meta_graph_and_variables(sess, ["tag_0"]) @@ -1293,8 +1293,8 @@ class SavedModelTest(test.TestCase): # Add a graph with two float32 variables and a Complex Op composing them # with strip_default_attrs enabled. with session.Session(graph=ops.Graph()) as sess: - real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") - imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") + real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real") + imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") sess.run(variables.global_variables_initializer()) builder.add_meta_graph_and_variables( @@ -1303,8 +1303,8 @@ class SavedModelTest(test.TestCase): # Add a graph with the same float32 variables and a Complex Op composing # them with strip_default_attrs disabled. with session.Session(graph=ops.Graph()) as sess: - real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") - imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") + real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real") + imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") sess.run(variables.global_variables_initializer()) builder.add_meta_graph(["bar"], strip_default_attrs=False) @@ -1366,7 +1366,7 @@ class SavedModelTest(test.TestCase): # Add a graph with a single variable and a test op with a defaultless # float32 attr, "test_attr". with session.Session(graph=ops.Graph()) as sess: - variables.Variable(1.0, dtype=dtypes.float64, name="var") + variables.VariableV1(1.0, dtype=dtypes.float64, name="var") test_ops.test_attr(T=dtypes.float32, name="test_attr") sess.run(variables.global_variables_initializer()) builder.add_meta_graph_and_variables(sess, ["foo"]) |