aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools/saved_model_cli_test.py
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2017-12-28 13:57:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-28 14:01:06 -0800
commit2e2715baa84720f786b38d1f9cb6887399020d6f (patch)
tree2b1b5b3f367595a39575bc0a9defa2219d2f4bb2 /tensorflow/python/tools/saved_model_cli_test.py
parent711b10c280534c0ab73351bb4fd3e7ec32585236 (diff)
Fix saved_model_cli _print_tensor_info for REF types.
Fix #15611. PiperOrigin-RevId: 180292752
Diffstat (limited to 'tensorflow/python/tools/saved_model_cli_test.py')
-rw-r--r--tensorflow/python/tools/saved_model_cli_test.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py
index a55cf168b2..0789e1e107 100644
--- a/tensorflow/python/tools/saved_model_cli_test.py
+++ b/tensorflow/python/tools/saved_model_cli_test.py
@@ -28,6 +28,8 @@ import sys
import numpy as np
from six import StringIO
+from tensorflow.core.framework import types_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.platform import test
from tensorflow.python.tools import saved_model_cli
@@ -200,6 +202,14 @@ Method name is: tensorflow/serving/predict"""
self.assertEqual(output, expected_output)
self.assertEqual(err.getvalue().strip(), '')
+ def testPrintREFTypeTensor(self):
+ ref_tensor_info = meta_graph_pb2.TensorInfo()
+ ref_tensor_info.dtype = types_pb2.DT_FLOAT_REF
+ with captured_output() as (out, err):
+ saved_model_cli._print_tensor_info(ref_tensor_info)
+ self.assertTrue('DT_FLOAT_REF' in out.getvalue().strip())
+ self.assertEqual(err.getvalue().strip(), '')
+
def testInputPreProcessFormats(self):
input_str = 'input1=/path/file.txt[ab3];input2=file2'
input_expr_str = 'input3=np.zeros([2,2]);input4=[4,5]'
@@ -217,7 +227,6 @@ Method name is: tensorflow/serving/predict"""
input_str = (r'inputx=C:\Program Files\data.npz[v:0];'
r'input:0=c:\PROGRA~1\data.npy')
input_dict = saved_model_cli.preprocess_inputs_arg_string(input_str)
- print(input_dict)
self.assertTrue(input_dict['inputx'] == (r'C:\Program Files\data.npz',
'v:0'))
self.assertTrue(input_dict['input:0'] == (r'c:\PROGRA~1\data.npy', None))