diff options
-rw-r--r-- | tensorflow/core/framework/op.cc | 2 | ||||
-rw-r--r-- | tensorflow/python/framework/function.py | 72 | ||||
-rw-r--r-- | tensorflow/python/framework/function_test.py | 61 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 10 | ||||
-rw-r--r-- | tensorflow/python/ops/gradients_test.py | 19 |
5 files changed, 100 insertions, 64 deletions
diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index bbe722e329..6d42353630 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -60,7 +60,7 @@ const OpDef* OpRegistry::LookUp(const string& op_type_name, if (op_def == nullptr) { status->Update( errors::NotFound("Op type not registered '", op_type_name, "'")); - LOG(INFO) << status->ToString(); + VLOG(1) << status->ToString(); static bool first_unregistered = true; if (first_unregistered) { OpList op_list; diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 5ac7e385c6..307b7041a4 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -41,18 +41,6 @@ def _tensor_to_argdef(t): return arg -def _is_array_type_input(op, i): - registered_ops = op_def_registry.get_registered_ops() - if op not in registered_ops: - return False - op_def = registered_ops[op] - if i not in xrange(len(op_def.input_arg)): - raise TypeError("Expected arg index " - "to be in [0, %d)" % len(op_def.input_arg)) - input_arg = op_def.input_arg[i] - return True if input_arg.number_attr else False - - def _get_node_def_attr(op): # pylint: disable=protected-access return op._node_def.attr @@ -93,11 +81,16 @@ def _add_output_array_to_list(op, start, limit, dtype, func): return arg_name -def _add_op_node(op, func): +def _add_op_node(graph, op, func): """Converts an op to a function def node and add it to `func`.""" node = function_pb2.FunctionDef.Node() node.op = op.type - op_def = op_def_registry.get_registered_ops()[op.type] + # pylint: disable=protected-access + if graph._is_function(op.type): + op_def = graph._get_function(op.type).signature + else: + op_def = op_def_registry.get_registered_ops()[op.type] + # pylint: enable=protected-access attrs = _get_node_def_attr(op) out_index = 0 for arg_def in op_def.output_arg: @@ -173,10 +166,11 @@ def graph_to_function_def(graph, name, inputs, outputs): func.signature.output_arg.extend([_tensor_to_argdef(graph.get_tensor_by_name( o.name)) for o in outputs]) func_arg_placeholders = set([i.name for i in inputs]) + g = ops.get_default_graph() for op in graph.get_operations(): tensor_name = op.values()[0].name if tensor_name not in func_arg_placeholders: - _add_op_node(op, func) + _add_op_node(g, op, func) return func @@ -220,9 +214,8 @@ def call_function(func_def, *inputs, **kwargs): raise ValueError("Expected number of arguments: %d" % len(func_def.signature.input_arg)) output_types = [dtypes.DType(x.type) for x in func_def.signature.output_arg] - g = ops.get_default_graph() - g._add_function(func_def) # pylint: disable=protected-access # TODO(touts): Pass compute_shapes as "try if function exists" + g = ops.get_default_graph() op = g.create_op(func_name, list(inputs), output_types, @@ -248,16 +241,13 @@ def define_function(func, input_types): names arguments to `func`. The value indicate the type of tensor expected by the function. - The returned `FunctionDef` protocol buffer can be passed to - `tf.add_function()` to add the function to the default graph library. After - it has been added you can add calls to the function by passing it to - `tf.call_function()`, together with a list of tensors to use as inputs for - the function. + The returned `FunctionDef` protocol buffer is also added to the + default graph library. After it has been added you can add calls to + the function by passing it to `tf.call_function()`, together with a + list of tensors to use as inputs for the function. Notes: - * The default graph is not changed in any way: no ops are added to it, the - returned `FunctionDef` is not added to its function library. * `func` is called once, with `placeholder` tensors of the types specified in `input_types` as arguments. * Values returned by `func` must be tensors and they are recorded as being @@ -294,23 +284,38 @@ def define_function(func, input_types): Raises: ValueError: if the arguments are invalid. + """ # TODO(touts): Lift the limitation that func can only receive Tensor args. - if not inspect.isfunction(func): + if inspect.isfunction(func): + func_name = func.__name__ + elif inspect.ismethod(func): + func_name = func.im_self.__name__ + "." + func.__name__ + else: raise ValueError("Argument must be a function") argspec = inspect.getargspec(func) if argspec.varargs or argspec.keywords or argspec.defaults: raise ValueError("Only functions with plain arglists are supported.") - if len(argspec.args) != len(input_types): - raise ValueError("The function must have the same number of arguments " - "as the number of specified input types.") + if inspect.isfunction(func): + if len(argspec.args) != len(input_types): + raise ValueError("The function must have the same number of arguments " + "as the number of specified input types.") + args = argspec.args + elif inspect.ismethod(func): + if len(argspec.args) != 1 + len(input_types): + raise ValueError( + "The class function must have the same number of arguments " + "as the number of specified input types.") + args = argspec.args[1:] # 1st argument is the "class" type. + # Create the func_def object. - with ops.Graph().as_default() as temp_graph: + temp_graph = ops.Graph() + with temp_graph.as_default(): # List of placeholders for the function_def. inputs = [] # Arglist to call 'func' kwargs = {} - for argname in argspec.args: + for argname in args: if argname not in input_types: raise ValueError("Missing type for argument: " + argname) argholder = array_ops.placeholder(input_types[argname], name=argname) @@ -323,8 +328,11 @@ def define_function(func, input_types): # Convenience: if func only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) - # Build and return the FunctionDef - return graph_to_function_def(temp_graph, func.__name__, inputs, outputs) + # Build the FunctionDef + func_def = graph_to_function_def(temp_graph, func_name, inputs, outputs) + g = ops.get_default_graph() + g._add_function(func_def) # pylint: disable=protected-access + return func_def class Defun(object): diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 29e85e1253..a6e19b825a 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -121,15 +121,13 @@ class FunctionTest(tf.test.TestCase): g = tf.Graph() with g.as_default(): - f_def = function.define_function(XSquarePlusOne, {"x": tf.float32}) - g._add_function(f_def) - g_def = function.define_function(XSquarePlusOneGrad, {"x": tf.float32, - "dy": tf.float32}) - g._add_function(g_def) + f = function.define_function(XSquarePlusOne, {"x": tf.float32}) + g = function.define_function(XSquarePlusOneGrad, {"x": tf.float32, + "dy": tf.float32}) epsilon = tf.constant([0.1]) two = tf.constant([2.0]) - call_f = function.call_function(f_def, two) - call_g = function.call_function(g_def, two, epsilon) + call_f = function.call_function(f, two) + call_g = function.call_function(g, two, epsilon) with tf.Session() as sess: self.assertAllClose([5.0], sess.run(call_f)) @@ -262,11 +260,12 @@ class FunctionTest(tf.test.TestCase): def testFunctionDecorator(self): - @function.Defun(b=tf.int32) - def Minus1(b): - return b - 1 - with tf.Graph().as_default(): + + @function.Defun(b=tf.int32) + def Minus1(b): + return b - 1 + two = tf.constant([2]) call1 = Minus1(two) self.assertEquals("Minus1", call1.op.name) @@ -278,6 +277,18 @@ class FunctionTest(tf.test.TestCase): self.assertAllEqual([1], sess.run(call1)) self.assertAllEqual([0], sess.run(call2)) + def testNestedFunction(self): + with tf.Graph().as_default(): + @function.Defun(x=tf.float32) + def Cube(x): + return x * x * x + @function.Defun(x=tf.float32, y=tf.float32) + def CubeXPlusY(x, y): + return Cube(x) + y + z = CubeXPlusY(tf.constant(3.0), tf.constant(-2.0)) + with self.test_session(): + self.assertAllEqual(z.eval(), 25.0) + def testUnrollLSTM(self): # Helper to construct a LSTM cell graph. @@ -289,23 +300,29 @@ class FunctionTest(tf.test.TestCase): new_m = tf.sigmoid(o_g) * tf.tanh(new_c) return new_m, new_c - # Helper to construct a LSTM function. - @function.Defun(x=tf.float32, - mprev=tf.float32, - cprev=tf.float32, - weights=tf.float32) - def LSTMCellFunc(x, mprev, cprev, weights): - return LSTMCell(x, mprev, cprev, weights) - batch_size = 16 lstm_dims = 32 num_unroll = 100 # Run one step of the unrolled lstm graph. - def RunStep(cell): + def RunStep(use_func): g = tf.Graph() start = time.time() with g.as_default(): + # Helper to construct a LSTM function. + if use_func: + + @function.Defun(x=tf.float32, + mprev=tf.float32, + cprev=tf.float32, + weights=tf.float32) + def LSTMCellFunc(x, mprev, cprev, weights): + return LSTMCell(x, mprev, cprev, weights) + + cell = LSTMCellFunc + else: + cell = LSTMCell + m = tf.zeros(shape=[batch_size, lstm_dims]) c = tf.zeros(shape=[batch_size, lstm_dims]) weights = tf.random_uniform( @@ -327,8 +344,8 @@ class FunctionTest(tf.test.TestCase): mv, cv = sess.run([m, c]) return mv, cv - mv0, cv0 = RunStep(LSTMCell) - mv1, cv1 = RunStep(LSTMCellFunc) + mv0, cv0 = RunStep(use_func=False) + mv1, cv1 = RunStep(use_func=True) self.assertAllClose(mv0, mv1) self.assertAllClose(cv0, cv1) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index e7d41a5918..263bda991c 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1791,6 +1791,16 @@ class Graph(object): """ return name in self._functions + def _get_function(self, name): + """Returns the function definition for 'name'. + + Args: + name: string function name. + Returns: + The function def proto. + """ + return self._functions[name] + def _add_function(self, function_def): """Adds a function to the graph. diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index a9f994c78c..b6985d8af2 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -265,27 +265,28 @@ class GradientsTest(test_util.TensorFlowTestCase): self.assertTrue(isinstance(grads[0], ops.Tensor)) -@function.Defun(x=tf.float32) -def XSquarePlusOne(x): - return x * x + 1.0 - - class FunctionGradientsTest(test_util.TensorFlowTestCase): + @classmethod + def XSquarePlusOne(cls, x): + return x * x + 1.0 + def testFunctionGradientsBasic(self): - with ops.Graph().as_default(): + g = ops.Graph() + with g.as_default(): + f = function.Defun(x=tf.float32)(self.XSquarePlusOne) two = tf.constant([2.0], name="two") - y = XSquarePlusOne(two) + y = f(two) # Build gradient graph (should add SymbolicGradient node for function). grads = gradients.gradients(y, two) - with self.test_session() as sess: self.assertAllEqual([4.0], sess.run(grads)[0]) def testFunctionGradientsComposition(self): with ops.Graph().as_default(): + f = function.Defun(x=tf.float32)(self.XSquarePlusOne) two = tf.constant([2.0], name="two") - y = XSquarePlusOne(XSquarePlusOne(two)) + y = f(f(two)) # Build gradient graph (should add SymbolicGradient node for function). grads = gradients.gradients(y, two) |