aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/remote_fused_graph
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-16 11:37:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-16 11:41:16 -0700
commit8f2bf53b4bc2a43d00f157f06545e04749ff35da (patch)
tree147152f5e783389f13923f55f03e68b553a58a5b /tensorflow/contrib/remote_fused_graph
parenta66de1eca225bc95e7972974a7089d84df8a8055 (diff)
Add python support of remote fused graph ops to contrib
PiperOrigin-RevId: 159254265
Diffstat (limited to 'tensorflow/contrib/remote_fused_graph')
-rw-r--r--tensorflow/contrib/remote_fused_graph/README.md8
-rw-r--r--tensorflow/contrib/remote_fused_graph/pylib/BUILD56
-rw-r--r--tensorflow/contrib/remote_fused_graph/pylib/__init__.py33
-rw-r--r--tensorflow/contrib/remote_fused_graph/pylib/python/__init__.py19
-rw-r--r--tensorflow/contrib/remote_fused_graph/pylib/python/ops/__init__.py19
-rw-r--r--tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py66
-rw-r--r--tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops_test.py66
7 files changed, 267 insertions, 0 deletions
diff --git a/tensorflow/contrib/remote_fused_graph/README.md b/tensorflow/contrib/remote_fused_graph/README.md
new file mode 100644
index 0000000000..267cfa1019
--- /dev/null
+++ b/tensorflow/contrib/remote_fused_graph/README.md
@@ -0,0 +1,8 @@
+# Remote Fused Graph
+
+## Description
+
+This module contains libraries for remote fused graph utilities
+
+Maintainers:
+- Satoshi Kataoka (satok@google.com, github.com/satok16)
diff --git a/tensorflow/contrib/remote_fused_graph/pylib/BUILD b/tensorflow/contrib/remote_fused_graph/pylib/BUILD
new file mode 100644
index 0000000000..c7ed663131
--- /dev/null
+++ b/tensorflow/contrib/remote_fused_graph/pylib/BUILD
@@ -0,0 +1,56 @@
+# Description:
+# Contains ops for remote fused graph
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+
+tf_gen_op_wrapper_py(
+ name = "gen_remote_fused_graph_ops",
+ out = "python/ops/gen_remote_fused_graph_ops.py",
+ deps = [
+ "//tensorflow/core:remote_fused_graph_ops_op_lib",
+ ],
+)
+
+py_library(
+ name = "remote_fused_graph_ops_py",
+ srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_remote_fused_graph_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ ],
+)
+
+py_test(
+ name = "remote_fused_graph_ops_test",
+ size = "small",
+ srcs = ["python/ops/remote_fused_graph_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":remote_fused_graph_ops_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:nn_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/remote_fused_graph/pylib/__init__.py b/tensorflow/contrib/remote_fused_graph/pylib/__init__.py
new file mode 100644
index 0000000000..4d23c38932
--- /dev/null
+++ b/tensorflow/contrib/remote_fused_graph/pylib/__init__.py
@@ -0,0 +1,33 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Remote fused graph ops python library.
+
+## This package provides classes for remote fused graph ops.
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import, line-too-long
+from tensorflow.contrib.remote_fused_graph.pylib.python.ops.remote_fused_graph_ops import *
+# pylint: enable=unused-import,wildcard-import,line-too-long
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = ['remote_fused_graph_execute']
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/__init__.py b/tensorflow/contrib/remote_fused_graph/pylib/python/__init__.py
new file mode 100644
index 0000000000..b66091f875
--- /dev/null
+++ b/tensorflow/contrib/remote_fused_graph/pylib/python/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Remote fused graph ops python library."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/__init__.py b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/__init__.py
new file mode 100644
index 0000000000..b66091f875
--- /dev/null
+++ b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Remote fused graph ops python library."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py
new file mode 100644
index 0000000000..2054367f0d
--- /dev/null
+++ b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py
@@ -0,0 +1,66 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operations to execute a subgraph on a remote processor."""
+
+# pylint: disable=g-bad-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import, line-too-long
+from tensorflow.contrib.remote_fused_graph.pylib.python.ops import gen_remote_fused_graph_ops
+from tensorflow.core.framework import remote_fused_graph_execute_info_pb2 as info_pb2
+# pylint: enable=unused-import,wildcard-import,line-too-long
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+
+# RemoteFusedGraphExecute is not differenciable op.
+ops.NotDifferentiable("RemoteFusedGraphExecute")
+
+
+def remote_fused_graph_execute(inputs,
+ output_types,
+ graph_def,
+ graph_input_node_names,
+ graph_output_node_names,
+ executor_name,
+ serialized_executor_parameters,
+ default_graph_input_tensor_type_shapes=None,
+ default_graph_output_tensor_type_shapes=None):
+ """A wrapper for remote_fused_graph_execute."""
+ info_proto = info_pb2.RemoteFusedGraphExecuteInfo()
+ info_proto.remote_graph.CopyFrom(graph_def)
+ info_proto.graph_input_node_name.extend(graph_input_node_names)
+ info_proto.graph_output_node_name.extend(graph_output_node_names)
+ info_proto.executor_name = executor_name
+ info_proto.serialized_executor_parameters = serialized_executor_parameters
+ if default_graph_input_tensor_type_shapes:
+ for type_shape in default_graph_input_tensor_type_shapes:
+ type_shape_proto = info_proto.default_graph_input_tensor_shape.add()
+ type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0]))
+ for dim in type_shape[1]:
+ type_shape_proto.shape.dim.add().size = dim
+ if default_graph_output_tensor_type_shapes:
+ for type_shape in default_graph_output_tensor_type_shapes:
+ type_shape_proto = info_proto.default_graph_output_tensor_shape.add()
+ type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0]))
+ for dim in type_shape[1]:
+ type_shape_proto.shape.dim.add().size = dim
+
+ serialized_info = info_proto.SerializeToString()
+
+ return gen_remote_fused_graph_ops.remote_fused_graph_execute(
+ inputs, output_types, serialized_info)
diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops_test.py b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops_test.py
new file mode 100644
index 0000000000..45df909148
--- /dev/null
+++ b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops_test.py
@@ -0,0 +1,66 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.remote_fused_graph_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+# pylint: disable=unused-import,wildcard-import, line-too-long
+from tensorflow.contrib.remote_fused_graph.pylib.python.ops import remote_fused_graph_ops
+# pylint: enable=unused-import,wildcard-import,line-too-long
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class RemoteFusedGraphExecuteTest(test_util.TensorFlowTestCase):
+ """Tests for RemoteFusedGraphExecute op."""
+
+ def testBuild(self):
+ graph = graph_pb2.GraphDef()
+ node = graph.node.add()
+ node.name = "a"
+ node.op = "op0"
+ node = graph.node.add()
+ node.name = "b"
+ node.op = "op1"
+ inputs = [ops.convert_n_to_tensor([1], dtypes.int64)]
+ output_types = [np.int64, np.int64]
+ graph_input_node_names = ["a"]
+ graph_output_node_names = ["a", "b"]
+ executor_name = ""
+ serialized_executor_parameters = b""
+ default_graph_input_tensor_type_shapes = [[dtypes.int64, [1]]]
+ default_graph_output_tensor_type_shapes = [[dtypes.int64, [1]],
+ [dtypes.int64, [1]]]
+
+ output_nodes = remote_fused_graph_ops.remote_fused_graph_execute(
+ inputs, output_types, graph, graph_input_node_names,
+ graph_output_node_names, executor_name, serialized_executor_parameters,
+ default_graph_input_tensor_type_shapes,
+ default_graph_output_tensor_type_shapes)
+ self.assertEqual(2, len(output_nodes))
+ for output_node in output_nodes:
+ with self.test_session(use_gpu=False):
+ output_node.eval()
+
+
+if __name__ == "__main__":
+ googletest.main()