aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2015-12-15 17:34:21 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-15 17:34:21 -0800
commit881dc225ecb32064681c7bf2229d796565ad7956 (patch)
tree556e1398df01704aa07adf8be0c95cc42584b8ed
parente361930cbcea67f8f1b742cabf5950818b3bc11d (diff)
Implement `tensor_util.ConstantValue()` support for `tf.concat()`.
This should make more shape inference possible, when concat is used to build, e.g., shape vectors. Change: 110303876
-rw-r--r--tensorflow/python/framework/tensor_util.py11
-rw-r--r--tensorflow/python/framework/tensor_util_test.py21
2 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 56fe4c5561..af6ae4df5a 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -555,5 +555,16 @@ def ConstantValue(tensor):
return None
cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
return pre_cast.astype(cast_dtype.as_numpy_dtype)
+ elif tensor.op.type == "Concat":
+ dim = ConstantValue(tensor.op.inputs[0])
+ if dim is None:
+ return None
+ values = []
+ for x in tensor.op.inputs[1:]:
+ value = ConstantValue(x)
+ if value is None:
+ return None
+ values.append(value)
+ return np.concatenate(values, axis=dim)
else:
return None
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 0e2847faf9..6d3823741d 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -412,5 +412,26 @@ class ConstantValueTest(test_util.TensorFlowTestCase):
c_val = tensor_util.ConstantValue(tf_val)
self.assertAllClose(np_val.astype(np.float64), c_val)
+ def testConcat(self):
+ np_val = np.random.rand(3, 4, 7).astype(np.float32)
+ tf_val = array_ops.concat(
+ 0, [np_val[0:1, :, :], np_val[1:2, :, :], np_val[2:3, :, :]])
+ c_val = tensor_util.ConstantValue(tf_val)
+ self.assertAllClose(np_val, c_val)
+
+ tf_val = array_ops.concat(
+ array_ops.placeholder(dtypes.int32),
+ [np_val[0, :, :], np_val[1, :, :], np_val[2, :, :]])
+ c_val = tensor_util.ConstantValue(tf_val)
+ self.assertIs(None, c_val)
+
+ tf_val = array_ops.concat(
+ 1,
+ [np_val[0, :, :], array_ops.placeholder(dtypes.float32),
+ np_val[2, :, :]])
+ c_val = tensor_util.ConstantValue(tf_val)
+ self.assertIs(None, c_val)
+
+
if __name__ == "__main__":
googletest.main()