diff options
-rw-r--r-- | tensorflow/contrib/specs/python/specs_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/scalar_strict_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 19 |
3 files changed, 4 insertions, 21 deletions
diff --git a/tensorflow/contrib/specs/python/specs_test.py b/tensorflow/contrib/specs/python/specs_test.py index a25532ab41..d0e650bc5f 100644 --- a/tensorflow/contrib/specs/python/specs_test.py +++ b/tensorflow/contrib/specs/python/specs_test.py @@ -118,8 +118,8 @@ class SpecsTest(tf.test.TestCase): result = outputs.eval() self.assertEqual(tuple(result.shape), (10, 30)) self.assertEqual(summaries.tf_spec_structure(spec, inputs), - "_ _ var dot var biasadd sig " - "<> var dot var biasadd sig concat") + "_ var dot var biasadd sig " + "<> var dot var biasadd sig _ concatv2") def testImport(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/scalar_strict_test.py b/tensorflow/python/kernel_tests/scalar_strict_test.py index 1ad13c1c68..eec6ec3f94 100644 --- a/tensorflow/python/kernel_tests/scalar_strict_test.py +++ b/tensorflow/python/kernel_tests/scalar_strict_test.py @@ -65,7 +65,7 @@ class ScalarStrictTest(tf.test.TestCase): def testConcat(self): self.check(tf.concat, ([0], ([2], [3], [7])), - 'concat_dim tensor should be a scalar integer', [2, 3, 7]) + 'axis tensor should be a scalar integer', [2, 3, 7]) for data in (2, 3, 7), (2, [3], 7), (2, 3, [7]): self.check(tf.concat, (0, data), r'Expected \w+ dimensions in the range \[0, 0\)', [2, 3, 7]) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 788d689ee9..724479c2e7 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1111,24 +1111,7 @@ def concat(concat_dim, values, name="concat"): Returns: A `Tensor` resulting from concatenation of the input tensors. """ - # TODO(annarev): switch to call concat_v2 instead. - if not isinstance(values, (list, tuple)): - values = [values] - # TODO(mrry): Change to return values? - if len(values) == 1: # Degenerate case of one tensor. - # Make a throwaway call to convert_to_tensor to make sure - # that axis is of the correct type, and make sure that - # the returned tensor is a scalar. - # TODO(keveman): Implement a standalone type and shape checker. - with ops.name_scope(name) as scope: - ops.convert_to_tensor(concat_dim, - name="concat_dim", - dtype=dtypes.int32).get_shape( - ).assert_is_compatible_with(tensor_shape.scalar()) - return identity(values[0], name=scope) - return gen_array_ops._concat(concat_dim=concat_dim, - values=values, - name=name) + return concat_v2(values, concat_dim, name) def boolean_mask(tensor, mask, name="boolean_mask"): |