aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler/tf_optimizer_test.py
blob: eca0f679829507212608e75f2c792b4bddf9b1da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the swig wrapper tf_optimizer."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.grappler import item as gitem
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test


class PyWrapOptimizeGraphTest(test.TestCase):

  def testBasic(self):
    """Make sure arguments can be passed correctly."""
    a = constant_op.constant(10, name='a')
    b = constant_op.constant(20, name='b')
    c = math_ops.add_n([a, b], name='c')
    d = math_ops.add_n([b, c], name='d')
    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    # Being a train_op will make 'd' to be added as a fetch node.
    train_op.append(d)
    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())

    rewriter_config = rewriter_config_pb2.RewriterConfig()
    rewriter_config.optimizers.append('constfold')
    rewriter_config.min_graph_nodes = -1

    graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

    self.assertEqual(len(graph.node), 1)
    self.assertItemsEqual([node.name for node in graph.node], ['d'])

  def testKeepNodes(self):
    g = ops.Graph()
    with g.as_default():
      a1 = variables.VariableV1(
          1.0)  # Must be preserved since it's in the collection 'variables'.
      a2 = constant_op.constant(0, shape=[50, 50], name='keep')
      ops.add_to_collection('a2', a2)  # Explicitly add to collection.
      b = constant_op.constant(1, shape=[100, 10])
      c = constant_op.constant(0, shape=[10, 30])
      d = math_ops.matmul(b, c)
      ops.add_to_collection('train_op', d)  # d is the fetch node.

    # Optimize the graph.
    mg = meta_graph.create_meta_graph_def(graph=g)
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    rewriter_config.min_graph_nodes = -1
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

    # Check that the nodes referenced in various collections have been preserved
    self.assertEqual(len(optimized_graph.node), 5)
    self.assertEqual(d.op.name, optimized_graph.node[0].name)
    self.assertEqual(a1.op.name, optimized_graph.node[1].name)
    self.assertEqual('Variable/initial_value', optimized_graph.node[2].name)
    self.assertEqual(a2.op.name, optimized_graph.node[3].name)
    self.assertEqual('Variable/Assign', optimized_graph.node[4].name)

  def testLoops(self):
    g = ops.Graph()
    with g.as_default():

      def _Cond(_, counter):
        return counter < end

      def _Body(buf, counter):
        buf = array_ops.concat([buf, [counter]], 0)
        counter += 1
        return [buf, counter]

      start = array_ops.placeholder(shape=[], dtype=dtypes.int32)
      end = array_ops.placeholder(shape=[], dtype=dtypes.int32)
      init_buf = array_ops.zeros(shape=[0], dtype=dtypes.int32)
      loop_vars = [init_buf, start]
      shape_inv = [
          tensor_shape.TensorShape([None]),
          tensor_shape.TensorShape([])
      ]
      buf, _ = control_flow_ops.while_loop(_Cond, _Body, loop_vars, shape_inv)

      f = -array_ops.ones_like(buf, optimize=False)
      buf_shape = array_ops.shape(buf)
      f_shape = array_ops.shape(f)
      ops.add_to_collection('train_op', buf_shape)
      ops.add_to_collection('train_op', f_shape)

    # Optimize the graph.
    mg = meta_graph.create_meta_graph_def(graph=g)
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    rewriter_config.min_graph_nodes = -1
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
    mg.graph_def.CopyFrom(optimized_graph)

    # Check that the nodes referenced in various collections have been preserved
    item = gitem.Item(mg)
    props = item.GetOpProperties()
    buf_prop = props[buf.op.name]
    f_prop = props[f.op.name]
    self.assertEqual(buf_prop, f_prop)


if __name__ == '__main__':
  test.main()