diff options
-rw-r--r-- | tensorflow/contrib/lite/python/convert_saved_model.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/convert_saved_model_test.py | 9 |
2 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index b952a72aab..5dad49f1ed 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -216,9 +216,9 @@ def set_tensor_shapes(tensors, shapes): """ if shapes: for tensor in tensors: - shape = shapes.get(tensor.name) + shape = shapes.get(tensor_name(tensor)) if shape is not None: - tensor.set_shape(shapes[tensor.name]) + tensor.set_shape(shape) def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py index 80e5dc6e46..1e570d2c89 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model_test.py +++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py @@ -73,10 +73,15 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase): tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) self.assertEqual([None, 3, 5], tensor.shape.as_list()) - convert_saved_model.set_tensor_shapes([tensor], - {"Placeholder:0": [5, 3, 5]}) + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]}) self.assertEqual([5, 3, 5], tensor.shape.as_list()) + def testSetTensorShapeNoneValid(self): + tensor = array_ops.placeholder(dtype=dtypes.float32) + + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]}) + self.assertEqual([1, 3, 5], tensor.shape.as_list()) + def testSetTensorShapeInvalid(self): tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) self.assertEqual([None, 3, 5], tensor.shape.as_list()) |