aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py4
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model_test.py9
2 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index b952a72aab..5dad49f1ed 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -216,9 +216,9 @@ def set_tensor_shapes(tensors, shapes):
"""
if shapes:
for tensor in tensors:
- shape = shapes.get(tensor.name)
+ shape = shapes.get(tensor_name(tensor))
if shape is not None:
- tensor.set_shape(shapes[tensor.name])
+ tensor.set_shape(shape)
def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py
index 80e5dc6e46..1e570d2c89 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model_test.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py
@@ -73,10 +73,15 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
self.assertEqual([None, 3, 5], tensor.shape.as_list())
- convert_saved_model.set_tensor_shapes([tensor],
- {"Placeholder:0": [5, 3, 5]})
+ convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]})
self.assertEqual([5, 3, 5], tensor.shape.as_list())
+ def testSetTensorShapeNoneValid(self):
+ tensor = array_ops.placeholder(dtype=dtypes.float32)
+
+ convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]})
+ self.assertEqual([1, 3, 5], tensor.shape.as_list())
+
def testSetTensorShapeInvalid(self):
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
self.assertEqual([None, 3, 5], tensor.shape.as_list())