diff options
author | 2017-04-25 14:49:17 -0800 | |
---|---|---|
committer | 2017-04-25 16:08:07 -0700 | |
commit | 58fe576e207705929794968535735cbb9ac65db0 (patch) | |
tree | 4e67a24b9e96f8b4786ab62fc20642cd784b2c23 /tensorflow/python | |
parent | 546befc40895749bc81b8d59c24fdf375a9a5034 (diff) |
Added a python API to the meta graph optimizer
Change: 154232702
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/BUILD | 29 | ||||
-rw-r--r-- | tensorflow/python/grappler/tf_optimizer.i | 91 | ||||
-rw-r--r-- | tensorflow/python/grappler/tf_optimizer.py | 35 | ||||
-rw-r--r-- | tensorflow/python/grappler/tf_optimizer_test.py | 53 | ||||
-rw-r--r-- | tensorflow/python/tensorflow.i | 2 |
5 files changed, 210 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c367d20f81..ec0a34008a 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2576,6 +2576,7 @@ tf_py_wrap_cc( "client/tf_session.i", "framework/cpp_shape_inference.i", "framework/python_op_gen.i", + "grappler/tf_optimizer.i", "lib/core/py_func.i", "lib/core/strings.i", "lib/io/file_io.i", @@ -2604,6 +2605,9 @@ tf_py_wrap_cc( "//tensorflow/c:tf_status_helper", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_session", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:grappler_item_builder", + "//tensorflow/core/grappler/optimizers:meta_optimizer", "//tensorflow/core:lib", "//tensorflow/core:reader_base", "//tensorflow/core/debug", @@ -3539,3 +3543,28 @@ cuda_py_test( ], main = "client/session_benchmark.py", ) + +py_library( + name = "tf_optimizer", + srcs = [ + "grappler/tf_optimizer.py", + ], + srcs_version = "PY2AND3", + deps = [":pywrap_tensorflow_internal"], +) + +py_test( + name = "tf_optimizer_test", + size = "small", + srcs = ["grappler/tf_optimizer_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], # tf_optimizer is not available in pip. + deps = [ + ":client_testlib", + ":framework_for_generated_wrappers", + ":math_ops", + ":tf_optimizer", + "//tensorflow/core:protos_all_py", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i new file mode 100644 index 0000000000..ab887e63e5 --- /dev/null +++ b/tensorflow/python/grappler/tf_optimizer.i @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +%include "tensorflow/python/platform/base.i" + +%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) { + char* c_string; + Py_ssize_t py_size; + if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { + // Python has raised an error (likely TypeError or UnicodeEncodeError). + SWIG_fail; + } + + if (!temp.ParseFromString(string(c_string, py_size))) { + PyErr_SetString( + PyExc_TypeError, + "The MetaGraphDef could not be parsed as a valid protocol buffer"); + SWIG_fail; + } + $1 = &temp; +} + +%typemap(in) const tensorflow::RewriterConfig& ( + tensorflow::RewriterConfig temp) { + char* c_string; + Py_ssize_t py_size; + if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { + // Python has raised an error (likely TypeError or UnicodeEncodeError). + SWIG_fail; + } + + if (!temp.ParseFromString(string(c_string, py_size))) { + PyErr_SetString( + PyExc_TypeError, + "The RewriterConfig could not be parsed as a valid protocol buffer"); + SWIG_fail; + } + $1 = &temp; +} + +%{ + #include <memory> + #include "tensorflow/c/tf_status_helper.h" + #include "tensorflow/core/lib/core/status.h" + #include "tensorflow/core/framework/graph.pb.h" + #include "tensorflow/core/grappler/grappler_item.h" + #include "tensorflow/core/grappler/grappler_item_builder.h" + #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" + #include "tensorflow/core/protobuf/meta_graph.pb.h" + #include "tensorflow/core/protobuf/rewriter_config.pb.h" + +PyObject* TF_OptimizeGraph( + const tensorflow::RewriterConfig& rewriter_config, + const tensorflow::MetaGraphDef& metagraph, + const string& graph_id, TF_Status* out_status) { + const tensorflow::grappler::ItemConfig item_config; + std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item = + tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config); + tensorflow::GraphDef out_graph; + tensorflow::Status status = tensorflow::grappler::RunMetaOptimizer( + *grappler_item, rewriter_config, &out_graph); + tensorflow::Set_TF_Status_from_Status(out_status, status); + string out_graph_str = out_graph.SerializeAsString(); + PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(), + out_graph_str.size()); + return ret; + } +%} + + +// Wrap this function +PyObject* TF_OptimizeGraph( + const tensorflow::RewriterConfig& rewriter_config, + const tensorflow::MetaGraphDef& metagraph, + const string& graph_id, TF_Status* out_status); + + + diff --git a/tensorflow/python/grappler/tf_optimizer.py b/tensorflow/python/grappler/tf_optimizer.py new file mode 100644 index 0000000000..d0464c6054 --- /dev/null +++ b/tensorflow/python/grappler/tf_optimizer.py @@ -0,0 +1,35 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Provides a proper python API for the symbols exported through swig.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.framework import graph_pb2 +from tensorflow.python import pywrap_tensorflow as tf_opt +from tensorflow.python.framework import errors + + +def OptimizeGraph(rewriter_config, metagraph, graph_id=b'graph_to_optimize'): + """Optimize the provided metagraph.""" + with errors.raise_exception_on_not_ok_status() as status: + ret_from_swig = tf_opt.TF_OptimizeGraph(rewriter_config.SerializeToString(), + metagraph.SerializeToString(), + graph_id, status) + if ret_from_swig is None: + return None + out_graph = graph_pb2.GraphDef().FromString(ret_from_swig) + return out_graph diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py new file mode 100644 index 0000000000..b1efc2dbfb --- /dev/null +++ b/tensorflow/python/grappler/tf_optimizer_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the swig wrapper tf_optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import meta_graph +from tensorflow.python.framework import ops +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class PyWrapOptimizeGraphTest(test.TestCase): + + def testBasic(self): + """Make sure arguments can be passed correctly.""" + a = constant_op.constant(10, name='a') + b = constant_op.constant(20, name='b') + c = math_ops.add_n([a, b], name='c') + d = math_ops.add_n([b, c], name='d') + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + train_op.append(d) + mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) + + rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.optimizers.append('constfold') + + graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) + + self.assertEqual(len(graph.node), 5) + self.assertItemsEqual([node.name for node in graph.node], + ['a', 'b', 'c', 'd', 'ConstantFolding/c']) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index a0009031ac..5c2ad417e2 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -40,3 +40,5 @@ limitations under the License. %include "tensorflow/python/util/kernel_registry.i" %include "tensorflow/python/util/transform_graph.i" + +%include "tensorflow/python/grappler/tf_optimizer.i" |