aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-01-17 23:31:46 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2016-01-18 12:22:23 -0800
commita84a81a7379507f8fcdd0d6118afc2d5044d159e (patch)
tree4059e7f6c5b8893ef6d637293ebc55bcbbfcf548
parent6f62e435ab6c36dfdfdef1acd580b5f278f6723c (diff)
* Supports nested function calls;
* Supports uses python class method as a tf function prototype. Derives the generate function name from the class name and the method name. * Changes one LOG(INFO) to VLOG(1), which is too verbose. Change: 112384148
-rw-r--r--tensorflow/core/framework/op.cc2
-rw-r--r--tensorflow/python/framework/function.py72
-rw-r--r--tensorflow/python/framework/function_test.py61
-rw-r--r--tensorflow/python/framework/ops.py10
-rw-r--r--tensorflow/python/ops/gradients_test.py19
5 files changed, 100 insertions, 64 deletions
diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc
index bbe722e329..6d42353630 100644
--- a/tensorflow/core/framework/op.cc
+++ b/tensorflow/core/framework/op.cc
@@ -60,7 +60,7 @@ const OpDef* OpRegistry::LookUp(const string& op_type_name,
if (op_def == nullptr) {
status->Update(
errors::NotFound("Op type not registered '", op_type_name, "'"));
- LOG(INFO) << status->ToString();
+ VLOG(1) << status->ToString();
static bool first_unregistered = true;
if (first_unregistered) {
OpList op_list;
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 5ac7e385c6..307b7041a4 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -41,18 +41,6 @@ def _tensor_to_argdef(t):
return arg
-def _is_array_type_input(op, i):
- registered_ops = op_def_registry.get_registered_ops()
- if op not in registered_ops:
- return False
- op_def = registered_ops[op]
- if i not in xrange(len(op_def.input_arg)):
- raise TypeError("Expected arg index "
- "to be in [0, %d)" % len(op_def.input_arg))
- input_arg = op_def.input_arg[i]
- return True if input_arg.number_attr else False
-
-
def _get_node_def_attr(op):
# pylint: disable=protected-access
return op._node_def.attr
@@ -93,11 +81,16 @@ def _add_output_array_to_list(op, start, limit, dtype, func):
return arg_name
-def _add_op_node(op, func):
+def _add_op_node(graph, op, func):
"""Converts an op to a function def node and add it to `func`."""
node = function_pb2.FunctionDef.Node()
node.op = op.type
- op_def = op_def_registry.get_registered_ops()[op.type]
+ # pylint: disable=protected-access
+ if graph._is_function(op.type):
+ op_def = graph._get_function(op.type).signature
+ else:
+ op_def = op_def_registry.get_registered_ops()[op.type]
+ # pylint: enable=protected-access
attrs = _get_node_def_attr(op)
out_index = 0
for arg_def in op_def.output_arg:
@@ -173,10 +166,11 @@ def graph_to_function_def(graph, name, inputs, outputs):
func.signature.output_arg.extend([_tensor_to_argdef(graph.get_tensor_by_name(
o.name)) for o in outputs])
func_arg_placeholders = set([i.name for i in inputs])
+ g = ops.get_default_graph()
for op in graph.get_operations():
tensor_name = op.values()[0].name
if tensor_name not in func_arg_placeholders:
- _add_op_node(op, func)
+ _add_op_node(g, op, func)
return func
@@ -220,9 +214,8 @@ def call_function(func_def, *inputs, **kwargs):
raise ValueError("Expected number of arguments: %d" %
len(func_def.signature.input_arg))
output_types = [dtypes.DType(x.type) for x in func_def.signature.output_arg]
- g = ops.get_default_graph()
- g._add_function(func_def) # pylint: disable=protected-access
# TODO(touts): Pass compute_shapes as "try if function exists"
+ g = ops.get_default_graph()
op = g.create_op(func_name,
list(inputs),
output_types,
@@ -248,16 +241,13 @@ def define_function(func, input_types):
names arguments to `func`. The value indicate the type of tensor expected
by the function.
- The returned `FunctionDef` protocol buffer can be passed to
- `tf.add_function()` to add the function to the default graph library. After
- it has been added you can add calls to the function by passing it to
- `tf.call_function()`, together with a list of tensors to use as inputs for
- the function.
+ The returned `FunctionDef` protocol buffer is also added to the
+ default graph library. After it has been added you can add calls to
+ the function by passing it to `tf.call_function()`, together with a
+ list of tensors to use as inputs for the function.
Notes:
- * The default graph is not changed in any way: no ops are added to it, the
- returned `FunctionDef` is not added to its function library.
* `func` is called once, with `placeholder` tensors of the types specified in
`input_types` as arguments.
* Values returned by `func` must be tensors and they are recorded as being
@@ -294,23 +284,38 @@ def define_function(func, input_types):
Raises:
ValueError: if the arguments are invalid.
+
"""
# TODO(touts): Lift the limitation that func can only receive Tensor args.
- if not inspect.isfunction(func):
+ if inspect.isfunction(func):
+ func_name = func.__name__
+ elif inspect.ismethod(func):
+ func_name = func.im_self.__name__ + "." + func.__name__
+ else:
raise ValueError("Argument must be a function")
argspec = inspect.getargspec(func)
if argspec.varargs or argspec.keywords or argspec.defaults:
raise ValueError("Only functions with plain arglists are supported.")
- if len(argspec.args) != len(input_types):
- raise ValueError("The function must have the same number of arguments "
- "as the number of specified input types.")
+ if inspect.isfunction(func):
+ if len(argspec.args) != len(input_types):
+ raise ValueError("The function must have the same number of arguments "
+ "as the number of specified input types.")
+ args = argspec.args
+ elif inspect.ismethod(func):
+ if len(argspec.args) != 1 + len(input_types):
+ raise ValueError(
+ "The class function must have the same number of arguments "
+ "as the number of specified input types.")
+ args = argspec.args[1:] # 1st argument is the "class" type.
+
# Create the func_def object.
- with ops.Graph().as_default() as temp_graph:
+ temp_graph = ops.Graph()
+ with temp_graph.as_default():
# List of placeholders for the function_def.
inputs = []
# Arglist to call 'func'
kwargs = {}
- for argname in argspec.args:
+ for argname in args:
if argname not in input_types:
raise ValueError("Missing type for argument: " + argname)
argholder = array_ops.placeholder(input_types[argname], name=argname)
@@ -323,8 +328,11 @@ def define_function(func, input_types):
# Convenience: if func only returned one value, make it a tuple.
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
- # Build and return the FunctionDef
- return graph_to_function_def(temp_graph, func.__name__, inputs, outputs)
+ # Build the FunctionDef
+ func_def = graph_to_function_def(temp_graph, func_name, inputs, outputs)
+ g = ops.get_default_graph()
+ g._add_function(func_def) # pylint: disable=protected-access
+ return func_def
class Defun(object):
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 29e85e1253..a6e19b825a 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -121,15 +121,13 @@ class FunctionTest(tf.test.TestCase):
g = tf.Graph()
with g.as_default():
- f_def = function.define_function(XSquarePlusOne, {"x": tf.float32})
- g._add_function(f_def)
- g_def = function.define_function(XSquarePlusOneGrad, {"x": tf.float32,
- "dy": tf.float32})
- g._add_function(g_def)
+ f = function.define_function(XSquarePlusOne, {"x": tf.float32})
+ g = function.define_function(XSquarePlusOneGrad, {"x": tf.float32,
+ "dy": tf.float32})
epsilon = tf.constant([0.1])
two = tf.constant([2.0])
- call_f = function.call_function(f_def, two)
- call_g = function.call_function(g_def, two, epsilon)
+ call_f = function.call_function(f, two)
+ call_g = function.call_function(g, two, epsilon)
with tf.Session() as sess:
self.assertAllClose([5.0], sess.run(call_f))
@@ -262,11 +260,12 @@ class FunctionTest(tf.test.TestCase):
def testFunctionDecorator(self):
- @function.Defun(b=tf.int32)
- def Minus1(b):
- return b - 1
-
with tf.Graph().as_default():
+
+ @function.Defun(b=tf.int32)
+ def Minus1(b):
+ return b - 1
+
two = tf.constant([2])
call1 = Minus1(two)
self.assertEquals("Minus1", call1.op.name)
@@ -278,6 +277,18 @@ class FunctionTest(tf.test.TestCase):
self.assertAllEqual([1], sess.run(call1))
self.assertAllEqual([0], sess.run(call2))
+ def testNestedFunction(self):
+ with tf.Graph().as_default():
+ @function.Defun(x=tf.float32)
+ def Cube(x):
+ return x * x * x
+ @function.Defun(x=tf.float32, y=tf.float32)
+ def CubeXPlusY(x, y):
+ return Cube(x) + y
+ z = CubeXPlusY(tf.constant(3.0), tf.constant(-2.0))
+ with self.test_session():
+ self.assertAllEqual(z.eval(), 25.0)
+
def testUnrollLSTM(self):
# Helper to construct a LSTM cell graph.
@@ -289,23 +300,29 @@ class FunctionTest(tf.test.TestCase):
new_m = tf.sigmoid(o_g) * tf.tanh(new_c)
return new_m, new_c
- # Helper to construct a LSTM function.
- @function.Defun(x=tf.float32,
- mprev=tf.float32,
- cprev=tf.float32,
- weights=tf.float32)
- def LSTMCellFunc(x, mprev, cprev, weights):
- return LSTMCell(x, mprev, cprev, weights)
-
batch_size = 16
lstm_dims = 32
num_unroll = 100
# Run one step of the unrolled lstm graph.
- def RunStep(cell):
+ def RunStep(use_func):
g = tf.Graph()
start = time.time()
with g.as_default():
+ # Helper to construct a LSTM function.
+ if use_func:
+
+ @function.Defun(x=tf.float32,
+ mprev=tf.float32,
+ cprev=tf.float32,
+ weights=tf.float32)
+ def LSTMCellFunc(x, mprev, cprev, weights):
+ return LSTMCell(x, mprev, cprev, weights)
+
+ cell = LSTMCellFunc
+ else:
+ cell = LSTMCell
+
m = tf.zeros(shape=[batch_size, lstm_dims])
c = tf.zeros(shape=[batch_size, lstm_dims])
weights = tf.random_uniform(
@@ -327,8 +344,8 @@ class FunctionTest(tf.test.TestCase):
mv, cv = sess.run([m, c])
return mv, cv
- mv0, cv0 = RunStep(LSTMCell)
- mv1, cv1 = RunStep(LSTMCellFunc)
+ mv0, cv0 = RunStep(use_func=False)
+ mv1, cv1 = RunStep(use_func=True)
self.assertAllClose(mv0, mv1)
self.assertAllClose(cv0, cv1)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index e7d41a5918..263bda991c 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1791,6 +1791,16 @@ class Graph(object):
"""
return name in self._functions
+ def _get_function(self, name):
+ """Returns the function definition for 'name'.
+
+ Args:
+ name: string function name.
+ Returns:
+ The function def proto.
+ """
+ return self._functions[name]
+
def _add_function(self, function_def):
"""Adds a function to the graph.
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index a9f994c78c..b6985d8af2 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -265,27 +265,28 @@ class GradientsTest(test_util.TensorFlowTestCase):
self.assertTrue(isinstance(grads[0], ops.Tensor))
-@function.Defun(x=tf.float32)
-def XSquarePlusOne(x):
- return x * x + 1.0
-
-
class FunctionGradientsTest(test_util.TensorFlowTestCase):
+ @classmethod
+ def XSquarePlusOne(cls, x):
+ return x * x + 1.0
+
def testFunctionGradientsBasic(self):
- with ops.Graph().as_default():
+ g = ops.Graph()
+ with g.as_default():
+ f = function.Defun(x=tf.float32)(self.XSquarePlusOne)
two = tf.constant([2.0], name="two")
- y = XSquarePlusOne(two)
+ y = f(two)
# Build gradient graph (should add SymbolicGradient node for function).
grads = gradients.gradients(y, two)
-
with self.test_session() as sess:
self.assertAllEqual([4.0], sess.run(grads)[0])
def testFunctionGradientsComposition(self):
with ops.Graph().as_default():
+ f = function.Defun(x=tf.float32)(self.XSquarePlusOne)
two = tf.constant([2.0], name="two")
- y = XSquarePlusOne(XSquarePlusOne(two))
+ y = f(f(two))
# Build gradient graph (should add SymbolicGradient node for function).
grads = gradients.gradients(y, two)