diff options
-rw-r--r-- | tensorflow/python/saved_model/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/python/saved_model/utils.py | 16 | ||||
-rw-r--r-- | tensorflow/python/saved_model/utils_test.py | 69 |
3 files changed, 91 insertions, 8 deletions
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 4b07451374..c58f87badb 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -82,8 +82,20 @@ py_library( name = "utils", srcs = ["utils.py"], srcs_version = "PY2AND3", + deps = ["//tensorflow/core:protos_all_py"], +) + +py_test( + name = "utils_test", + size = "small", + srcs = [ + "utils_test.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], deps = [ - "//tensorflow/core:protos_all_py", + ":utils", + "//tensorflow:tensorflow_py", ], ) diff --git a/tensorflow/python/saved_model/utils.py b/tensorflow/python/saved_model/utils.py index 6d80e4f1bd..550eed0fcc 100644 --- a/tensorflow/python/saved_model/utils.py +++ b/tensorflow/python/saved_model/utils.py @@ -21,23 +21,25 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.framework import dtypes # TensorInfo helpers. -def build_tensor_info(name=None, dtype=None, shape=None): +def build_tensor_info(tensor): """Utility function to build TensorInfo proto. Args: - name: Name of the tensor to be used in the TensorInfo. - dtype: Datatype to be set in the TensorInfo. - shape: TensorShapeProto to specify the shape of the tensor in the - TensorInfo. + tensor: Tensor whose name, dtype and shape are used to build the TensorInfo. Returns: - A TensorInfo protocol buffer constructed based on the supplied arguments. + A TensorInfo protocol buffer constructed based on the supplied argument. """ - return meta_graph_pb2.TensorInfo(name=name, dtype=dtype, shape=shape) + dtype_enum = dtypes.as_dtype(tensor.dtype).as_datatype_enum + return meta_graph_pb2.TensorInfo( + name=tensor.name, + dtype=dtype_enum, + tensor_shape=tensor.get_shape().as_proto()) # SignatureDef helpers. diff --git a/tensorflow/python/saved_model/utils_test.py b/tensorflow/python/saved_model/utils_test.py new file mode 100644 index 0000000000..8ce7d1dea1 --- /dev/null +++ b/tensorflow/python/saved_model/utils_test.py @@ -0,0 +1,69 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SavedModel utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.core.framework import types_pb2 +from tensorflow.python.saved_model import utils + + +class UtilsTest(tf.test.TestCase): + + def testBuildTensorInfo(self): + x = tf.placeholder(tf.float32, 1, name="x") + x_tensor_info = utils.build_tensor_info(x) + self.assertEqual("x:0", x_tensor_info.name) + self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info.dtype) + self.assertEqual(1, len(x_tensor_info.tensor_shape.dim)) + self.assertEqual(1, x_tensor_info.tensor_shape.dim[0].size) + + def testBuildSignatureDef(self): + x = tf.placeholder(tf.float32, 1, name="x") + x_tensor_info = utils.build_tensor_info(x) + inputs = dict() + inputs["foo-input"] = x_tensor_info + + y = tf.placeholder(tf.float32, name="y") + y_tensor_info = utils.build_tensor_info(y) + outputs = dict() + outputs["foo-output"] = y_tensor_info + + signature_def = utils.build_signature_def(inputs, outputs, + "foo-method-name") + self.assertEqual("foo-method-name", signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(1, len(signature_def.inputs)) + x_tensor_info_actual = signature_def.inputs["foo-input"] + self.assertEqual("x:0", x_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype) + self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim)) + self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size) + + # Check outputs in signature def. + self.assertEqual(1, len(signature_def.outputs)) + y_tensor_info_actual = signature_def.outputs["foo-output"] + self.assertEqual("y:0", y_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype) + self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) + + +if __name__ == "__main__": + tf.test.main() |