diff options
author | 2017-02-01 07:03:57 -0800 | |
---|---|---|
committer | 2017-02-01 07:27:43 -0800 | |
commit | ff8e4dcee872e017cc8f2f23ed75b070e08382ff (patch) | |
tree | 9d26f1bb2d66f20e0051bea924def4fd38a3c720 | |
parent | 0392fe4438527c10961aed96193e2c6cc351eb60 (diff) |
Fix bug in computing shape for SDCA fake bias column
Change: 146238867
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/linear.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index a601ba42af..3668d087c2 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -35,6 +35,8 @@ from tensorflow.contrib.learn.python.learn.utils import export from tensorflow.contrib.linear_optimizer.python import sdca_optimizer from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import gradients @@ -76,8 +78,12 @@ def _add_bias_column(feature_columns, columns_to_tensors, bias_variable, if not feature_columns: raise ValueError("feature_columns can't be empty.") - # Using a arbitrary input tensor to figure out batch_size. - batch_size = array_ops.shape(next(iter(columns_to_tensors.values())))[0] + # Using an arbitrary input tensor to figure out batch_size. + some_input = next(iter(columns_to_tensors.values())) + if isinstance(some_input, sparse_tensor.SparseTensor): + batch_size = tensor_util.constant_value(some_input.dense_shape)[0] + else: + batch_size = array_ops.shape(some_input)[0] bias_column = layers.real_valued_column(bias_column_name) columns_to_tensors[bias_column] = array_ops.ones([batch_size, 1], |