diff options
3 files changed, 43 insertions, 9 deletions
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 917c07398f..4329e22f48 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -1649,6 +1649,15 @@ class _RealValuedVarLenColumn(_FeatureColumn, collections.namedtuple( input_tensor = self._normalized_input_tensor(columns_to_tensors[self.name]) columns_to_tensors[self] = math_ops.to_float(input_tensor) + # pylint: disable=unused-argument + def _to_dnn_input_layer(self, + input_tensor, + weight_collections=None, + trainable=True, + output_rank=2): + return _reshape_real_valued_tensor( + self._to_dense_tensor(input_tensor), output_rank, self.name) + def _to_dense_tensor(self, input_tensor): if not self.is_sparse: return input_tensor diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index d010ae6df1..fa0047f05d 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -856,7 +856,8 @@ def _add_variable_collection(weight_collections): # pylint: disable=protected-access _SUPPORTED_SEQUENCE_COLUMNS = (fc._OneHotColumn, fc._EmbeddingColumn, - fc._RealValuedColumn) + fc._RealValuedColumn, + fc._RealValuedVarLenColumn) _FORBIDDEN_SEQUENCE_COLUMNS = (fc._ScatteredEmbeddingColumn, fc._BucketizedColumn, diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 5e69d41621..a9698163dd 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -633,16 +633,27 @@ class CreateInputLayersForDNNsTest(test.TestCase): fc_core.make_input_layer(features, [real_valued]).eval()) - def testRealValuedColumnSparse(self): - sparse_real_valued = feature_column._real_valued_var_len_column( + def testRealValuedColumnDense(self): + var_len_real_valued = feature_column._real_valued_var_len_column( "rating", default_value=-1) - rating = [[2.0], [-1.0], [5.0]] + rating = np.array([[0., 1., 2., -1.], + [3., 4., 5., 6.]]) features = {"rating": constant_op.constant(rating)} - with self.assertRaisesRegexp( - ValueError, - "Error creating input layer for column: rating.*"): - feature_column_ops.input_from_feature_columns(features, - [sparse_real_valued]) + with self.test_session() as sess: + output = sess.run(feature_column_ops.input_from_feature_columns( + features, [var_len_real_valued])) + self.assertAllClose(rating, output) + + def testRealValuedColumnTypeConversion(self): + var_len_real_valued = feature_column._real_valued_var_len_column( + "rating", default_value=-1) + rating = np.array([[0, 1, 2, -1], + [3, 4, 5, 6]]) + features = {"rating": constant_op.constant(rating, dtype=dtypes.int64)} + with self.test_session() as sess: + output = sess.run(feature_column_ops.input_from_feature_columns( + features, [var_len_real_valued])) + self.assertAllClose(rating.astype(np.float32), output) def testRealValuedColumnWithNormalizer(self): real_valued = feature_column.real_valued_column( @@ -1267,6 +1278,19 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): model_inputs = sess.run(model_input_tensor) self.assertAllClose(measurement_input, model_inputs) + def testRealValuedVarLenColumn(self): + var_len_real_valued = feature_column._real_valued_var_len_column( + "rating", default_value=-1) + rating = np.array([[0., 1., 2., -1.], + [3., 4., 5., 6.]]) + features = {"rating": constant_op.constant(rating)} + with self.test_session() as sess: + output = sess.run( + feature_column_ops.sequence_input_from_feature_columns( + features, [var_len_real_valued])) + reshaped_rating = np.reshape(rating, [2, 4, 1]) + self.assertAllClose(reshaped_rating, output) + def testRealValuedColumnWithExtraDimensions(self): batch_size = 4 sequence_length = 8 |