aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-12-08 11:09:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-08 11:24:51 -0800
commita5eb3d676d7edb62217c9b1e7001540d867f29eb (patch)
treeb6a837ddb3a7f49645fe88264df7137582e5a871
parentb59082b8afa5130d455e97b017cc4548201cb658 (diff)
SparseTensor.shape -> SparseTensor.dense_shape part 2...
Change: 141460319
-rw-r--r--tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py4
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops_test.py73
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py6
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py2
-rw-r--r--tensorflow/contrib/layers/python/ops/sparse_ops_test.py28
-rw-r--r--tensorflow/contrib/learn/python/learn/dataframe/transforms/densify.py3
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py8
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/svm_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/tensor_signature_test.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/dataframe/boolean_mask_test.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/dataframe/example_parser_test.py6
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py22
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py2
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py6
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_test.py4
-rw-r--r--tensorflow/python/client/session_test.py12
-rw-r--r--tensorflow/python/framework/sparse_tensor_test.py7
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py4
-rw-r--r--tensorflow/python/kernel_tests/ctc_decoder_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/sets_test.py8
-rw-r--r--tensorflow/python/kernel_tests/sparse_add_op_test.py8
-rw-r--r--tensorflow/python/kernel_tests/sparse_concat_op_test.py10
-rw-r--r--tensorflow/python/kernel_tests/sparse_reorder_op_test.py10
-rw-r--r--tensorflow/python/kernel_tests/sparse_reshape_op_test.py22
-rw-r--r--tensorflow/python/kernel_tests/sparse_serialization_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py2
-rw-r--r--tensorflow/python/ops/array_ops.py3
-rw-r--r--tensorflow/python/ops/confusion_matrix.py2
-rw-r--r--tensorflow/python/ops/control_flow_grad.py5
-rw-r--r--tensorflow/python/ops/control_flow_ops.py6
-rw-r--r--tensorflow/python/ops/math_ops.py20
-rw-r--r--tensorflow/python/ops/metrics.py8
-rw-r--r--tensorflow/python/ops/parsing_ops.py10
-rw-r--r--tensorflow/python/training/input_test.py150
36 files changed, 236 insertions, 233 deletions
diff --git a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py
index 3bdfc3e81b..c856a952dd 100644
--- a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py
+++ b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py
@@ -420,12 +420,12 @@ class SparseCrossOpTest(tf.test.TestCase):
self.assertEquals(0, sp.indices.size)
self.assertEquals(0, sp.values.size)
# TODO(zakaria): check if we can ignore the first dim of the shape.
- self.assertEquals(0, sp.shape[1])
+ self.assertEquals(0, sp.dense_shape[1])
def _assert_sparse_tensor_equals(self, sp1, sp2):
self.assertAllEqual(sp1.indices.eval(), sp2.indices)
self.assertAllEqual(sp1.values.eval(), sp2.values)
- self.assertAllEqual(sp1.shape.eval(), sp2.shape)
+ self.assertAllEqual(sp1.dense_shape.eval(), sp2.dense_shape)
def _sparse_tensor(self, data, batch_size=-1):
"""Generates a SparseTensor.
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index 8255b69ac8..032ea57cf5 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -91,8 +91,8 @@ class TransformerTest(tf.test.TestCase):
all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
self.assertAllEqual(output[hashed_sparse].indices.eval(),
wire_tensor.indices.eval())
- self.assertAllEqual(output[hashed_sparse].shape.eval(),
- wire_tensor.shape.eval())
+ self.assertAllEqual(output[hashed_sparse].dense_shape.eval(),
+ wire_tensor.dense_shape.eval())
def testSparseIntColumnWithHashBucket(self):
"""Tests a sparse column with int values."""
@@ -113,8 +113,8 @@ class TransformerTest(tf.test.TestCase):
all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
self.assertAllEqual(output[hashed_sparse].indices.eval(),
wire_tensor.indices.eval())
- self.assertAllEqual(output[hashed_sparse].shape.eval(),
- wire_tensor.shape.eval())
+ self.assertAllEqual(output[hashed_sparse].dense_shape.eval(),
+ wire_tensor.dense_shape.eval())
def testSparseColumnWithHashBucketWithDenseInputTensor(self):
hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
@@ -129,7 +129,7 @@ class TransformerTest(tf.test.TestCase):
self.assertTrue(all(x < 10 and x >= 0 for x in output.values.eval()))
self.assertAllEqual(output.indices.eval(),
[[0, 0], [0, 1], [1, 0], [1, 1]])
- self.assertAllEqual(output.shape.eval(), [2, 2])
+ self.assertAllEqual(output.dense_shape.eval(), [2, 2])
def testEmbeddingColumn(self):
hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
@@ -144,7 +144,8 @@ class TransformerTest(tf.test.TestCase):
with self.test_session():
self.assertAllEqual(output.values.eval(), expected.values.eval())
self.assertAllEqual(output.indices.eval(), expected.indices.eval())
- self.assertAllEqual(output.shape.eval(), expected.shape.eval())
+ self.assertAllEqual(
+ output.dense_shape.eval(), expected.dense_shape.eval())
# Test transform features.
output = tf.contrib.layers.transform_features(
@@ -170,8 +171,8 @@ class TransformerTest(tf.test.TestCase):
self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0])
self.assertAllEqual(output[keys_sparse].indices.eval(),
wire_tensor.indices.eval())
- self.assertAllEqual(output[keys_sparse].shape.eval(),
- wire_tensor.shape.eval())
+ self.assertAllEqual(output[keys_sparse].dense_shape.eval(),
+ wire_tensor.dense_shape.eval())
def testSparseColumnWithKeysWithDenseInputTensor(self):
keys_sparse = tf.contrib.layers.sparse_column_with_keys(
@@ -189,7 +190,7 @@ class TransformerTest(tf.test.TestCase):
self.assertAllEqual(output.values.eval(), [1, 2, 0, 3])
self.assertAllEqual(output.indices.eval(),
[[0, 0], [0, 1], [1, 0], [1, 1]])
- self.assertAllEqual(output.shape.eval(), [2, 2])
+ self.assertAllEqual(output.dense_shape.eval(), [2, 2])
def testSparseColumnWithHashBucket_IsIntegerized(self):
hashed_sparse = tf.contrib.layers.sparse_column_with_integerized_feature(
@@ -209,15 +210,15 @@ class TransformerTest(tf.test.TestCase):
all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
self.assertAllEqual(output[hashed_sparse].indices.eval(),
wire_tensor.indices.eval())
- self.assertAllEqual(output[hashed_sparse].shape.eval(),
- wire_tensor.shape.eval())
+ self.assertAllEqual(output[hashed_sparse].dense_shape.eval(),
+ wire_tensor.dense_shape.eval())
def testSparseColumnWithHashBucketWithDenseInputTensor_IsIntegerized(self):
hashed_sparse = tf.contrib.layers.sparse_column_with_integerized_feature(
"wire", 10)
# wire_tensor = tf.SparseTensor(values=[100, 1, 25],
# indices=[[0, 0], [1, 0], [1, 1]],
- # shape=[2, 2])
+ # dense_shape=[2, 2])
wire_tensor = tf.constant([[100, 0], [1, 25]])
features = {"wire": wire_tensor}
output = feature_column_ops._Transformer(features).transform(hashed_sparse)
@@ -228,7 +229,7 @@ class TransformerTest(tf.test.TestCase):
self.assertTrue(all(x < 10 and x >= 0 for x in output.values.eval()))
self.assertAllEqual(output.indices.eval(),
[[0, 0], [0, 1], [1, 0], [1, 1]])
- self.assertAllEqual(output.shape.eval(), [2, 2])
+ self.assertAllEqual(output.dense_shape.eval(), [2, 2])
def testWeightedSparseColumn(self):
ids = tf.contrib.layers.sparse_column_with_keys(
@@ -250,13 +251,13 @@ class TransformerTest(tf.test.TestCase):
print(output)
with self.test_session():
tf.initialize_all_tables().run()
- self.assertAllEqual(output[weighted_ids][0].shape.eval(),
- ids_tensor.shape.eval())
+ self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(),
+ ids_tensor.dense_shape.eval())
self.assertAllEqual(output[weighted_ids][0].indices.eval(),
ids_tensor.indices.eval())
self.assertAllEqual(output[weighted_ids][0].values.eval(), [2, 2, 0])
- self.assertAllEqual(output[weighted_ids][1].shape.eval(),
- weights_tensor.shape.eval())
+ self.assertAllEqual(output[weighted_ids][1].dense_shape.eval(),
+ weights_tensor.dense_shape.eval())
self.assertAllEqual(output[weighted_ids][1].indices.eval(),
weights_tensor.indices.eval())
self.assertEqual(output[weighted_ids][1].values.dtype, tf.float32)
@@ -395,7 +396,7 @@ class CreateInputLayersForDNNsTest(tf.test.TestCase):
"ids": tf.SparseTensor(
values=["c", "b", "a"],
indices=[[0, 0], [1, 0], [2, 0]],
- shape=[3, 1]),
+ dense_shape=[3, 1]),
"income": tf.constant([[20.3, 10], [110.3, 0.4], [-3.0, 30.4]])
}
output = tf.contrib.layers.input_from_feature_columns(features,
@@ -487,13 +488,13 @@ class CreateInputLayersForDNNsTest(tf.test.TestCase):
ids_tensor = tf.SparseTensor(
values=["c", "b", "a", "c"],
indices=[[0, 0], [1, 0], [2, 0], [2, 1]],
- shape=[3, 2])
+ dense_shape=[3, 2])
weighted_ids_column = tf.contrib.layers.weighted_sparse_column(ids_column,
"weights")
weights_tensor = tf.SparseTensor(
values=[10.0, 20.0, 30.0, 40.0],
indices=[[0, 0], [1, 0], [2, 0], [2, 1]],
- shape=[3, 2])
+ dense_shape=[3, 2])
features = {"ids": ids_tensor, "weights": weights_tensor}
one_hot_column = tf.contrib.layers.one_hot_column(weighted_ids_column)
with self.test_session():
@@ -528,7 +529,7 @@ class CreateInputLayersForDNNsTest(tf.test.TestCase):
ids_tensor = tf.SparseTensor(
values=["c", "b", "a", "c"],
indices=[[0, 0], [1, 0], [2, 0], [2, 1]],
- shape=[3, 2])
+ dense_shape=[3, 2])
one_hot_sparse = tf.contrib.layers.one_hot_column(ids_column)
features = {"ids": ids_tensor}
output = tf.contrib.layers.input_from_feature_columns(features,
@@ -547,7 +548,7 @@ class CreateInputLayersForDNNsTest(tf.test.TestCase):
features = {"ids": tf.SparseTensor(
values=[2, 1, 0, 2],
indices=[[0, 0], [1, 0], [2, 0], [2, 1]],
- shape=[3, 2])}
+ dense_shape=[3, 2])}
output = tf.contrib.layers.input_from_feature_columns(features,
[one_hot_sparse])
with self.test_session():
@@ -560,7 +561,7 @@ class CreateInputLayersForDNNsTest(tf.test.TestCase):
wire_tensor = tf.SparseTensor(
values=["a", "b", "c1", "c2"],
indices=[[0, 0], [1, 0], [2, 0], [2, 1]],
- shape=[3, 2])
+ dense_shape=[3, 2])
features = {"feat": wire_tensor}
one_hot_sparse = tf.contrib.layers.one_hot_column(hashed_sparse)
output = tf.contrib.layers.input_from_feature_columns(features,
@@ -575,7 +576,7 @@ class CreateInputLayersForDNNsTest(tf.test.TestCase):
wire_tensor = tf.SparseTensor(
values=["omar", "stringer", "marlo", "xx", "yy"],
indices=[[0, 0], [1, 0], [1, 1], [2, 0], [3, 0]],
- shape=[4, 2])
+ dense_shape=[4, 2])
features = {"wire": wire_tensor}
embeded_sparse = tf.contrib.layers.embedding_column(hashed_sparse, 10)
output = tf.contrib.layers.input_from_feature_columns(features,
@@ -994,7 +995,7 @@ class SequenceInputFromFeatureColumnTest(tf.test.TestCase):
indices=[[0, 0, 0], [0, 1, 0],
[1, 0, 0], [1, 0, 1], [1, 1, 0],
[3, 2, 0]],
- shape=[4, 3, 2])
+ dense_shape=[4, 3, 2])
ids_column = tf.contrib.layers.sparse_column_with_keys(
"ids", ["a", "b", "c", "unseen"])
@@ -1027,7 +1028,7 @@ class SequenceInputFromFeatureColumnTest(tf.test.TestCase):
indices=[[0, 0, 0], [0, 1, 0],
[1, 0, 0], [1, 0, 1], [1, 1, 0],
[3, 2, 0]],
- shape=[4, 3, 2])
+ dense_shape=[4, 3, 2])
hashed_ids_column = tf.contrib.layers.sparse_column_with_hash_bucket(
"ids", hash_buckets)
@@ -1054,7 +1055,7 @@ class SequenceInputFromFeatureColumnTest(tf.test.TestCase):
indices=[[0, 0, 0], [0, 1, 0],
[1, 0, 0], [1, 0, 1], [1, 1, 0],
[3, 2, 0]],
- shape=[4, 3, 2])
+ dense_shape=[4, 3, 2])
expected_input_shape = np.array([4, 3, embedding_dimension])
@@ -1083,7 +1084,7 @@ class SequenceInputFromFeatureColumnTest(tf.test.TestCase):
indices=[[0, 0, 0], [0, 1, 0],
[1, 0, 0], [1, 0, 1], [1, 1, 0],
[3, 2, 0]],
- shape=[4, 3, 2])
+ dense_shape=[4, 3, 2])
hashed_ids_column = tf.contrib.layers.sparse_column_with_hash_bucket(
"ids", hash_buckets)
@@ -1128,7 +1129,7 @@ class SequenceInputFromFeatureColumnTest(tf.test.TestCase):
indices=[[0, 0, 0], [0, 1, 0],
[1, 0, 0], [1, 0, 1], [1, 1, 0],
[3, 2, 0]],
- shape=[4, 3, 2])
+ dense_shape=[4, 3, 2])
id_tensor = tf.SparseTensor(
values=[2, 5,
26, 123, 1,
@@ -1136,7 +1137,7 @@ class SequenceInputFromFeatureColumnTest(tf.test.TestCase):
indices=[[0, 0, 0], [0, 0, 1], [0, 1, 1],
[1, 0, 0], [1, 1, 0],
[3, 2, 0]],
- shape=[4, 3, 2])
+ dense_shape=[4, 3, 2])
columns_to_tensors = {"measurements": measurement_tensor,
"country": country_tensor,
@@ -1481,7 +1482,7 @@ class WeightedSumTest(tf.test.TestCase):
"movies": tf.SparseTensor(
values=["matrix", "head-on", "winter sleep"],
indices=[[0, 0], [0, 1], [1, 0]],
- shape=[2, 2])
+ dense_shape=[2, 2])
}
output, column_to_variable, _ = (
tf.contrib.layers.weighted_sum_from_feature_columns(features,
@@ -1924,7 +1925,7 @@ class WeightedSumTest(tf.test.TestCase):
"language": tf.SparseTensor(
values=["hindi", "english", "arabic", "russian"],
indices=[[0, 0], [1, 0], [2, 0], [3, 0]],
- shape=[4, 1])
+ dense_shape=[4, 1])
}
output, column_to_variable, _ = (
tf.contrib.layers.weighted_sum_from_feature_columns(features,
@@ -1976,7 +1977,7 @@ class WeightedSumTest(tf.test.TestCase):
"language": tf.SparseTensor(
values=["english", "spanish", "russian", "swahili"],
indices=[[0, 0], [1, 0], [2, 0], [3, 0]],
- shape=[4, 1]),
+ dense_shape=[4, 1]),
"country": tf.SparseTensor(values=["US", "SV", "RU", "KE"],
indices=[[0, 0], [1, 0], [2, 0], [3, 0]],
dense_shape=[4, 1])
@@ -2005,7 +2006,7 @@ class WeightedSumTest(tf.test.TestCase):
"language": tf.SparseTensor(
values=["hindi", "english", "turkish", "turkish", "english"],
indices=[[0, 0], [0, 1], [1, 0], [2, 0], [3, 0]],
- shape=[4, 2])
+ dense_shape=[4, 2])
}
output, column_to_variable, _ = (
tf.contrib.layers.weighted_sum_from_feature_columns(features,
@@ -2135,13 +2136,13 @@ class ParseExampleTest(tf.test.TestCase):
self.assertAllEqual(location_val.indices, np.array([[0]]))
self.assertAllEqual(location_val.values, np.array([b"west_side"]))
- self.assertAllEqual(location_val.shape, np.array([1]))
+ self.assertAllEqual(location_val.dense_shape, np.array([1]))
self.assertAllEqual(wire_cast_val.indices, np.array(
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [2, 0]]))
self.assertAllEqual(wire_cast_val.values, np.array(
[b"marlo", b"stringer", b"omar", b"stringer", b"marlo", b"marlo"]))
- self.assertAllEqual(wire_cast_val.shape, np.array([3, 3]))
+ self.assertAllEqual(wire_cast_val.dense_shape, np.array([3, 3]))
self.assertAllClose(
measurement_val, np.array([[0.2, 0.3], [0.1, 0.8], [0.5, 0.0]]))
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 12c0bdbcde..53de49f67a 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -1182,7 +1182,7 @@ def flatten(inputs,
Returns:
a flattened tensor with shape [batch_size, k].
Raises:
- ValueError: if inputs.shape is wrong.
+ ValueError: if inputs.dense_shape is wrong.
"""
with ops.name_scope(scope, 'Flatten', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
@@ -1200,8 +1200,8 @@ def flatten(inputs,
def _sparse_inner_flatten(inputs, new_rank):
"""Helper function for `inner_flatten`."""
- outer_dimensions = inputs.shape[:new_rank - 1]
- inner_dimensions = inputs.shape[new_rank - 1:]
+ outer_dimensions = inputs.dense_shape[:new_rank - 1]
+ inner_dimensions = inputs.dense_shape[new_rank - 1:]
new_shape = array_ops.concat_v2((outer_dimensions,
[math_ops.reduce_prod(inner_dimensions)]), 0)
flattened = sparse_ops.sparse_reshape(inputs, new_shape)
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 009fbdc485..c2e4cf3846 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1380,7 +1380,7 @@ class PartialFlattenTest(tf.test.TestCase):
np.testing.assert_array_equal(expected_indices, flattened.indices)
np.testing.assert_array_equal(expected_values, flattened.values)
- np.testing.assert_array_equal(expected_shape, flattened.shape)
+ np.testing.assert_array_equal(expected_shape, flattened.dense_shape)
def testIncompleteShape(self):
"""Test `_inner_flatten` shape inference for incomplete shapes."""
diff --git a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
index 9245bf7367..3a078b8ae6 100644
--- a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
+++ b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
@@ -32,10 +32,10 @@ class SparseOpsTest(tf.test.TestCase):
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
self.assertEqual(result.values.dtype, np.int32)
- self.assertEqual(result.shape.dtype, np.int64)
+ self.assertEqual(result.dense_shape.dtype, np.int64)
self.assertAllEqual([[0], [2]], result.indices)
self.assertAllEqual([1, 2], result.values)
- self.assertAllEqual([4], result.shape)
+ self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_float(self):
with self.test_session() as sess:
@@ -43,10 +43,10 @@ class SparseOpsTest(tf.test.TestCase):
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
self.assertEqual(result.values.dtype, np.float32)
- self.assertEqual(result.shape.dtype, np.int64)
+ self.assertEqual(result.dense_shape.dtype, np.int64)
self.assertAllEqual([[0], [2]], result.indices)
self.assertAllClose([1.5, 2.3], result.values)
- self.assertAllEqual([4], result.shape)
+ self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_bool(self):
with self.test_session() as sess:
@@ -54,10 +54,10 @@ class SparseOpsTest(tf.test.TestCase):
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
self.assertEqual(result.values.dtype, np.bool)
- self.assertEqual(result.shape.dtype, np.int64)
+ self.assertEqual(result.dense_shape.dtype, np.int64)
self.assertAllEqual([[0], [2]], result.indices)
self.assertAllEqual([True, True], result.values)
- self.assertAllEqual([4], result.shape)
+ self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_str(self):
with self.test_session() as sess:
@@ -65,10 +65,10 @@ class SparseOpsTest(tf.test.TestCase):
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
self.assertEqual(result.values.dtype, np.object)
- self.assertEqual(result.shape.dtype, np.int64)
+ self.assertEqual(result.dense_shape.dtype, np.int64)
self.assertAllEqual([[0], [2]], result.indices)
self.assertAllEqual([b'qwe', b'ewq'], result.values)
- self.assertAllEqual([4], result.shape)
+ self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_str_special_ignore(self):
with self.test_session() as sess:
@@ -77,10 +77,10 @@ class SparseOpsTest(tf.test.TestCase):
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
self.assertEqual(result.values.dtype, np.object)
- self.assertEqual(result.shape.dtype, np.int64)
+ self.assertEqual(result.dense_shape.dtype, np.int64)
self.assertAllEqual([[1], [2], [3]], result.indices)
self.assertAllEqual([b'', b'ewq', b''], result.values)
- self.assertAllEqual([4], result.shape)
+ self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_2d(self):
with self.test_session() as sess:
@@ -89,7 +89,7 @@ class SparseOpsTest(tf.test.TestCase):
self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
result.indices)
self.assertAllEqual([1, 2, 3, 4, 5], result.values)
- self.assertAllEqual([2, 4], result.shape)
+ self.assertAllEqual([2, 4], result.dense_shape)
def test_dense_to_sparse_tensor_3d(self):
with self.test_session() as sess:
@@ -99,7 +99,7 @@ class SparseOpsTest(tf.test.TestCase):
self.assertAllEqual([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [0, 1, 2],
[1, 0, 0], [1, 0, 1], [1, 1, 0]], result.indices)
self.assertAllEqual([1, 2, 3, 4, 5, 7, 8, 9], result.values)
- self.assertAllEqual([2, 2, 4], result.shape)
+ self.assertAllEqual([2, 2, 4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_no_shape(self):
with self.test_session() as sess:
@@ -108,7 +108,7 @@ class SparseOpsTest(tf.test.TestCase):
result = sess.run(st, feed_dict={tensor: [0, 100, 0, 3]})
self.assertAllEqual([[1], [3]], result.indices)
self.assertAllEqual([100, 3], result.values)
- self.assertAllEqual([4], result.shape)
+ self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_3d_no_shape(self):
with self.test_session() as sess:
@@ -122,7 +122,7 @@ class SparseOpsTest(tf.test.TestCase):
self.assertAllEqual([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [0, 1, 2],
[1, 0, 0], [1, 0, 1], [1, 1, 0]], result.indices)
self.assertAllEqual([1, 2, 3, 4, 5, 7, 8, 9], result.values)
- self.assertAllEqual([2, 2, 4], result.shape)
+ self.assertAllEqual([2, 2, 4], result.dense_shape)
def test_convert_to_sparse_undef_shape(self):
with self.test_session():
diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/densify.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/densify.py
index d5b55ee359..61abf41339 100644
--- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/densify.py
+++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/densify.py
@@ -62,5 +62,4 @@ class Densify(transform.TensorFlowTransform):
# pylint: disable=not-callable
return self.return_type(sparse_ops.sparse_to_dense(
- s.indices, s.shape, s.values, default_value=self.default_value))
-
+ s.indices, s.dense_shape, s.values, default_value=self.default_value))
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
index ff059a9727..2540abe248 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
@@ -63,7 +63,7 @@ class EmbeddingMultiplierTest(tf.test.TestCase):
tf.SparseTensor(
values=['en', 'fr', 'zh'],
indices=[[0, 0], [1, 0], [2, 0]],
- shape=[3, 1]),
+ dense_shape=[3, 1]),
}
labels = tf.constant([[0], [0], [0]], dtype=tf.int32)
with self.assertRaisesRegexp(
@@ -94,12 +94,12 @@ class EmbeddingMultiplierTest(tf.test.TestCase):
tf.SparseTensor(
values=['en', 'fr', 'zh'],
indices=[[0, 0], [1, 0], [2, 0]],
- shape=[3, 1]),
+ dense_shape=[3, 1]),
'wire':
tf.SparseTensor(
values=['omar', 'stringer', 'marlo'],
indices=[[0, 0], [1, 0], [2, 0]],
- shape=[3, 1]),
+ dense_shape=[3, 1]),
}
labels = tf.constant([[0], [0], [0]], dtype=tf.int32)
model_ops = dnn_linear_combined._dnn_linear_combined_model_fn(
@@ -178,7 +178,7 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
features['dummy_sparse_column'] = tf.SparseTensor(
values=['en', 'fr', 'zh'],
indices=[[0, 0], [0, 1], [60, 0]],
- shape=[len(iris.target), 2])
+ dense_shape=[len(iris.target), 2])
labels = tf.reshape(tf.constant(iris.target, dtype=tf.int32), [-1, 1])
return features, labels
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index 15b56dfb72..068e0a5afb 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -163,7 +163,7 @@ class RegressionModelHeadTest(tf.test.TestCase):
labels = tf.SparseTensor(
indices=tf.constant([[0, 0], [1, 0], [2, 0]], dtype=tf.int64),
values=tf.constant([0., 1., 1.]),
- shape=[3, 1])
+ dense_shape=[3, 1])
with self.assertRaisesRegexp(
ValueError, "SparseTensor is not supported as labels."):
head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN,
@@ -362,7 +362,7 @@ class BinaryClassificationModelHeadTest(tf.test.TestCase):
labels = tf.SparseTensor(
indices=tf.constant([[0, 0], [1, 0], [2, 0]], dtype=tf.int64),
values=tf.constant([0, 1, 1]),
- shape=[3, 1])
+ dense_shape=[3, 1])
with self.assertRaisesRegexp(
ValueError, "SparseTensor is not supported as labels."):
head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN,
diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm_test.py b/tensorflow/contrib/learn/python/learn/estimators/svm_test.py
index d60a061b87..6d8c9599f6 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/svm_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/svm_test.py
@@ -169,7 +169,7 @@ class SVMTest(tf.test.TestCase):
'country': tf.SparseTensor(
values=['IT', 'US', 'GB'],
indices=[[0, 0], [1, 0], [2, 0]],
- shape=[3, 1]),
+ dense_shape=[3, 1]),
}, tf.constant([[0], [1], [1]])
price = tf.contrib.layers.real_valued_column('price')
@@ -220,7 +220,7 @@ class SVMTest(tf.test.TestCase):
'country': tf.SparseTensor(
values=['IT', 'US', 'GB'],
indices=[[0, 0], [1, 3], [2, 1]],
- shape=[3, 5]),
+ dense_shape=[3, 5]),
'weights': tf.constant([[3.0], [1.0], [1.0]])
}, tf.constant([[1], [0], [1]])
diff --git a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature_test.py b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature_test.py
index 84c0b7c2d2..8620a12e1b 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature_test.py
@@ -107,7 +107,7 @@ class TensorSignatureTest(tf.test.TestCase):
def testSparseTensorSignaturePlaceholders(self):
tensor = tf.SparseTensor(values=[1.0, 2.0], indices=[[0, 2], [0, 3]],
- shape=[5, 5])
+ dense_shape=[5, 5])
signature = tensor_signature.create_signatures(tensor)
placeholder = tensor_signature.create_placeholders_from_signatures(
signature)
diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/boolean_mask_test.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/boolean_mask_test.py
index 1e3a069b6d..9a81e3e482 100644
--- a/tensorflow/contrib/learn/python/learn/tests/dataframe/boolean_mask_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/boolean_mask_test.py
@@ -71,7 +71,7 @@ class BooleanMaskTestCase(tf.test.TestCase):
np.testing.assert_array_equal(expected_indices, actual.indices)
np.testing.assert_array_equal(expected_values, actual.values)
- np.testing.assert_array_equal(shape, actual.shape)
+ np.testing.assert_array_equal(shape, actual.dense_shape)
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/example_parser_test.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/example_parser_test.py
index 099983b467..4d360364f3 100644
--- a/tensorflow/contrib/learn/python/learn/tests/dataframe/example_parser_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/example_parser_test.py
@@ -102,7 +102,8 @@ class ExampleParserTestCase(tf.test.TestCase):
with self.test_session() as sess:
string_feature, int_feature = sess.run(output_tensors)
- np.testing.assert_array_equal(string_feature.shape, np.array([2, 2]))
+ np.testing.assert_array_equal(
+ string_feature.dense_shape, np.array([2, 2]))
np.testing.assert_array_equal(int_feature.shape, np.array([2, 3]))
np.testing.assert_array_equal(self.expected_string_values,
string_feature.values)
@@ -121,7 +122,8 @@ class ExampleParserTestCase(tf.test.TestCase):
with self.test_session() as sess:
int_feature, string_feature = sess.run(output_tensors)
- np.testing.assert_array_equal(string_feature.shape, np.array([2, 2]))
+ np.testing.assert_array_equal(
+ string_feature.dense_shape, np.array([2, 2]))
np.testing.assert_array_equal(int_feature.shape, np.array([2, 3]))
np.testing.assert_array_equal(self.expected_string_values,
string_feature.values)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index b0bf232f18..a0a185c6bf 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2294,7 +2294,7 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
sp_labels = tf.SparseTensorValue(
indices=np.array([[0,], [1,], [2,]], np.int64),
values=np.array([2, 7, 8], np.int64),
- shape=np.array([10,], np.int64))
+ dense_shape=np.array([10,], np.int64))
with self.assertRaises(ValueError):
precision, _ = metrics.streaming_sparse_precision_at_top_k(
@@ -2568,7 +2568,7 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
# values -1 and 10 are outside the [0, n_classes) range and are ignored.
values=np.array([2, 7, -1, 8,
1, 2, 5, 10], np.int64),
- shape=[2, 4])
+ dense_shape=[2, 4])
# Class 2: 2 labels, 2 correct predictions.
self._test_streaming_sparse_precision_at_k(
@@ -3032,7 +3032,7 @@ class StreamingSparseRecallTest(tf.test.TestCase):
# values -1 and 10 are outside the [0, n_classes) range.
values=np.array([2, 7, -1, 8,
1, 2, 5, 10], np.int64),
- shape=[2, 4])
+ dense_shape=[2, 4])
# Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
@@ -4628,7 +4628,7 @@ class NumRelevantTest(tf.test.TestCase):
(2, 1, 0), (2, 1, 1),
(2, 2, 0)),
values=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13),
- shape=(3, 3, 3))
+ dense_shape=(3, 3, 3))
self.assertAllEqual(
((1, 1, 0), (1, 1, 1), (0, 1, 1)),
metric_ops.num_relevant(labels, k=1).eval())
@@ -4659,7 +4659,7 @@ class ExpandAndTileTest(tf.test.TestCase):
indices=[
(i, j, k) for i in range(3) for j in range(3) for k in range(3)],
values=[1] * 27,
- shape=[3, 3, 3])
+ dense_shape=[3, 3, 3])
with self.assertRaisesRegexp(ValueError, 'nvalid multiple'):
metric_ops.expand_and_tile(x, multiple=0)
@@ -4756,7 +4756,7 @@ class ExpandAndTileTest(tf.test.TestCase):
def _assert_sparse_tensors_equal(self, expected, actual):
self.assertAllEqual(expected.indices, actual.indices)
self.assertAllEqual(expected.values, actual.values)
- self.assertAllEqual(expected.shape, actual.shape)
+ self.assertAllEqual(expected.dense_shape, actual.dense_shape)
# TODO(ptucker): Use @parameterized when it's available in tf.
def testSparseExpandAndTile1x(self):
@@ -4770,7 +4770,7 @@ class ExpandAndTileTest(tf.test.TestCase):
1, 2,
3, 4, 5,
6],
- shape=[3, 3])
+ dense_shape=[3, 3])
with self.test_session():
expected_result_dim0 = tf.SparseTensorValue(
indices=[[0, i[0], i[1]] for i in x.indices], values=x.values,
@@ -4810,12 +4810,12 @@ class ExpandAndTileTest(tf.test.TestCase):
1, 2,
3, 4, 5,
6),
- shape=(3, 3))
+ dense_shape=(3, 3))
with self.test_session():
expected_result_dim0 = tf.SparseTensorValue(
indices=[(d0, i[0], i[1]) for d0 in range(5) for i in x.indices],
values=[v for _ in range(5) for v in x.values],
- shape=(5, 3, 3))
+ dense_shape=(5, 3, 3))
self._assert_sparse_tensors_equal(
expected_result_dim0,
metric_ops.expand_and_tile(x, multiple=5).eval())
@@ -4831,7 +4831,7 @@ class ExpandAndTileTest(tf.test.TestCase):
for d1 in range(5)
for i in x.indices if i[0] == d0],
values=x.values[0:2] * 5 + x.values[2:5] * 5 + x.values[5:] * 5,
- shape=(3, 5, 3))
+ dense_shape=(3, 5, 3))
for dim in (-1, 1):
self._assert_sparse_tensors_equal(
expected_result_dim1,
@@ -4840,7 +4840,7 @@ class ExpandAndTileTest(tf.test.TestCase):
expected_result_dim2 = tf.SparseTensorValue(
indices=[(i[0], i[1], d2) for i in x.indices for d2 in range(5)],
values=[v for v in x.values for _ in range(5)],
- shape=(3, 3, 5))
+ dense_shape=(3, 3, 5))
self._assert_sparse_tensors_equal(
expected_result_dim2,
metric_ops.expand_and_tile(x, multiple=5, dim=2).eval())
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
index d842e4380a..a2c71ae334 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
@@ -246,7 +246,7 @@ class SparseTensor(ItemHandler):
elif self._shape:
shape = self._shape
else:
- shape = indices.shape
+ shape = indices.dense_shape
indices_shape = array_ops.shape(indices.indices)
rank = indices_shape[1]
ids = math_ops.to_int64(indices.values)
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
index f572f9c3e9..1fa826de38 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
@@ -471,7 +471,7 @@ class TFExampleDecoderTest(tf.test.TestCase):
labels = tf_labels.eval()
self.assertAllEqual(labels.indices, np_indices)
self.assertAllEqual(labels.values, np_values)
- self.assertAllEqual(labels.shape, np_values.shape)
+ self.assertAllEqual(labels.dense_shape, np_values.shape)
def testDecodeExampleWithSparseTensorWithKeyShape(self):
np_indices = np.array([[1], [2], [5]])
@@ -501,7 +501,7 @@ class TFExampleDecoderTest(tf.test.TestCase):
labels = tf_labels.eval()
self.assertAllEqual(labels.indices, np_indices)
self.assertAllEqual(labels.values, np_values)
- self.assertAllEqual(labels.shape, np_shape)
+ self.assertAllEqual(labels.dense_shape, np_shape)
def testDecodeExampleWithSparseTensorWithGivenShape(self):
np_indices = np.array([[1], [2], [5]])
@@ -529,7 +529,7 @@ class TFExampleDecoderTest(tf.test.TestCase):
labels = tf_labels.eval()
self.assertAllEqual(labels.indices, np_indices)
self.assertAllEqual(labels.values, np_values)
- self.assertAllEqual(labels.shape, np_shape)
+ self.assertAllEqual(labels.dense_shape, np_shape)
def testDecodeExampleWithSparseTensorToDense(self):
np_indices = np.array([1, 2, 5])
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index 0a0f473855..d7919ae5d9 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -110,7 +110,7 @@ class TensorForestTest(test_util.TensorFlowTestCase):
-1., 2.,
1.,
-2.0],
- shape=[4, 10])
+ dense_shape=[4, 10])
input_labels = [0, 1, 2, 3]
params = tensor_forest.ForestHParams(
@@ -131,7 +131,7 @@ class TensorForestTest(test_util.TensorFlowTestCase):
-1., 2.,
1.,
-2.0],
- shape=[4, 10])
+ dense_shape=[4, 10])
params = tensor_forest.ForestHParams(
num_classes=4, num_features=10, num_trees=10, max_nodes=1000,
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index bb0d0acbf5..13fda7aef4 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -554,7 +554,7 @@ class SessionTest(test_util.TensorFlowTestCase):
array_ops.placeholder(dtype=np.int64, shape=(3,)),)
sp_indices = array_ops.identity(sp.indices)
sp_values = array_ops.identity(sp.values)
- sp_shape = array_ops.identity(sp.shape)
+ sp_shape = array_ops.identity(sp.dense_shape)
sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
# Feed with tuple
indices_out, values_out, shape_out = s.run(
@@ -595,7 +595,7 @@ class SessionTest(test_util.TensorFlowTestCase):
sp = array_ops.sparse_placeholder(dtype=np.float32, name='placeholder1')
sp_indices = array_ops.identity(sp.indices)
sp_values = array_ops.identity(sp.values)
- sp_shape = array_ops.identity(sp.shape)
+ sp_shape = array_ops.identity(sp.dense_shape)
sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
# Feed with tuple
indices_out, values_out, shape_out = s.run(
@@ -626,7 +626,7 @@ class SessionTest(test_util.TensorFlowTestCase):
shape=[None, 9, 2], dtype=np.float32, name='placeholder1')
sp_indices = array_ops.identity(sp.indices)
sp_values = array_ops.identity(sp.values)
- sp_shape = array_ops.identity(sp.shape)
+ sp_shape = array_ops.identity(sp.dense_shape)
sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
# Feed with tuple
indices_out, values_out, shape_out = s.run(
@@ -656,11 +656,11 @@ class SessionTest(test_util.TensorFlowTestCase):
sp = array_ops.sparse_placeholder(dtype=np.float32,
shape=shape,
name='placeholder1')
- self.assertAllEqual(sp.shape.eval(session=s), shape)
- self.assertAllEqual(tensor_util.constant_value(sp.shape), shape)
+ self.assertAllEqual(sp.dense_shape.eval(session=s), shape)
+ self.assertAllEqual(tensor_util.constant_value(sp.dense_shape), shape)
sp_indices = array_ops.identity(sp.indices)
sp_values = array_ops.identity(sp.values)
- sp_shape = array_ops.identity(sp.shape)
+ sp_shape = array_ops.identity(sp.dense_shape)
# Feed with tuple
indices_out, values_out, shape_out = s.run(
[sp_indices, sp_values, sp_shape], {sp: (indices, values)})
diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py
index afd815b498..19a2b187b9 100644
--- a/tensorflow/python/framework/sparse_tensor_test.py
+++ b/tensorflow/python/framework/sparse_tensor_test.py
@@ -45,11 +45,11 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
value = sp.eval()
self.assertAllEqual(indices, value.indices)
self.assertAllEqual(values, value.values)
- self.assertAllEqual(shape, value.shape)
+ self.assertAllEqual(shape, value.dense_shape)
sess_run_value = sess.run(sp)
self.assertAllEqual(sess_run_value.indices, value.indices)
self.assertAllEqual(sess_run_value.values, value.values)
- self.assertAllEqual(sess_run_value.shape, value.shape)
+ self.assertAllEqual(sess_run_value.dense_shape, value.dense_shape)
class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
@@ -75,7 +75,8 @@ class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
for convertee in [from_value, from_tensor]:
self.assertAllEqual(sparse_tensor_value.indices, convertee.indices)
self.assertAllEqual(sparse_tensor_value.values, convertee.values)
- self.assertAllEqual(sparse_tensor_value.dense_shape, convertee.shape)
+ self.assertAllEqual(
+ sparse_tensor_value.dense_shape, convertee.dense_shape)
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 952370d0f0..1a68ff5fba 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -817,7 +817,7 @@ class ShapeSizeRankTest(test_util.TensorFlowTestCase):
sp_value = tf.SparseTensorValue(
indices=((0, 1), (1, 0)),
values=(42, 24),
- shape=(2, 2))
+ dense_shape=(2, 2))
self.assertAllEqual((2, 2), tf.shape(sp_value).eval())
self.assertEqual(4, tf.size(sp_value).eval())
self.assertEqual(2, tf.rank(sp_value).eval())
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 4384dc03ca..3fd2ae4a9e 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -747,13 +747,13 @@ class ControlFlowTest(tf.test.TestCase):
return i < 10
def b(i, x):
return [i + 1, tf.SparseTensor(x.indices, x.values * 2.0,
- x.shape)]
+ x.dense_shape)]
_, r = tf.while_loop(c, b, [i, x])
self.assertEqual(r.dense_shape.get_shape()[0].value, 1)
_, r = tf.while_loop(c, b, [i, x],
[i.get_shape(), tensor_shape.TensorShape([None])])
- self.assertTrue(r.shape.get_shape()[0].value is None)
+ self.assertTrue(r.dense_shape.get_shape()[0].value is None)
with self.assertRaisesRegexp(ValueError, "is not compatible with"):
_, r = tf.while_loop(c, b, [i, x],
diff --git a/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py b/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
index 69f47ea6cf..e01edc88c1 100644
--- a/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
+++ b/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
@@ -77,7 +77,7 @@ class CTCGreedyDecoderTest(tf.test.TestCase):
self.assertEqual([None, truth_st[0].shape[1]],
tf_st.indices.get_shape().as_list())
self.assertEqual([None], tf_st.values.get_shape().as_list())
- self.assertShapeEqual(truth_st[2], tf_st.shape)
+ self.assertShapeEqual(truth_st[2], tf_st.dense_shape)
# Make sure decoded probabilities match
self.assertAllClose(output_log_probability, log_prob_truth, atol=1e-6)
diff --git a/tensorflow/python/kernel_tests/sets_test.py b/tensorflow/python/kernel_tests/sets_test.py
index d5eb8181ba..44ec440d6c 100644
--- a/tensorflow/python/kernel_tests/sets_test.py
+++ b/tensorflow/python/kernel_tests/sets_test.py
@@ -389,7 +389,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual((expected_rows,),
result_sparse_tensor.values.get_shape().as_list())
self.assertAllEqual((expected_rank,),
- result_sparse_tensor.shape.get_shape().as_list())
+ result_sparse_tensor.dense_shape.get_shape().as_list())
def _set_intersection(self, a, b):
# Validate that we get the same results with or without `validate_indices`,
@@ -407,7 +407,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
for i in range(1, 4):
self.assertAllEqual(results[0].indices, results[i].indices)
self.assertAllEqual(results[0].values, results[i].values)
- self.assertAllEqual(results[0].shape, results[i].shape)
+ self.assertAllEqual(results[0].dense_shape, results[i].dense_shape)
return results[0]
def _set_intersection_count(self, a, b):
@@ -761,7 +761,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
for i in range(1, 4):
self.assertAllEqual(results[0].indices, results[i].indices)
self.assertAllEqual(results[0].values, results[i].values)
- self.assertAllEqual(results[0].shape, results[i].shape)
+ self.assertAllEqual(results[0].dense_shape, results[i].dense_shape)
return results[0]
def _set_difference_count(self, a, b, aminusb=True):
@@ -967,7 +967,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
for i in range(1, 4):
self.assertAllEqual(results[0].indices, results[i].indices)
self.assertAllEqual(results[0].values, results[i].values)
- self.assertAllEqual(results[0].shape, results[i].shape)
+ self.assertAllEqual(results[0].dense_shape, results[i].dense_shape)
return results[0]
def _set_union_count(self, a, b):
diff --git a/tensorflow/python/kernel_tests/sparse_add_op_test.py b/tensorflow/python/kernel_tests/sparse_add_op_test.py
index d1c7eeadc9..6184c4edc4 100644
--- a/tensorflow/python/kernel_tests/sparse_add_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_add_op_test.py
@@ -85,7 +85,7 @@ class SparseAddTest(tf.test.TestCase):
self.assertAllEqual(
sum_out.indices, [[0, 1], [1, 0], [2, 0], [2, 1]])
self.assertAllEqual(sum_out.values, [2, 4, 6, 8])
- self.assertAllEqual(sum_out.shape, [3, 3])
+ self.assertAllEqual(sum_out.dense_shape, [3, 3])
def testAddSelfAndNegation(self):
with self.test_session(use_gpu=False) as sess:
@@ -98,7 +98,7 @@ class SparseAddTest(tf.test.TestCase):
self.assertEqual(sp_sum.dense_shape.get_shape(), [2])
self.assertAllEqual(sum_out.indices, np.empty([0, 2]))
self.assertAllEqual(sum_out.values, [])
- self.assertAllEqual(sum_out.shape, [3, 3])
+ self.assertAllEqual(sum_out.dense_shape, [3, 3])
def testSmallValuesShouldVanish(self):
with self.test_session(use_gpu=False) as sess:
@@ -117,7 +117,7 @@ class SparseAddTest(tf.test.TestCase):
self.assertEqual(sp_sum.dense_shape.get_shape(), [2])
self.assertAllEqual(sum_out.indices, [[0, 1], [2, 0]])
self.assertAllEqual(sum_out.values, [2, 6])
- self.assertAllEqual(sum_out.shape, [3, 3])
+ self.assertAllEqual(sum_out.dense_shape, [3, 3])
# only .1 vanishes
sp_sum = tf.sparse_add(sp_a, sp_b, thresh=0.11)
@@ -126,7 +126,7 @@ class SparseAddTest(tf.test.TestCase):
self.assertEqual(sp_sum.dense_shape.get_shape(), [2])
self.assertAllEqual(sum_out.indices, [[0, 1], [2, 0], [2, 1]])
self.assertAllClose(sum_out.values, [2, 6, -.2])
- self.assertAllEqual(sum_out.shape, [3, 3])
+ self.assertAllEqual(sum_out.dense_shape, [3, 3])
def testGradients(self):
np.random.seed(1618) # Make it reproducible.
diff --git a/tensorflow/python/kernel_tests/sparse_concat_op_test.py b/tensorflow/python/kernel_tests/sparse_concat_op_test.py
index b65610dcc7..2b4bd24ff5 100644
--- a/tensorflow/python/kernel_tests/sparse_concat_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_concat_op_test.py
@@ -144,7 +144,7 @@ class SparseConcatTest(tf.test.TestCase):
self.assertAllEqual(concat_out.indices,
[[0, 2], [1, 0], [2, 0], [2, 2]])
self.assertAllEqual(concat_out.values, [1, 2, 3, 4])
- self.assertAllEqual(concat_out.shape, [3, 3])
+ self.assertAllEqual(concat_out.dense_shape, [3, 3])
def testConcat2(self):
with self.test_session(use_gpu=False) as sess:
@@ -167,7 +167,7 @@ class SparseConcatTest(tf.test.TestCase):
[2, 0], [2, 2], [2, 3],
[2, 6], [2, 7]])
self.assertAllEqual(concat_out.values, [1, 2, 1, 3, 4, 2, 1, 0])
- self.assertAllEqual(concat_out.shape, [3, 8])
+ self.assertAllEqual(concat_out.dense_shape, [3, 8])
def testConcatDim0(self):
with self.test_session(use_gpu=False) as sess:
@@ -193,7 +193,7 @@ class SparseConcatTest(tf.test.TestCase):
concat_out.indices,
[[0, 2], [1, 0], [2, 0], [2, 2], [3, 1], [4, 0], [4, 2]])
self.assertAllEqual(concat_out.values, np.array([1, 2, 3, 4, 1, 1, 2]))
- self.assertAllEqual(concat_out.shape, np.array([5, 3]))
+ self.assertAllEqual(concat_out.dense_shape, np.array([5, 3]))
def testConcat3(self):
with self.test_session(use_gpu=False) as sess:
@@ -218,7 +218,7 @@ class SparseConcatTest(tf.test.TestCase):
[2, 0], [2, 2], [2, 3], [2, 6],
[2, 7], [2, 8]])
self.assertAllEqual(concat_out.values, [1, 2, 1, 1, 3, 4, 2, 1, 0, 2])
- self.assertAllEqual(concat_out.shape, [3, 10])
+ self.assertAllEqual(concat_out.dense_shape, [3, 10])
def testConcatNonNumeric(self):
with self.test_session(use_gpu=False) as sess:
@@ -243,7 +243,7 @@ class SparseConcatTest(tf.test.TestCase):
[[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]])
self.assertAllEqual(concat_out.values,
[b"a", b"b", b"e", b"c", b"d", b"f", b"g", b"h"])
- self.assertAllEqual(concat_out.shape, [3, 8])
+ self.assertAllEqual(concat_out.dense_shape, [3, 8])
def testMismatchedRank(self):
with self.test_session(use_gpu=False):
diff --git a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
index d665e8ed86..5c8c3fb433 100644
--- a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
@@ -52,7 +52,7 @@ class SparseReorderTest(tf.test.TestCase):
output_val = sess.run(sp_output)
self.assertAllEqual(output_val.indices, input_val.indices)
self.assertAllEqual(output_val.values, input_val.values)
- self.assertAllEqual(output_val.shape, input_val.shape)
+ self.assertAllEqual(output_val.dense_shape, input_val.dense_shape)
def testFeedAlreadyInOrder(self):
with self.test_session(use_gpu=False) as sess:
@@ -63,7 +63,7 @@ class SparseReorderTest(tf.test.TestCase):
output_val = sess.run(sp_output, {sp_input: input_val})
self.assertAllEqual(output_val.indices, input_val.indices)
self.assertAllEqual(output_val.values, input_val.values)
- self.assertAllEqual(output_val.shape, input_val.shape)
+ self.assertAllEqual(output_val.dense_shape, input_val.dense_shape)
def testOutOfOrder(self):
expected_output_val = self._SparseTensorValue_5x6(np.arange(6))
@@ -75,7 +75,8 @@ class SparseReorderTest(tf.test.TestCase):
output_val = sess.run(sp_output)
self.assertAllEqual(output_val.indices, expected_output_val.indices)
self.assertAllEqual(output_val.values, expected_output_val.values)
- self.assertAllEqual(output_val.shape, expected_output_val.shape)
+ self.assertAllEqual(
+ output_val.dense_shape, expected_output_val.dense_shape)
def testFeedOutOfOrder(self):
expected_output_val = self._SparseTensorValue_5x6(np.arange(6))
@@ -88,7 +89,8 @@ class SparseReorderTest(tf.test.TestCase):
output_val = sess.run(sp_output, {sp_input: input_val})
self.assertAllEqual(output_val.indices, expected_output_val.indices)
self.assertAllEqual(output_val.values, expected_output_val.values)
- self.assertAllEqual(output_val.shape, expected_output_val.shape)
+ self.assertAllEqual(
+ output_val.dense_shape, expected_output_val.dense_shape)
def testGradients(self):
with self.test_session(use_gpu=False):
diff --git a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py
index 4b7e158d54..052a41dda3 100644
--- a/tensorflow/python/kernel_tests/sparse_reshape_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_reshape_op_test.py
@@ -55,7 +55,7 @@ class SparseReshapeTest(tf.test.TestCase):
output_val = sess.run(sp_output)
self.assertAllEqual(output_val.indices, input_val.indices)
self.assertAllEqual(output_val.values, input_val.values)
- self.assertAllEqual(output_val.shape, input_val.shape)
+ self.assertAllEqual(output_val.dense_shape, input_val.dense_shape)
def testFeedSameShape(self):
with self.test_session(use_gpu=False) as sess:
@@ -66,7 +66,7 @@ class SparseReshapeTest(tf.test.TestCase):
output_val = sess.run(sp_output, {sp_input: input_val})
self.assertAllEqual(output_val.indices, input_val.indices)
self.assertAllEqual(output_val.values, input_val.values)
- self.assertAllEqual(output_val.shape, input_val.shape)
+ self.assertAllEqual(output_val.dense_shape, input_val.dense_shape)
def testFeedSameShapeWithInferredDim(self):
with self.test_session(use_gpu=False) as sess:
@@ -77,7 +77,7 @@ class SparseReshapeTest(tf.test.TestCase):
output_val = sess.run(sp_output, {sp_input: input_val})
self.assertAllEqual(output_val.indices, input_val.indices)
self.assertAllEqual(output_val.values, input_val.values)
- self.assertAllEqual(output_val.shape, input_val.shape)
+ self.assertAllEqual(output_val.dense_shape, input_val.dense_shape)
def testFeedNewShapeSameRank(self):
with self.test_session(use_gpu=False) as sess:
@@ -90,7 +90,7 @@ class SparseReshapeTest(tf.test.TestCase):
[0, 0], [0, 6], [0, 9], [1, 0], [2, 0], [2, 1]
]))
self.assertAllEqual(output_val.values, input_val.values)
- self.assertAllEqual(output_val.shape, [3, 10])
+ self.assertAllEqual(output_val.dense_shape, [3, 10])
def testFeedNewShapeSameRankWithInferredDim(self):
with self.test_session(use_gpu=False) as sess:
@@ -103,7 +103,7 @@ class SparseReshapeTest(tf.test.TestCase):
[0, 0], [0, 6], [0, 9], [1, 0], [2, 0], [2, 1]
]))
self.assertAllEqual(output_val.values, input_val.values)
- self.assertAllEqual(output_val.shape, [3, 10])
+ self.assertAllEqual(output_val.dense_shape, [3, 10])
def testUpRank(self):
with self.test_session(use_gpu=False) as sess:
@@ -115,7 +115,7 @@ class SparseReshapeTest(tf.test.TestCase):
[0, 0, 0], [0, 1, 1], [0, 1, 4], [0, 2, 0], [1, 1, 0], [1, 1, 1]
]))
self.assertAllEqual(output_val.values, input_val.values)
- self.assertAllEqual(output_val.shape, [2, 3, 5])
+ self.assertAllEqual(output_val.dense_shape, [2, 3, 5])
def testFeedUpRank(self):
with self.test_session(use_gpu=False) as sess:
@@ -128,7 +128,7 @@ class SparseReshapeTest(tf.test.TestCase):
[0, 0, 0], [0, 1, 1], [0, 1, 4], [0, 2, 0], [1, 1, 0], [1, 1, 1]
]))
self.assertAllEqual(output_val.values, input_val.values)
- self.assertAllEqual(output_val.shape, [2, 3, 5])
+ self.assertAllEqual(output_val.dense_shape, [2, 3, 5])
def testFeedUpRankWithInferredDim(self):
with self.test_session(use_gpu=False) as sess:
@@ -141,7 +141,7 @@ class SparseReshapeTest(tf.test.TestCase):
[0, 0, 0], [0, 1, 1], [0, 1, 4], [0, 2, 0], [1, 1, 0], [1, 1, 1]
]))
self.assertAllEqual(output_val.values, input_val.values)
- self.assertAllEqual(output_val.shape, [2, 3, 5])
+ self.assertAllEqual(output_val.dense_shape, [2, 3, 5])
def testFeedDownRank(self):
with self.test_session(use_gpu=False) as sess:
@@ -154,7 +154,7 @@ class SparseReshapeTest(tf.test.TestCase):
[0, 1], [1, 0], [1, 2], [3, 3], [4, 1], [4, 3], [5, 2]
]))
self.assertAllEqual(output_val.values, input_val.values)
- self.assertAllEqual(output_val.shape, [6, 4])
+ self.assertAllEqual(output_val.dense_shape, [6, 4])
def testFeedDownRankWithInferredDim(self):
with self.test_session(use_gpu=False) as sess:
@@ -167,7 +167,7 @@ class SparseReshapeTest(tf.test.TestCase):
[0, 1], [1, 0], [1, 2], [3, 3], [4, 1], [4, 3], [5, 2]
]))
self.assertAllEqual(output_val.values, input_val.values)
- self.assertAllEqual(output_val.shape, [6, 4])
+ self.assertAllEqual(output_val.dense_shape, [6, 4])
def testFeedMultipleInferredDims(self):
with self.test_session(use_gpu=False) as sess:
@@ -247,7 +247,7 @@ class SparseReshapeTest(tf.test.TestCase):
output_val = sess.run(sp_output, {sp_input: input_val})
self.assertAllEqual(output_val.indices, new_indices)
self.assertAllEqual(output_val.values, new_values)
- self.assertAllEqual(output_val.shape, new_shape)
+ self.assertAllEqual(output_val.dense_shape, new_shape)
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
index 13b7fcc7c0..159c5d9d81 100644
--- a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
@@ -127,7 +127,7 @@ class SerializeSparseTest(tf.test.TestCase):
self.assertEqual(serialized_value.shape, (4, 3))
self.assertAllEqual(deserialized_value.indices, indices_value)
self.assertAllEqual(deserialized_value.values, values_value)
- self.assertAllEqual(deserialized_value.shape, shape_value)
+ self.assertAllEqual(deserialized_value.dense_shape, shape_value)
def testDeserializeFailsWrongType(self):
with self.test_session(use_gpu=False) as sess:
diff --git a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
index 240644d228..08c41bc4b6 100644
--- a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
@@ -142,7 +142,7 @@ class SparseTensorsMapTest(tf.test.TestCase):
self.assertEqual(handles_value.shape, (4,))
self.assertAllEqual(roundtrip_value.indices, indices_value)
self.assertAllEqual(roundtrip_value.values, values_value)
- self.assertAllEqual(roundtrip_value.shape, shape_value)
+ self.assertAllEqual(roundtrip_value.dense_shape, shape_value)
def testDeserializeFailsInconsistentRank(self):
with self.test_session(use_gpu=False) as sess:
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index ff918a4d0b..e25a8449cb 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1697,8 +1697,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
indices=placeholder(
dtypes.int64, shape=[None, None],
name=(name + "/indices") if name is not None else None),
- shape=shape
- )
+ dense_shape=shape)
# pylint: enable=redefined-outer-name
diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py
index 0b9e79c640..576b78b15f 100644
--- a/tensorflow/python/ops/confusion_matrix.py
+++ b/tensorflow/python/ops/confusion_matrix.py
@@ -157,7 +157,7 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32,
values = (array_ops.ones_like(predictions, dtype)
if weights is None else weights)
cm_sparse = sparse_tensor.SparseTensor(
- indices=indices, values=values, shape=math_ops.to_int64(shape))
+ indices=indices, values=values, dense_shape=math_ops.to_int64(shape))
zero_matrix = array_ops.zeros(math_ops.to_int32(shape), dtype)
return sparse_ops.sparse_add(zero_matrix, cm_sparse)
diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py
index 45004c2935..d74a5ded3c 100644
--- a/tensorflow/python/ops/control_flow_grad.py
+++ b/tensorflow/python/ops/control_flow_grad.py
@@ -154,10 +154,7 @@ def _ExitGrad(op, grad):
raise TypeError("Type %s not supported" % type(grad))
grad_ctxt.AddName(grad.values.name)
grad_ctxt.AddName(grad.indices.name)
- if isinstance(grad, ops.IndexedSlices):
- dense_shape = grad.dense_shape
- else:
- dense_shape = grad.shape
+ dense_shape = grad.dense_shape
if dense_shape is not None:
grad_ctxt.AddName(dense_shape.name)
enter_fn = control_flow_ops._Enter # pylint: disable=protected-access
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 728c9fe7b1..55548e7541 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -251,10 +251,10 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
dense_shape.set_shape(data.dense_shape.get_shape())
return ops.IndexedSlices(values, indices, dense_shape)
else:
- dense_shape = enter(data.shape, frame_name, is_constant,
+ dense_shape = enter(data.dense_shape, frame_name, is_constant,
parallel_iterations, name="dense_shape")
if use_input_shape:
- dense_shape.set_shape(data.shape.get_shape())
+ dense_shape.set_shape(data.dense_shape.get_shape())
return sparse_tensor.SparseTensor(indices, values, dense_shape)
@@ -334,7 +334,7 @@ def switch(data, pred, dtype=None, name=None):
else:
dense_shape = data.dense_shape
dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(
- data.shape, pred, name="dense_shape")
+ data.dense_shape, pred, name="dense_shape")
return (sparse_tensor.SparseTensor(ind_f, val_f, dense_shape_f),
sparse_tensor.SparseTensor(ind_t, val_t, dense_shape_t))
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index f10cc4ed21..ce30937c7d 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -296,10 +296,10 @@ def abs(x, name=None):
x_abs = gen_math_ops.complex_abs(
x.values, Tout=x.values.dtype.real_dtype, name=name)
return sparse_tensor.SparseTensor(
- indices=x.indices, values=x_abs, dense_shape=x.shape)
+ indices=x.indices, values=x_abs, dense_shape=x.dense_shape)
x_abs = gen_math_ops._abs(x.values, name=name)
return sparse_tensor.SparseTensor(
- indices=x.indices, values=x_abs, dense_shape=x.shape)
+ indices=x.indices, values=x_abs, dense_shape=x.dense_shape)
else:
x = ops.convert_to_tensor(x, name="x")
if x.dtype in (dtypes.complex64, dtypes.complex128):
@@ -336,7 +336,7 @@ def neg(x, name=None):
if isinstance(x, sparse_tensor.SparseTensor):
x_neg = gen_math_ops.neg(x.values, name=name)
return sparse_tensor.SparseTensor(
- indices=x.indices, values=x_neg, dense_shape=x.shape)
+ indices=x.indices, values=x_neg, dense_shape=x.dense_shape)
else:
return gen_math_ops.neg(x, name=name)
@@ -360,7 +360,7 @@ def sign(x, name=None):
if isinstance(x, sparse_tensor.SparseTensor):
x_sign = gen_math_ops.sign(x.values, name=name)
return sparse_tensor.SparseTensor(
- indices=x.indices, values=x_sign, dense_shape=x.shape)
+ indices=x.indices, values=x_sign, dense_shape=x.dense_shape)
else:
return gen_math_ops.sign(x, name=name)
@@ -382,7 +382,7 @@ def square(x, name=None):
if isinstance(x, sparse_tensor.SparseTensor):
x_square = gen_math_ops.square(x.values, name=name)
return sparse_tensor.SparseTensor(
- indices=x.indices, values=x_square, dense_shape=x.shape)
+ indices=x.indices, values=x_square, dense_shape=x.dense_shape)
else:
return gen_math_ops.square(x, name=name)
@@ -404,7 +404,7 @@ def sqrt(x, name=None):
if isinstance(x, sparse_tensor.SparseTensor):
x_sqrt = gen_math_ops.sqrt(x.values, name=name)
return sparse_tensor.SparseTensor(
- indices=x.indices, values=x_sqrt, dense_shape=x.shape)
+ indices=x.indices, values=x_sqrt, dense_shape=x.dense_shape)
else:
return gen_math_ops.sqrt(x, name=name)
@@ -424,7 +424,7 @@ def erf(x, name=None):
if isinstance(x, sparse_tensor.SparseTensor):
x_erf = gen_math_ops.erf(x.values, name=name)
return sparse_tensor.SparseTensor(
- indices=x.indices, values=x_erf, dense_shape=x.shape)
+ indices=x.indices, values=x_erf, dense_shape=x.dense_shape)
else:
return gen_math_ops.erf(x, name=name)
@@ -663,7 +663,7 @@ def cast(x, dtype, name=None):
with ops.name_scope(name, "Cast", [x]) as name:
if isinstance(x, sparse_tensor.SparseTensor):
values_cast = cast(x.values, base_type, name=name)
- return sparse_tensor.SparseTensor(x.indices, values_cast, x.shape)
+ return sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape)
else:
# TODO(touts): Handle what Josh said.
#
@@ -1039,7 +1039,7 @@ def _mul_dispatch(x, y, name=None):
assert isinstance(y, sparse_tensor.SparseTensor) # Case: Dense * Sparse.
new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values,
y.dense_shape, x, name)
- return sparse_tensor.SparseTensor(y.indices, new_vals, y.shape)
+ return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape)
# NOTE(aselle): When integer division is added for sparse_dense_cwise,
# div, truediv, and floordiv should be delegated appropriately for
@@ -1998,7 +1998,7 @@ def tanh(x, name=None):
if isinstance(x, sparse_tensor.SparseTensor):
x_tanh = gen_math_ops._tanh(x.values, name=name)
return sparse_tensor.SparseTensor(
- indices=x.indices, values=x_tanh, dense_shape=x.shape)
+ indices=x.indices, values=x_tanh, dense_shape=x.dense_shape)
else:
return gen_math_ops._tanh(x, name=name)
diff --git a/tensorflow/python/ops/metrics.py b/tensorflow/python/ops/metrics.py
index 35fbbd7ff4..1816fe7988 100644
--- a/tensorflow/python/ops/metrics.py
+++ b/tensorflow/python/ops/metrics.py
@@ -1449,7 +1449,7 @@ def _select_class_id(ids, selected_id):
filled_selected_id_shape, math_ops.to_int64(selected_id))
result = sets.set_intersection(filled_selected_id, ids)
return sparse_tensor.SparseTensor(
- indices=result.indices, values=result.values, shape=ids_shape)
+ indices=result.indices, values=result.values, dense_shape=ids_shape)
def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
@@ -1989,12 +1989,12 @@ def _expand_and_tile(tensor, multiple, dim=0, name=None):
if isinstance(tensor, sparse_tensor.SparseTensor):
if dim < 0:
expand_dims = array_ops.reshape(
- array_ops.size(tensor.shape) + dim, [1])
+ array_ops.size(tensor.dense_shape) + dim, [1])
else:
expand_dims = [dim]
expanded_shape = array_ops.concat_v2(
- (array_ops.slice(tensor.shape, [0], expand_dims), [1],
- array_ops.slice(tensor.shape, expand_dims, [-1])),
+ (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1],
+ array_ops.slice(tensor.dense_shape, expand_dims, [-1])),
0,
name='expanded_shape')
expanded = sparse_ops.sparse_reshape(
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 0d52004a55..f77293fd13 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -328,15 +328,15 @@ def parse_example(serialized, features, name=None, example_names=None):
"kw": SparseTensor(
indices=[[0, 0], [0, 1], [1, 0]],
values=["knit", "big", "emmy"]
- shape=[2, 2]),
+ dense_shape=[2, 2]),
"dank": SparseTensor(
indices=[[1, 0]],
values=[42],
- shape=[2, 1]),
+ dense_shape=[2, 1]),
"gps": SparseTensor(
indices=[],
values=[],
- shape=[2, 0]),
+ dense_shape=[2, 0]),
}
```
@@ -405,7 +405,7 @@ def parse_example(serialized, features, name=None, example_names=None):
"sparse": SparseTensor(
indices=[[0, 3], [0, 20], [1, 42]],
values=[0.5, -1.0, 0.0]
- shape=[2, 100]),
+ dense_shape=[2, 100]),
}
```
@@ -667,7 +667,7 @@ def _parse_single_example_raw(serialized,
array_ops.slice(outputs[s].indices,
[0, 1], [-1, -1], name="Slice_Indices_%s" % s_name),
outputs[s].values,
- array_ops.slice(outputs[s].shape,
+ array_ops.slice(outputs[s].dense_shape,
[1], [-1], name="Squeeze_Shape_%s" % s_name))
return outputs
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
index 93aae621b3..ae698c358f 100644
--- a/tensorflow/python/training/input_test.py
+++ b/tensorflow/python/training/input_test.py
@@ -426,7 +426,7 @@ class BatchTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.reshape(tf.stack([zero64, zero64 + 1]), [2, 1]),
values=tf.cast(tf.stack([counter, -counter]), tf.float32),
- shape=[2])
+ dense_shape=[2])
if use_dict:
batched = tf.train.batch(
{"c": counter, "s": sparse_counter, "S": "string"},
@@ -452,7 +452,7 @@ class BatchTest(tf.test.TestCase):
expected = np.arange(2 * i * batch_size, 2 * (i + 1) * batch_size) // 2
expected *= ([1, -1] * batch_size) # mult by [1, -1, 1, -1, ...]
self.assertAllEqual(results[1].values, expected)
- self.assertAllEqual(results[1].shape, [batch_size, 2])
+ self.assertAllEqual(results[1].dense_shape, [batch_size, 2])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Reached the limit.
@@ -507,7 +507,7 @@ class BatchTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
pre_batched = tf.train.batch(
[counter, sparse_counter, "string"], batch_size=2)
batched = tf.train.batch(pre_batched, enqueue_many=True,
@@ -525,7 +525,7 @@ class BatchTest(tf.test.TestCase):
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(
results[1].values, np.arange(i * batch_size, (i + 1) * batch_size))
- self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Reached the limit.
@@ -545,7 +545,7 @@ class BatchTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
batched = tf.train.batch(
[counter, sparse_counter, "string"],
batch_size=batch_size, num_threads=4)
@@ -562,7 +562,7 @@ class BatchTest(tf.test.TestCase):
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
- self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
all_counts.extend(results[0])
self.assertAllEqual(results[2], [b"string"] * batch_size)
self.assertItemsEqual(all_counts, range(num_batches * batch_size))
@@ -584,7 +584,7 @@ class BatchTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.reshape(tf.stack([zero64, zero64 + 1]), [2, 1]),
values=tf.cast(tf.stack([counter, -counter]), tf.float32),
- shape=[2])
+ dense_shape=[2])
batched = tf.train.batch(
[counter, sparse_counter, "string"], batch_size=batch_size,
allow_smaller_final_batch=True)
@@ -604,7 +604,7 @@ class BatchTest(tf.test.TestCase):
expected = np.arange(2 * i * batch_size, 2 * (i + 1) * batch_size) // 2
expected *= ([1, -1] * batch_size) # mult by [1, -1, 1, -1, ...]
self.assertAllEqual(results[1].values, expected)
- self.assertAllEqual(results[1].shape, [batch_size, 2])
+ self.assertAllEqual(results[1].dense_shape, [batch_size, 2])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Reached the final batch with extra_elements.
@@ -616,7 +616,7 @@ class BatchTest(tf.test.TestCase):
results[1].indices,
np.vstack((np.arange(2 * extra_elements) // 2, # 0, 0, 1, 1, ...
[0, 1] * extra_elements)).T)
- self.assertAllEqual(results[1].shape, [extra_elements, 2])
+ self.assertAllEqual(results[1].dense_shape, [extra_elements, 2])
self.assertAllEqual(results[2], [b"string"] * extra_elements)
# Reached the limit.
@@ -637,7 +637,7 @@ class BatchTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
batched = tf.train.batch(
[counter, sparse_counter, "string"],
batch_size=batch_size, num_threads=4, allow_smaller_final_batch=True)
@@ -654,7 +654,7 @@ class BatchTest(tf.test.TestCase):
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
- self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
all_counts.extend(results[0])
self.assertAllEqual(results[2], [b"string"] * batch_size)
@@ -666,7 +666,7 @@ class BatchTest(tf.test.TestCase):
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(extra_elements), np.zeros(extra_elements))).T)
- self.assertAllEqual(results[1].shape, [extra_elements, 1])
+ self.assertAllEqual(results[1].dense_shape, [extra_elements, 1])
all_counts.extend(results[0])
self.assertAllEqual(results[2], [b"string"] * extra_elements)
self.assertItemsEqual(all_counts,
@@ -703,31 +703,31 @@ class BatchTest(tf.test.TestCase):
sparse = tf.SparseTensor(indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = tf.train.batch([sparse], batch_size=2)
- self.assertAllEqual((2,), batched.shape.get_shape().as_list())
+ self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
def testBatchedSparseTensorInferredShapeEnqueueMany(self):
sparse = tf.SparseTensor(indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = tf.train.batch([sparse], batch_size=2, enqueue_many=True)
- self.assertAllEqual((1,), batched.shape.get_shape().as_list())
+ self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
def testBatchedSparseTensorInferredShapeUnknownRank(self):
sparse = tf.SparseTensor(
indices=tf.placeholder(tf.int64),
values=tf.placeholder(tf.float32),
- shape=tf.placeholder(tf.int64))
+ dense_shape=tf.placeholder(tf.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = tf.train.batch([sparse], batch_size=2)
- self.assertIs(None, batched.shape.get_shape().num_elements())
+ self.assertIs(None, batched.dense_shape.get_shape().num_elements())
def testBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
sparse = tf.SparseTensor(
indices=tf.placeholder(tf.int64),
values=tf.placeholder(tf.float32),
- shape=tf.placeholder(tf.int64))
+ dense_shape=tf.placeholder(tf.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = tf.train.batch([sparse], batch_size=2, enqueue_many=True)
- self.assertIs(None, batched.shape.get_shape().num_elements())
+ self.assertIs(None, batched.dense_shape.get_shape().num_elements())
def testSingleElementDict(self):
x = tf.train.batch({"c": [12, 12]}, batch_size=8)
@@ -742,7 +742,7 @@ class BatchTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.zeros([1, 1], dtype=tf.int64),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
to_batch = [counter, sparse_counter, "string"]
if enqueue_many:
to_batch = tf.train.batch(to_batch, 1)
@@ -782,33 +782,33 @@ class BatchTest(tf.test.TestCase):
sparse = tf.SparseTensor(indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = tf.train.maybe_batch([sparse], keep_input=True, batch_size=2)
- self.assertAllEqual((2,), batched.shape.get_shape().as_list())
+ self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
sparse = tf.SparseTensor(indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = tf.train.maybe_batch(
[sparse], keep_input=True, batch_size=2, enqueue_many=True)
- self.assertAllEqual((1,), batched.shape.get_shape().as_list())
+ self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
sparse = tf.SparseTensor(
indices=tf.placeholder(tf.int64),
values=tf.placeholder(tf.float32),
- shape=tf.placeholder(tf.int64))
+ dense_shape=tf.placeholder(tf.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = tf.train.maybe_batch([sparse], keep_input=True, batch_size=2)
- self.assertIs(None, batched.shape.get_shape().num_elements())
+ self.assertIs(None, batched.dense_shape.get_shape().num_elements())
def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
sparse = tf.SparseTensor(
indices=tf.placeholder(tf.int64),
values=tf.placeholder(tf.float32),
- shape=tf.placeholder(tf.int64))
+ dense_shape=tf.placeholder(tf.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = tf.train.maybe_batch(
[sparse], keep_input=True, batch_size=2, enqueue_many=True)
- self.assertIs(None, batched.shape.get_shape().num_elements())
+ self.assertIs(None, batched.dense_shape.get_shape().num_elements())
class BatchJoinTest(tf.test.TestCase):
@@ -823,7 +823,7 @@ class BatchJoinTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
# The second generates (99, "b") 90 times and then stops.
num_b = 90
@@ -832,7 +832,7 @@ class BatchJoinTest(tf.test.TestCase):
sparse_ninety_nine = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(ninety_nine, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
# These get joined together and grouped into batches of 5.
batch_size = 5
@@ -856,7 +856,8 @@ class BatchJoinTest(tf.test.TestCase):
(None, 2), batched_fetch[1].indices.get_shape().as_list())
self.assertAllEqual(
(None,), batched_fetch[1].values.get_shape().as_list())
- self.assertAllEqual((2,), batched_fetch[1].shape.get_shape().as_list())
+ self.assertAllEqual(
+ (2,), batched_fetch[1].dense_shape.get_shape().as_list())
self.assertAllEqual((batch_size,), batched_fetch[2].get_shape().as_list())
tf.global_variables_initializer().run()
@@ -877,7 +878,7 @@ class BatchJoinTest(tf.test.TestCase):
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
- self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), batch_size)
@@ -995,7 +996,7 @@ class BatchJoinTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
# The second generates (99, "b") 90 times and then stops.
num_b = 90 + extra_elements
@@ -1004,7 +1005,7 @@ class BatchJoinTest(tf.test.TestCase):
sparse_ninety_nine = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(ninety_nine, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
# These get joined together and grouped into batches of 5.
batch_size = 5
@@ -1019,7 +1020,7 @@ class BatchJoinTest(tf.test.TestCase):
self.assertAllEqual((None,), batched[0].get_shape().as_list())
self.assertAllEqual((None, 2), batched[1].indices.get_shape().as_list())
self.assertAllEqual((None,), batched[1].values.get_shape().as_list())
- self.assertAllEqual((2,), batched[1].shape.get_shape().as_list())
+ self.assertAllEqual((2,), batched[1].dense_shape.get_shape().as_list())
self.assertAllEqual((None,), batched[2].get_shape().as_list())
tf.global_variables_initializer().run()
@@ -1040,7 +1041,7 @@ class BatchJoinTest(tf.test.TestCase):
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
- self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), batch_size)
@@ -1060,7 +1061,7 @@ class BatchJoinTest(tf.test.TestCase):
results[1].indices,
np.vstack((np.arange(2 * extra_elements),
np.zeros(2 * extra_elements))).T)
- self.assertAllEqual(results[1].shape, [2 * extra_elements, 1])
+ self.assertAllEqual(results[1].dense_shape, [2 * extra_elements, 1])
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), 2 * extra_elements)
@@ -1211,7 +1212,7 @@ class BatchJoinTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.zeros([1, 1], dtype=tf.int64),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
to_batch = [counter, sparse_counter, "string"]
if enqueue_many:
to_batch = tf.train.batch(to_batch, 1)
@@ -1252,34 +1253,34 @@ class BatchJoinTest(tf.test.TestCase):
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = tf.train.maybe_batch_join(
[[sparse]], keep_input=True, batch_size=2)
- self.assertAllEqual((2,), batched.shape.get_shape().as_list())
+ self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
sparse = tf.SparseTensor(indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = tf.train.maybe_batch_join(
[[sparse]], keep_input=True, batch_size=2, enqueue_many=True)
- self.assertAllEqual((1,), batched.shape.get_shape().as_list())
+ self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
sparse = tf.SparseTensor(
indices=tf.placeholder(tf.int64),
values=tf.placeholder(tf.float32),
- shape=tf.placeholder(tf.int64))
+ dense_shape=tf.placeholder(tf.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = tf.train.maybe_batch_join(
[[sparse]], keep_input=True, batch_size=2)
- self.assertIs(None, batched.shape.get_shape().num_elements())
+ self.assertIs(None, batched.dense_shape.get_shape().num_elements())
def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
sparse = tf.SparseTensor(
indices=tf.placeholder(tf.int64),
values=tf.placeholder(tf.float32),
- shape=tf.placeholder(tf.int64))
+ dense_shape=tf.placeholder(tf.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = tf.train.maybe_batch_join(
[[sparse]], keep_input=True, batch_size=2, enqueue_many=True)
- self.assertIs(None, batched.shape.get_shape().num_elements())
+ self.assertIs(None, batched.dense_shape.get_shape().num_elements())
class ShuffleBatchTest(tf.test.TestCase):
@@ -1294,7 +1295,7 @@ class ShuffleBatchTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
if use_dict:
batched = tf.train.shuffle_batch(
{"c": counter, "s": sparse_counter, "S": "string"},
@@ -1320,7 +1321,7 @@ class ShuffleBatchTest(tf.test.TestCase):
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[0], results[1].values)
- self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Results scrambled, but include all the expected numbers.
deltas = [all_counts[i + 1] - all_counts[i]
@@ -1352,7 +1353,7 @@ class ShuffleBatchTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
batched = tf.train.shuffle_batch(
[counter, sparse_counter, "string"],
batch_size=batch_size, capacity=32,
@@ -1372,12 +1373,12 @@ class ShuffleBatchTest(tf.test.TestCase):
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[0], results[1].values)
- self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Reached the final batch with extra elements.
results = sess.run(batched)
- self.assertAllEqual(results[1].shape, [extra_elements, 1])
+ self.assertAllEqual(results[1].dense_shape, [extra_elements, 1])
self.assertAllEqual(results[2], [b"string"] * extra_elements)
all_counts.extend(results[0])
@@ -1403,7 +1404,7 @@ class ShuffleBatchTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
batched = tf.train.shuffle_batch(
[counter, sparse_counter, "string"],
batch_size=batch_size, capacity=32,
@@ -1422,7 +1423,7 @@ class ShuffleBatchTest(tf.test.TestCase):
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[0], results[1].values)
- self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Results scrambled, but include all the expected numbers.
deltas = [all_counts[i + 1] - all_counts[i]
@@ -1448,7 +1449,7 @@ class ShuffleBatchTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
batched = tf.train.shuffle_batch(
[counter, sparse_counter, "string"],
batch_size=batch_size, capacity=32,
@@ -1468,13 +1469,13 @@ class ShuffleBatchTest(tf.test.TestCase):
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[0], results[1].values)
- self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Reached the final batch with extra elements.
results = sess.run(batched)
self.assertAllEqual(results[0].shape, [extra_elements])
- self.assertAllEqual(results[1].shape, [extra_elements, 1])
+ self.assertAllEqual(results[1].dense_shape, [extra_elements, 1])
self.assertAllEqual(results[2], [b"string"] * extra_elements)
all_counts.extend(results[0])
@@ -1516,7 +1517,7 @@ class ShuffleBatchTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.zeros([1, 1], dtype=tf.int64),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
to_batch = [counter, sparse_counter, "string"]
if enqueue_many:
to_batch = tf.train.batch(to_batch, 1)
@@ -1556,33 +1557,33 @@ class ShuffleBatchTest(tf.test.TestCase):
sparse = tf.SparseTensor(indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = tf.train.maybe_shuffle_batch([sparse], 2, 10, 1, True)
- self.assertAllEqual((2,), batched.shape.get_shape().as_list())
+ self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
sparse = tf.SparseTensor(indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = tf.train.maybe_shuffle_batch(
[sparse], 2, 10, 1, True, enqueue_many=True)
- self.assertAllEqual((1,), batched.shape.get_shape().as_list())
+ self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
sparse = tf.SparseTensor(
indices=tf.placeholder(tf.int64),
values=tf.placeholder(tf.float32),
- shape=tf.placeholder(tf.int64))
+ dense_shape=tf.placeholder(tf.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = tf.train.maybe_shuffle_batch([sparse], 2, 10, 1, True)
- self.assertIs(None, batched.shape.get_shape().num_elements())
+ self.assertIs(None, batched.dense_shape.get_shape().num_elements())
def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
sparse = tf.SparseTensor(
indices=tf.placeholder(tf.int64),
values=tf.placeholder(tf.float32),
- shape=tf.placeholder(tf.int64))
+ dense_shape=tf.placeholder(tf.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = tf.train.maybe_shuffle_batch(
[sparse], 2, 10, 1, True, enqueue_many=True)
- self.assertIs(None, batched.shape.get_shape().num_elements())
+ self.assertIs(None, batched.dense_shape.get_shape().num_elements())
class ShuffleBatchJoinTest(tf.test.TestCase):
@@ -1597,7 +1598,7 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
# The second generates (99, "b") 35 times and then stops.
num_b = 35
@@ -1606,7 +1607,7 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
sparse_ninety_nine = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(ninety_nine, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
# These get joined together and grouped into batches of 5.
batch_size = 5
@@ -1632,7 +1633,8 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
(None, 2), batched_fetch[1].indices.get_shape().as_list())
self.assertAllEqual(
(None,), batched_fetch[1].values.get_shape().as_list())
- self.assertAllEqual((2,), batched_fetch[1].shape.get_shape().as_list())
+ self.assertAllEqual(
+ (2,), batched_fetch[1].dense_shape.get_shape().as_list())
self.assertAllEqual((batch_size,), batched_fetch[2].get_shape().as_list())
tf.global_variables_initializer().run()
@@ -1653,7 +1655,7 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
- self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), batch_size)
@@ -1696,7 +1698,7 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
# The second generates (99, "b") 37 times and then stops.
num_b = 35 + extra_elements
@@ -1705,7 +1707,7 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
sparse_ninety_nine = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.stack([tf.cast(ninety_nine, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
# These get joined together and grouped into batches of 5.
batch_size = 5
@@ -1720,7 +1722,7 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
self.assertAllEqual((None,), batched[0].get_shape().as_list())
self.assertAllEqual((None, 2), batched[1].indices.get_shape().as_list())
self.assertAllEqual((None,), batched[1].values.get_shape().as_list())
- self.assertAllEqual((2,), batched[1].shape.get_shape().as_list())
+ self.assertAllEqual((2,), batched[1].dense_shape.get_shape().as_list())
self.assertAllEqual((None,), batched[2].get_shape().as_list())
tf.global_variables_initializer().run()
@@ -1741,7 +1743,7 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
- self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), batch_size)
@@ -1754,7 +1756,7 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
# Reached end with 2 * extra_elements left
results = sess.run(batched)
self.assertEqual(len(results[0]), 2 * extra_elements)
- self.assertAllEqual(results[1].shape, [2 * extra_elements, 1])
+ self.assertAllEqual(results[1].dense_shape, [2 * extra_elements, 1])
self.assertEqual(len(results[2]), 2 * extra_elements)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(
@@ -1823,7 +1825,7 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
sparse_counter = tf.SparseTensor(
indices=tf.zeros([1, 1], dtype=tf.int64),
values=tf.stack([tf.cast(counter, tf.float32)]),
- shape=[1])
+ dense_shape=[1])
to_batch = [counter, sparse_counter, "string"]
if enqueue_many:
to_batch = tf.train.batch(to_batch, 1)
@@ -1863,33 +1865,33 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
sparse = tf.SparseTensor(indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = tf.train.maybe_shuffle_batch_join([[sparse]], 2, 10, 1, True)
- self.assertAllEqual((2,), batched.shape.get_shape().as_list())
+ self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
sparse = tf.SparseTensor(indices=[[0]], values=[1.0], dense_shape=[1])
self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
batched = tf.train.maybe_shuffle_batch_join(
[[sparse]], 2, 10, 1, True, enqueue_many=True)
- self.assertAllEqual((1,), batched.shape.get_shape().as_list())
+ self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
sparse = tf.SparseTensor(
indices=tf.placeholder(tf.int64),
values=tf.placeholder(tf.float32),
- shape=tf.placeholder(tf.int64))
+ dense_shape=tf.placeholder(tf.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = tf.train.maybe_shuffle_batch_join([[sparse]], 2, 10, 1, True)
- self.assertIs(None, batched.shape.get_shape().num_elements())
+ self.assertIs(None, batched.dense_shape.get_shape().num_elements())
def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
sparse = tf.SparseTensor(
indices=tf.placeholder(tf.int64),
values=tf.placeholder(tf.float32),
- shape=tf.placeholder(tf.int64))
+ dense_shape=tf.placeholder(tf.int64))
self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
batched = tf.train.maybe_shuffle_batch_join(
[[sparse]], 2, 10, 1, True, enqueue_many=True)
- self.assertIs(None, batched.shape.get_shape().num_elements())
+ self.assertIs(None, batched.dense_shape.get_shape().num_elements())
if __name__ == "__main__":