aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/saved_model/BUILD14
-rw-r--r--tensorflow/python/saved_model/utils.py16
-rw-r--r--tensorflow/python/saved_model/utils_test.py69
3 files changed, 91 insertions, 8 deletions
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 4b07451374..c58f87badb 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -82,8 +82,20 @@ py_library(
name = "utils",
srcs = ["utils.py"],
srcs_version = "PY2AND3",
+ deps = ["//tensorflow/core:protos_all_py"],
+)
+
+py_test(
+ name = "utils_test",
+ size = "small",
+ srcs = [
+ "utils_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:private"],
deps = [
- "//tensorflow/core:protos_all_py",
+ ":utils",
+ "//tensorflow:tensorflow_py",
],
)
diff --git a/tensorflow/python/saved_model/utils.py b/tensorflow/python/saved_model/utils.py
index 6d80e4f1bd..550eed0fcc 100644
--- a/tensorflow/python/saved_model/utils.py
+++ b/tensorflow/python/saved_model/utils.py
@@ -21,23 +21,25 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.python.framework import dtypes
# TensorInfo helpers.
-def build_tensor_info(name=None, dtype=None, shape=None):
+def build_tensor_info(tensor):
"""Utility function to build TensorInfo proto.
Args:
- name: Name of the tensor to be used in the TensorInfo.
- dtype: Datatype to be set in the TensorInfo.
- shape: TensorShapeProto to specify the shape of the tensor in the
- TensorInfo.
+ tensor: Tensor whose name, dtype and shape are used to build the TensorInfo.
Returns:
- A TensorInfo protocol buffer constructed based on the supplied arguments.
+ A TensorInfo protocol buffer constructed based on the supplied argument.
"""
- return meta_graph_pb2.TensorInfo(name=name, dtype=dtype, shape=shape)
+ dtype_enum = dtypes.as_dtype(tensor.dtype).as_datatype_enum
+ return meta_graph_pb2.TensorInfo(
+ name=tensor.name,
+ dtype=dtype_enum,
+ tensor_shape=tensor.get_shape().as_proto())
# SignatureDef helpers.
diff --git a/tensorflow/python/saved_model/utils_test.py b/tensorflow/python/saved_model/utils_test.py
new file mode 100644
index 0000000000..8ce7d1dea1
--- /dev/null
+++ b/tensorflow/python/saved_model/utils_test.py
@@ -0,0 +1,69 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SavedModel utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.saved_model import utils
+
+
+class UtilsTest(tf.test.TestCase):
+
+ def testBuildTensorInfo(self):
+ x = tf.placeholder(tf.float32, 1, name="x")
+ x_tensor_info = utils.build_tensor_info(x)
+ self.assertEqual("x:0", x_tensor_info.name)
+ self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info.dtype)
+ self.assertEqual(1, len(x_tensor_info.tensor_shape.dim))
+ self.assertEqual(1, x_tensor_info.tensor_shape.dim[0].size)
+
+ def testBuildSignatureDef(self):
+ x = tf.placeholder(tf.float32, 1, name="x")
+ x_tensor_info = utils.build_tensor_info(x)
+ inputs = dict()
+ inputs["foo-input"] = x_tensor_info
+
+ y = tf.placeholder(tf.float32, name="y")
+ y_tensor_info = utils.build_tensor_info(y)
+ outputs = dict()
+ outputs["foo-output"] = y_tensor_info
+
+ signature_def = utils.build_signature_def(inputs, outputs,
+ "foo-method-name")
+ self.assertEqual("foo-method-name", signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(1, len(signature_def.inputs))
+ x_tensor_info_actual = signature_def.inputs["foo-input"]
+ self.assertEqual("x:0", x_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype)
+ self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim))
+ self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size)
+
+ # Check outputs in signature def.
+ self.assertEqual(1, len(signature_def.outputs))
+ y_tensor_info_actual = signature_def.outputs["foo-output"]
+ self.assertEqual("y:0", y_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype)
+ self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
+
+
+if __name__ == "__main__":
+ tf.test.main()