diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-16 11:37:40 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-16 11:41:16 -0700 |
commit | 8f2bf53b4bc2a43d00f157f06545e04749ff35da (patch) | |
tree | 147152f5e783389f13923f55f03e68b553a58a5b /tensorflow/contrib/remote_fused_graph | |
parent | a66de1eca225bc95e7972974a7089d84df8a8055 (diff) |
Add python support of remote fused graph ops to contrib
PiperOrigin-RevId: 159254265
Diffstat (limited to 'tensorflow/contrib/remote_fused_graph')
7 files changed, 267 insertions, 0 deletions
diff --git a/tensorflow/contrib/remote_fused_graph/README.md b/tensorflow/contrib/remote_fused_graph/README.md new file mode 100644 index 0000000000..267cfa1019 --- /dev/null +++ b/tensorflow/contrib/remote_fused_graph/README.md @@ -0,0 +1,8 @@ +# Remote Fused Graph + +## Description + +This module contains libraries for remote fused graph utilities + +Maintainers: +- Satoshi Kataoka (satok@google.com, github.com/satok16) diff --git a/tensorflow/contrib/remote_fused_graph/pylib/BUILD b/tensorflow/contrib/remote_fused_graph/pylib/BUILD new file mode 100644 index 0000000000..c7ed663131 --- /dev/null +++ b/tensorflow/contrib/remote_fused_graph/pylib/BUILD @@ -0,0 +1,56 @@ +# Description: +# Contains ops for remote fused graph + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") + +tf_gen_op_wrapper_py( + name = "gen_remote_fused_graph_ops", + out = "python/ops/gen_remote_fused_graph_ops.py", + deps = [ + "//tensorflow/core:remote_fused_graph_ops_op_lib", + ], +) + +py_library( + name = "remote_fused_graph_ops_py", + srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + srcs_version = "PY2AND3", + deps = [ + ":gen_remote_fused_graph_ops", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) + +py_test( + name = "remote_fused_graph_ops_test", + size = "small", + srcs = ["python/ops/remote_fused_graph_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":remote_fused_graph_ops_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:nn_ops", + "//third_party/py/numpy", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/remote_fused_graph/pylib/__init__.py b/tensorflow/contrib/remote_fused_graph/pylib/__init__.py new file mode 100644 index 0000000000..4d23c38932 --- /dev/null +++ b/tensorflow/contrib/remote_fused_graph/pylib/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Remote fused graph ops python library. + +## This package provides classes for remote fused graph ops. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import, line-too-long +from tensorflow.contrib.remote_fused_graph.pylib.python.ops.remote_fused_graph_ops import * +# pylint: enable=unused-import,wildcard-import,line-too-long + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['remote_fused_graph_execute'] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/__init__.py b/tensorflow/contrib/remote_fused_graph/pylib/python/__init__.py new file mode 100644 index 0000000000..b66091f875 --- /dev/null +++ b/tensorflow/contrib/remote_fused_graph/pylib/python/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Remote fused graph ops python library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/__init__.py b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/__init__.py new file mode 100644 index 0000000000..b66091f875 --- /dev/null +++ b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Remote fused graph ops python library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py new file mode 100644 index 0000000000..2054367f0d --- /dev/null +++ b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py @@ -0,0 +1,66 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operations to execute a subgraph on a remote processor.""" + +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import, line-too-long +from tensorflow.contrib.remote_fused_graph.pylib.python.ops import gen_remote_fused_graph_ops +from tensorflow.core.framework import remote_fused_graph_execute_info_pb2 as info_pb2 +# pylint: enable=unused-import,wildcard-import,line-too-long + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops + +# RemoteFusedGraphExecute is not differenciable op. +ops.NotDifferentiable("RemoteFusedGraphExecute") + + +def remote_fused_graph_execute(inputs, + output_types, + graph_def, + graph_input_node_names, + graph_output_node_names, + executor_name, + serialized_executor_parameters, + default_graph_input_tensor_type_shapes=None, + default_graph_output_tensor_type_shapes=None): + """A wrapper for remote_fused_graph_execute.""" + info_proto = info_pb2.RemoteFusedGraphExecuteInfo() + info_proto.remote_graph.CopyFrom(graph_def) + info_proto.graph_input_node_name.extend(graph_input_node_names) + info_proto.graph_output_node_name.extend(graph_output_node_names) + info_proto.executor_name = executor_name + info_proto.serialized_executor_parameters = serialized_executor_parameters + if default_graph_input_tensor_type_shapes: + for type_shape in default_graph_input_tensor_type_shapes: + type_shape_proto = info_proto.default_graph_input_tensor_shape.add() + type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0])) + for dim in type_shape[1]: + type_shape_proto.shape.dim.add().size = dim + if default_graph_output_tensor_type_shapes: + for type_shape in default_graph_output_tensor_type_shapes: + type_shape_proto = info_proto.default_graph_output_tensor_shape.add() + type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0])) + for dim in type_shape[1]: + type_shape_proto.shape.dim.add().size = dim + + serialized_info = info_proto.SerializeToString() + + return gen_remote_fused_graph_ops.remote_fused_graph_execute( + inputs, output_types, serialized_info) diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops_test.py b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops_test.py new file mode 100644 index 0000000000..45df909148 --- /dev/null +++ b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops_test.py @@ -0,0 +1,66 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.remote_fused_graph_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +# pylint: disable=unused-import,wildcard-import, line-too-long +from tensorflow.contrib.remote_fused_graph.pylib.python.ops import remote_fused_graph_ops +# pylint: enable=unused-import,wildcard-import,line-too-long + +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +class RemoteFusedGraphExecuteTest(test_util.TensorFlowTestCase): + """Tests for RemoteFusedGraphExecute op.""" + + def testBuild(self): + graph = graph_pb2.GraphDef() + node = graph.node.add() + node.name = "a" + node.op = "op0" + node = graph.node.add() + node.name = "b" + node.op = "op1" + inputs = [ops.convert_n_to_tensor([1], dtypes.int64)] + output_types = [np.int64, np.int64] + graph_input_node_names = ["a"] + graph_output_node_names = ["a", "b"] + executor_name = "" + serialized_executor_parameters = b"" + default_graph_input_tensor_type_shapes = [[dtypes.int64, [1]]] + default_graph_output_tensor_type_shapes = [[dtypes.int64, [1]], + [dtypes.int64, [1]]] + + output_nodes = remote_fused_graph_ops.remote_fused_graph_execute( + inputs, output_types, graph, graph_input_node_names, + graph_output_node_names, executor_name, serialized_executor_parameters, + default_graph_input_tensor_type_shapes, + default_graph_output_tensor_type_shapes) + self.assertEqual(2, len(output_nodes)) + for output_node in output_nodes: + with self.test_session(use_gpu=False): + output_node.eval() + + +if __name__ == "__main__": + googletest.main() |