aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/copy_graph
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-09-27 13:18:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 13:23:04 -0700
commit4cedc8b6e738b7a188c9c091cf667bacafae44b7 (patch)
tree56de35940e5f9daedd5f39a82d2cd90cf374e4e4 /tensorflow/contrib/copy_graph
parentc898e63d07fc63315be98f0772736e5d7f2fb44c (diff)
Updating the V2 variables API.
PiperOrigin-RevId: 214824023
Diffstat (limited to 'tensorflow/contrib/copy_graph')
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_elements.py6
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_test.py4
2 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
index 6c9ab6aeb8..9c5871da34 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
@@ -31,7 +31,7 @@ from __future__ import division
from __future__ import print_function
from copy import deepcopy
-from tensorflow.python.ops.variables import Variable
+from tensorflow.python.ops.variables import VariableV1
from tensorflow.python.client.session import Session
from tensorflow.python.framework import ops
@@ -55,7 +55,7 @@ def copy_variable_to_graph(org_instance, to_graph, scope=''):
TypeError: If `org_instance` is not a `Variable`.
"""
- if not isinstance(org_instance, Variable):
+ if not isinstance(org_instance, VariableV1):
raise TypeError(str(org_instance) + ' is not a Variable')
#The name of the new variable
@@ -88,7 +88,7 @@ def copy_variable_to_graph(org_instance, to_graph, scope=''):
#Initialize the new variable
with to_graph.as_default():
- new_var = Variable(
+ new_var = VariableV1(
init_value,
trainable,
name=new_name,
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_test.py b/tensorflow/contrib/copy_graph/python/util/copy_test.py
index 05744bec4e..ba97c78456 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_test.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_test.py
@@ -36,7 +36,7 @@ class CopyVariablesTest(test.TestCase):
with graph1.as_default():
#Define a Variable in graph1
- some_var = variables.Variable(2)
+ some_var = variables.VariableV1(2)
#Initialize session
sess1 = session_lib.Session()
#Initialize the Variable
@@ -72,7 +72,7 @@ class CopyOpsTest(test.TestCase):
with graph1.as_default():
#Initialize a basic expression y = ax + b
x = array_ops.placeholder("float")
- a = variables.Variable(3.0)
+ a = variables.VariableV1(3.0)
b = constant_op.constant(4.0)
ax = math_ops.multiply(x, a)
y = math_ops.add(ax, b)