aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-03-06 16:24:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-06 16:50:30 -0800
commit020501c695119f76bba0dd2bb47abfeaa939d669 (patch)
tree9039763fdcfc0336f71355fe344ec987f175e1ea /tensorflow
parent6bac75332a78da71c171e910fe79ab90fc4c272d (diff)
Hoist resource variable reads outside functions.
Change: 149360110
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/framework/function.py7
-rw-r--r--tensorflow/python/framework/function_test.py4
2 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 9e7dba6946..3a483dc41d 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import compat
@@ -324,6 +325,12 @@ class _FuncGraph(ops.Graph):
collections=collections,
use_resource=use_resource)
self.extra_vars.append(var)
+ if isinstance(var, resource_variable_ops.ResourceVariable):
+ # For resource-based variables read the variable outside the function
+ # and pass in the value. This ensures that the function is pure and
+ # differentiable. TODO(apassos) this may have performance problems if
+ # the function will only do embedding lookups on the variable.
+ return var.value()
return var
def create_op(self, op_type, inputs, data_types, **kwargs):
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index a87f5f09bb..317c6fe98d 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1178,9 +1178,7 @@ class VariableHoistingTest(test.TestCase):
self._testSimpleModel(True)
self._testSimpleModel(False)
- # TODO(b/35668241): disabled because resource variable handling inside
- # functions does not work.
- def DISABLED_testBasicResource(self):
+ def testBasicResource(self):
self._testSimpleModel(True, use_resource=True)
self._testSimpleModel(False, use_resource=True)