aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David Soergel <soergel@google.com>2017-02-01 07:03:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-01 07:27:43 -0800
commitff8e4dcee872e017cc8f2f23ed75b070e08382ff (patch)
tree9d26f1bb2d66f20e0051bea924def4fd38a3c720
parent0392fe4438527c10961aed96193e2c6cc351eb60 (diff)
Fix bug in computing shape for SDCA fake bias column
Change: 146238867
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index a601ba42af..3668d087c2 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -35,6 +35,8 @@ from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gradients
@@ -76,8 +78,12 @@ def _add_bias_column(feature_columns, columns_to_tensors, bias_variable,
if not feature_columns:
raise ValueError("feature_columns can't be empty.")
- # Using a arbitrary input tensor to figure out batch_size.
- batch_size = array_ops.shape(next(iter(columns_to_tensors.values())))[0]
+ # Using an arbitrary input tensor to figure out batch_size.
+ some_input = next(iter(columns_to_tensors.values()))
+ if isinstance(some_input, sparse_tensor.SparseTensor):
+ batch_size = tensor_util.constant_value(some_input.dense_shape)[0]
+ else:
+ batch_size = array_ops.shape(some_input)[0]
bias_column = layers.real_valued_column(bias_column_name)
columns_to_tensors[bias_column] = array_ops.ones([batch_size, 1],