aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-04-25 14:49:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-25 16:08:07 -0700
commit58fe576e207705929794968535735cbb9ac65db0 (patch)
tree4e67a24b9e96f8b4786ab62fc20642cd784b2c23 /tensorflow/python
parent546befc40895749bc81b8d59c24fdf375a9a5034 (diff)
Added a python API to the meta graph optimizer
Change: 154232702
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/BUILD29
-rw-r--r--tensorflow/python/grappler/tf_optimizer.i91
-rw-r--r--tensorflow/python/grappler/tf_optimizer.py35
-rw-r--r--tensorflow/python/grappler/tf_optimizer_test.py53
-rw-r--r--tensorflow/python/tensorflow.i2
5 files changed, 210 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index c367d20f81..ec0a34008a 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2576,6 +2576,7 @@ tf_py_wrap_cc(
"client/tf_session.i",
"framework/cpp_shape_inference.i",
"framework/python_op_gen.i",
+ "grappler/tf_optimizer.i",
"lib/core/py_func.i",
"lib/core/strings.i",
"lib/io/file_io.i",
@@ -2604,6 +2605,9 @@ tf_py_wrap_cc(
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:grappler_item_builder",
+ "//tensorflow/core/grappler/optimizers:meta_optimizer",
"//tensorflow/core:lib",
"//tensorflow/core:reader_base",
"//tensorflow/core/debug",
@@ -3539,3 +3543,28 @@ cuda_py_test(
],
main = "client/session_benchmark.py",
)
+
+py_library(
+ name = "tf_optimizer",
+ srcs = [
+ "grappler/tf_optimizer.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [":pywrap_tensorflow_internal"],
+)
+
+py_test(
+ name = "tf_optimizer_test",
+ size = "small",
+ srcs = ["grappler/tf_optimizer_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"], # tf_optimizer is not available in pip.
+ deps = [
+ ":client_testlib",
+ ":framework_for_generated_wrappers",
+ ":math_ops",
+ ":tf_optimizer",
+ "//tensorflow/core:protos_all_py",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i
new file mode 100644
index 0000000000..ab887e63e5
--- /dev/null
+++ b/tensorflow/python/grappler/tf_optimizer.i
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+
+%include "tensorflow/python/platform/base.i"
+
+%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
+ char* c_string;
+ Py_ssize_t py_size;
+ if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
+ // Python has raised an error (likely TypeError or UnicodeEncodeError).
+ SWIG_fail;
+ }
+
+ if (!temp.ParseFromString(string(c_string, py_size))) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ "The MetaGraphDef could not be parsed as a valid protocol buffer");
+ SWIG_fail;
+ }
+ $1 = &temp;
+}
+
+%typemap(in) const tensorflow::RewriterConfig& (
+ tensorflow::RewriterConfig temp) {
+ char* c_string;
+ Py_ssize_t py_size;
+ if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
+ // Python has raised an error (likely TypeError or UnicodeEncodeError).
+ SWIG_fail;
+ }
+
+ if (!temp.ParseFromString(string(c_string, py_size))) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ "The RewriterConfig could not be parsed as a valid protocol buffer");
+ SWIG_fail;
+ }
+ $1 = &temp;
+}
+
+%{
+ #include <memory>
+ #include "tensorflow/c/tf_status_helper.h"
+ #include "tensorflow/core/lib/core/status.h"
+ #include "tensorflow/core/framework/graph.pb.h"
+ #include "tensorflow/core/grappler/grappler_item.h"
+ #include "tensorflow/core/grappler/grappler_item_builder.h"
+ #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+ #include "tensorflow/core/protobuf/meta_graph.pb.h"
+ #include "tensorflow/core/protobuf/rewriter_config.pb.h"
+
+PyObject* TF_OptimizeGraph(
+ const tensorflow::RewriterConfig& rewriter_config,
+ const tensorflow::MetaGraphDef& metagraph,
+ const string& graph_id, TF_Status* out_status) {
+ const tensorflow::grappler::ItemConfig item_config;
+ std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
+ tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
+ tensorflow::GraphDef out_graph;
+ tensorflow::Status status = tensorflow::grappler::RunMetaOptimizer(
+ *grappler_item, rewriter_config, &out_graph);
+ tensorflow::Set_TF_Status_from_Status(out_status, status);
+ string out_graph_str = out_graph.SerializeAsString();
+ PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(),
+ out_graph_str.size());
+ return ret;
+ }
+%}
+
+
+// Wrap this function
+PyObject* TF_OptimizeGraph(
+ const tensorflow::RewriterConfig& rewriter_config,
+ const tensorflow::MetaGraphDef& metagraph,
+ const string& graph_id, TF_Status* out_status);
+
+
+
diff --git a/tensorflow/python/grappler/tf_optimizer.py b/tensorflow/python/grappler/tf_optimizer.py
new file mode 100644
index 0000000000..d0464c6054
--- /dev/null
+++ b/tensorflow/python/grappler/tf_optimizer.py
@@ -0,0 +1,35 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Provides a proper python API for the symbols exported through swig."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python import pywrap_tensorflow as tf_opt
+from tensorflow.python.framework import errors
+
+
+def OptimizeGraph(rewriter_config, metagraph, graph_id=b'graph_to_optimize'):
+ """Optimize the provided metagraph."""
+ with errors.raise_exception_on_not_ok_status() as status:
+ ret_from_swig = tf_opt.TF_OptimizeGraph(rewriter_config.SerializeToString(),
+ metagraph.SerializeToString(),
+ graph_id, status)
+ if ret_from_swig is None:
+ return None
+ out_graph = graph_pb2.GraphDef().FromString(ret_from_swig)
+ return out_graph
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py
new file mode 100644
index 0000000000..b1efc2dbfb
--- /dev/null
+++ b/tensorflow/python/grappler/tf_optimizer_test.py
@@ -0,0 +1,53 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the swig wrapper tf_optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class PyWrapOptimizeGraphTest(test.TestCase):
+
+ def testBasic(self):
+ """Make sure arguments can be passed correctly."""
+ a = constant_op.constant(10, name='a')
+ b = constant_op.constant(20, name='b')
+ c = math_ops.add_n([a, b], name='c')
+ d = math_ops.add_n([b, c], name='d')
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ train_op.append(d)
+ mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
+
+ rewriter_config = rewriter_config_pb2.RewriterConfig()
+ rewriter_config.optimizers.append('constfold')
+
+ graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
+
+ self.assertEqual(len(graph.node), 5)
+ self.assertItemsEqual([node.name for node in graph.node],
+ ['a', 'b', 'c', 'd', 'ConstantFolding/c'])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index a0009031ac..5c2ad417e2 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -40,3 +40,5 @@ limitations under the License.
%include "tensorflow/python/util/kernel_registry.i"
%include "tensorflow/python/util/transform_graph.i"
+
+%include "tensorflow/python/grappler/tf_optimizer.i"