diff options
author | Alexandre Passos <apassos@google.com> | 2017-03-06 16:24:30 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-06 16:50:30 -0800 |
commit | 020501c695119f76bba0dd2bb47abfeaa939d669 (patch) | |
tree | 9039763fdcfc0336f71355fe344ec987f175e1ea /tensorflow | |
parent | 6bac75332a78da71c171e910fe79ab90fc4c272d (diff) |
Hoist resource variable reads outside functions.
Change: 149360110
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/framework/function.py | 7 | ||||
-rw-r--r-- | tensorflow/python/framework/function_test.py | 4 |
2 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 9e7dba6946..3a483dc41d 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import compat @@ -324,6 +325,12 @@ class _FuncGraph(ops.Graph): collections=collections, use_resource=use_resource) self.extra_vars.append(var) + if isinstance(var, resource_variable_ops.ResourceVariable): + # For resource-based variables read the variable outside the function + # and pass in the value. This ensures that the function is pure and + # differentiable. TODO(apassos) this may have performance problems if + # the function will only do embedding lookups on the variable. + return var.value() return var def create_op(self, op_type, inputs, data_types, **kwargs): diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index a87f5f09bb..317c6fe98d 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -1178,9 +1178,7 @@ class VariableHoistingTest(test.TestCase): self._testSimpleModel(True) self._testSimpleModel(False) - # TODO(b/35668241): disabled because resource variable handling inside - # functions does not work. - def DISABLED_testBasicResource(self): + def testBasicResource(self): self._testSimpleModel(True, use_resource=True) self._testSimpleModel(False, use_resource=True) |