diff options
author | 2015-12-15 17:34:21 -0800 | |
---|---|---|
committer | 2015-12-15 17:34:21 -0800 | |
commit | 881dc225ecb32064681c7bf2229d796565ad7956 (patch) | |
tree | 556e1398df01704aa07adf8be0c95cc42584b8ed | |
parent | e361930cbcea67f8f1b742cabf5950818b3bc11d (diff) |
Implement `tensor_util.ConstantValue()` support for `tf.concat()`.
This should make more shape inference possible, when concat is used to
build, e.g., shape vectors.
Change: 110303876
-rw-r--r-- | tensorflow/python/framework/tensor_util.py | 11 | ||||
-rw-r--r-- | tensorflow/python/framework/tensor_util_test.py | 21 |
2 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 56fe4c5561..af6ae4df5a 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -555,5 +555,16 @@ def ConstantValue(tensor): return None cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT")) return pre_cast.astype(cast_dtype.as_numpy_dtype) + elif tensor.op.type == "Concat": + dim = ConstantValue(tensor.op.inputs[0]) + if dim is None: + return None + values = [] + for x in tensor.op.inputs[1:]: + value = ConstantValue(x) + if value is None: + return None + values.append(value) + return np.concatenate(values, axis=dim) else: return None diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 0e2847faf9..6d3823741d 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -412,5 +412,26 @@ class ConstantValueTest(test_util.TensorFlowTestCase): c_val = tensor_util.ConstantValue(tf_val) self.assertAllClose(np_val.astype(np.float64), c_val) + def testConcat(self): + np_val = np.random.rand(3, 4, 7).astype(np.float32) + tf_val = array_ops.concat( + 0, [np_val[0:1, :, :], np_val[1:2, :, :], np_val[2:3, :, :]]) + c_val = tensor_util.ConstantValue(tf_val) + self.assertAllClose(np_val, c_val) + + tf_val = array_ops.concat( + array_ops.placeholder(dtypes.int32), + [np_val[0, :, :], np_val[1, :, :], np_val[2, :, :]]) + c_val = tensor_util.ConstantValue(tf_val) + self.assertIs(None, c_val) + + tf_val = array_ops.concat( + 1, + [np_val[0, :, :], array_ops.placeholder(dtypes.float32), + np_val[2, :, :]]) + c_val = tensor_util.ConstantValue(tf_val) + self.assertIs(None, c_val) + + if __name__ == "__main__": googletest.main() |