diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 09:31:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 09:31:07 -0700 |
commit | af959c2f25e29ff1a76d8ad2a8100780e1bca8cf (patch) | |
tree | c393bef2c2fb02831dd4c99a20510502c4920f0b /tensorflow/contrib/layers | |
parent | e2ce9787d9927e4a6574e6ac4606a47712320170 (diff) | |
parent | 413ac36f33deb0c354dd687963d2410eab048970 (diff) |
Merge pull request #18567 from imsheridan:fix_expand_dims
PiperOrigin-RevId: 214278672
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/target_column.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py index 69bb6be814..8a6b4f68a8 100644 --- a/tensorflow/contrib/layers/python/layers/target_column.py +++ b/tensorflow/contrib/layers/python/layers/target_column.py @@ -396,7 +396,7 @@ class _BinarySvmTargetColumn(_MultiClassTargetColumn): def _mean_squared_loss(logits, target): # To prevent broadcasting inside "-". if len(target.get_shape()) == 1: - target = array_ops.expand_dims(target, dim=[1]) + target = array_ops.expand_dims(target, axis=1) logits.get_shape().assert_is_compatible_with(target.get_shape()) return math_ops.square(logits - math_ops.to_float(target)) @@ -405,7 +405,7 @@ def _mean_squared_loss(logits, target): def _log_loss_with_two_classes(logits, target): # sigmoid_cross_entropy_with_logits requires [batch_size, 1] target. if len(target.get_shape()) == 1: - target = array_ops.expand_dims(target, dim=[1]) + target = array_ops.expand_dims(target, axis=1) loss_vec = nn.sigmoid_cross_entropy_with_logits( labels=math_ops.to_float(target), logits=logits) return loss_vec |