aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-27 18:43:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 18:49:39 -0700
commit96f3428e33e18477661b8d8cf78f2db457c8881b (patch)
tree3c2c68337369a259feca09db0dfc2b4cbdfdf147 /tensorflow/python/feature_column
parenta332fea0be8def4aa5985499ad807ef78d029142 (diff)
Let feature columns correctly handle rank-1 sparse tensors from an empty batch.
reshape can't determine the size of the last dimension when reshaping shape (0) to (0, 1). PiperOrigin-RevId: 214872677
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py2
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py12
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py2
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py16
4 files changed, 30 insertions, 2 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 226e273660..618e70f3a5 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -2318,7 +2318,7 @@ class _LazyBuilder(object):
# Input_tensor must have rank 1.
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
return sparse_ops.sparse_reshape(
- input_tensor, [array_ops.shape(input_tensor)[0], -1])
+ input_tensor, [array_ops.shape(input_tensor)[0], 1])
else:
return array_ops.expand_dims(input_tensor, -1)
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index abb79efa68..1ae510250c 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -169,6 +169,18 @@ class LazyColumnTest(test.TestCase):
TypeError, '"key" must be either a "str" or "_FeatureColumn".'):
builder.get(NotAFeatureColumn())
+ def test_expand_dim_rank_1_sparse_tensor_empty_batch(self):
+ # empty 1-D sparse tensor:
+ builder = _LazyBuilder(features={'a': sparse_tensor.SparseTensor(
+ indices=np.reshape(np.array([], dtype=np.int64), (0, 1)),
+ dense_shape=[0],
+ values=np.array([]))})
+ with self.cached_session():
+ spv = builder.get('a').eval()
+ self.assertAllEqual(np.array([0, 1], dtype=np.int64), spv.dense_shape)
+ self.assertAllEqual(
+ np.reshape(np.array([], dtype=np.int64), (0, 2)), spv.indices)
+
class NumericColumnTest(test.TestCase):
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index 289f6d0d14..538641c251 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -2341,7 +2341,7 @@ class FeatureTransformationCache(object):
# Input_tensor must have rank 1.
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
return sparse_ops.sparse_reshape(
- input_tensor, [array_ops.shape(input_tensor)[0], -1])
+ input_tensor, [array_ops.shape(input_tensor)[0], 1])
else:
return array_ops.expand_dims(input_tensor, -1)
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index 58168e0f9e..2970431167 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -177,6 +177,22 @@ class LazyColumnTest(test.TestCase):
TypeError, '"key" must be either a "str" or "FeatureColumn".'):
transformation_cache.get(NotAFeatureColumn(), None)
+ def test_expand_dim_rank_1_sparse_tensor_empty_batch(self):
+ # empty 1-D sparse tensor:
+ transformation_cache = FeatureTransformationCache(
+ features={
+ 'a':
+ sparse_tensor.SparseTensor(
+ indices=np.reshape(np.array([], dtype=np.int64), (0, 1)),
+ dense_shape=[0],
+ values=np.array([]))
+ })
+ with self.cached_session():
+ spv = transformation_cache.get('a', None).eval()
+ self.assertAllEqual(np.array([0, 1], dtype=np.int64), spv.dense_shape)
+ self.assertAllEqual(
+ np.reshape(np.array([], dtype=np.int64), (0, 2)), spv.indices)
+
class NumericColumnTest(test.TestCase):