aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 09:31:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 09:31:07 -0700
commitaf959c2f25e29ff1a76d8ad2a8100780e1bca8cf (patch)
treec393bef2c2fb02831dd4c99a20510502c4920f0b /tensorflow/contrib/layers
parente2ce9787d9927e4a6574e6ac4606a47712320170 (diff)
parent413ac36f33deb0c354dd687963d2410eab048970 (diff)
Merge pull request #18567 from imsheridan:fix_expand_dims
PiperOrigin-RevId: 214278672
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/target_column.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py
index 69bb6be814..8a6b4f68a8 100644
--- a/tensorflow/contrib/layers/python/layers/target_column.py
+++ b/tensorflow/contrib/layers/python/layers/target_column.py
@@ -396,7 +396,7 @@ class _BinarySvmTargetColumn(_MultiClassTargetColumn):
def _mean_squared_loss(logits, target):
# To prevent broadcasting inside "-".
if len(target.get_shape()) == 1:
- target = array_ops.expand_dims(target, dim=[1])
+ target = array_ops.expand_dims(target, axis=1)
logits.get_shape().assert_is_compatible_with(target.get_shape())
return math_ops.square(logits - math_ops.to_float(target))
@@ -405,7 +405,7 @@ def _mean_squared_loss(logits, target):
def _log_loss_with_two_classes(logits, target):
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
if len(target.get_shape()) == 1:
- target = array_ops.expand_dims(target, dim=[1])
+ target = array_ops.expand_dims(target, axis=1)
loss_vec = nn.sigmoid_cross_entropy_with_logits(
labels=math_ops.to_float(target), logits=logits)
return loss_vec