aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-03-24 08:19:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-24 09:33:05 -0700
commit5b5b8412f0684a548e1e9001421e5d095cda0142 (patch)
tree88d3c7419498f7a6991113bac146320b833b601b
parentc404448d3b1e44fddc2d6e1c6da9862443112721 (diff)
In tf.contrib, move framework and loss packages up a level.
Simplify loss ops, only return scalar in public API (see discussion at go/tf_contrib_ops). Add softmax loss. Change: 118034691
-rw-r--r--tensorflow/contrib/BUILD2
-rw-r--r--tensorflow/contrib/__init__.py2
-rw-r--r--tensorflow/contrib/framework/BUILD42
-rw-r--r--tensorflow/contrib/framework/__init__.py30
-rw-r--r--tensorflow/contrib/framework/python/framework/__init__.py22
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util.py (renamed from tensorflow/contrib/layers/python/framework/tensor_util.py)8
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util_test.py (renamed from tensorflow/contrib/layers/python/framework/tensor_util_test.py)49
-rw-r--r--tensorflow/contrib/layers/BUILD29
-rw-r--r--tensorflow/contrib/layers/__init__.py17
-rw-r--r--tensorflow/contrib/layers/python/ops/loss_ops_test.py310
-rw-r--r--tensorflow/contrib/losses/BUILD42
-rw-r--r--tensorflow/contrib/losses/__init__.py25
-rw-r--r--tensorflow/contrib/losses/python/losses/__init__.py (renamed from tensorflow/contrib/layers/python/ops/__init__.py)2
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py (renamed from tensorflow/contrib/layers/python/ops/loss_ops.py)246
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops_test.py272
15 files changed, 535 insertions, 563 deletions
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index ece038a572..366af6b2c6 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -15,9 +15,11 @@ py_library(
deps = [
"//tensorflow/contrib/ctc:ctc_py",
"//tensorflow/contrib/distributions:distributions_py",
+ "//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/lookup:lookup_py",
+ "//tensorflow/contrib/losses:losses_py",
"//tensorflow/contrib/skflow",
"//tensorflow/contrib/testing:testing_py",
"//tensorflow/contrib/util:util_py",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index eab8b457d4..794ce70299 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -21,8 +21,10 @@ from __future__ import print_function
# Add projects here, they will show up under tf.contrib.
from tensorflow.contrib import ctc
from tensorflow.contrib import distributions
+from tensorflow.contrib import framework
from tensorflow.contrib import layers
from tensorflow.contrib import linear_optimizer
from tensorflow.contrib import lookup
+from tensorflow.contrib import losses
from tensorflow.contrib import testing
from tensorflow.contrib import util
diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD
new file mode 100644
index 0000000000..4d83b1956e
--- /dev/null
+++ b/tensorflow/contrib/framework/BUILD
@@ -0,0 +1,42 @@
+# Description:
+# contains parts of TensorFlow that are experimental or unstable and which are not supported.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+py_library(
+ name = "framework_py",
+ srcs = [
+ "__init__.py",
+ "python/framework/__init__.py",
+ "python/framework/tensor_util.py",
+ ],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "tensor_util_test",
+ srcs = glob(["python/framework/tensor_util_test.py"]),
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py
new file mode 100644
index 0000000000..be0cc3eb93
--- /dev/null
+++ b/tensorflow/contrib/framework/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Framework utilities.
+
+@@assert_same_float_dtype
+@@is_numeric_tensor
+@@assert_scalar_int
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.framework.python.framework import *
+from tensorflow.python.util.all_util import make_all
diff --git a/tensorflow/contrib/framework/python/framework/__init__.py b/tensorflow/contrib/framework/python/framework/__init__.py
new file mode 100644
index 0000000000..4a17474b63
--- /dev/null
+++ b/tensorflow/contrib/framework/python/framework/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A module containing TensorFlow ops whose API may change in the future."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=wildcard-import
+from tensorflow.contrib.framework.python.framework.tensor_util import *
diff --git a/tensorflow/contrib/layers/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py
index 1a5450630c..6b85c38f1a 100644
--- a/tensorflow/contrib/layers/python/framework/tensor_util.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util.py
@@ -13,14 +13,18 @@
# limitations under the License.
# ==============================================================================
-"""Tensor utility functions."""
+"""Tensor utility functions.
+
+@@assert_same_float_dtype
+@@is_numeric_tensor
+@@assert_scalar_int
+"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework.ops import Tensor
-
__all__ = ['assert_same_float_dtype', 'is_numeric_tensor', 'assert_scalar_int']
diff --git a/tensorflow/contrib/layers/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
index 6785ab4938..644fa9905b 100644
--- a/tensorflow/contrib/layers/python/framework/tensor_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
@@ -1,4 +1,4 @@
-# Copyright 2015 Google Inc. All Rights Reserved.
+# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,72 +21,73 @@ from __future__ import print_function
import tensorflow as tf
-import tensorflow.python.framework
-
class FloatDTypeTest(tf.test.TestCase):
def test_assert_same_float_dtype(self):
self.assertIs(
- tf.float32, tf.contrib.layers.assert_same_float_dtype(None, None))
+ tf.float32, tf.contrib.framework.assert_same_float_dtype(None, None))
self.assertIs(
- tf.float32, tf.contrib.layers.assert_same_float_dtype([], None))
+ tf.float32, tf.contrib.framework.assert_same_float_dtype([], None))
self.assertIs(
- tf.float32, tf.contrib.layers.assert_same_float_dtype([], tf.float32))
+ tf.float32,
+ tf.contrib.framework.assert_same_float_dtype([], tf.float32))
self.assertIs(
tf.float32,
- tf.contrib.layers.assert_same_float_dtype(None, tf.float32))
+ tf.contrib.framework.assert_same_float_dtype(None, tf.float32))
self.assertIs(
tf.float32,
- tf.contrib.layers.assert_same_float_dtype([None, None], None))
+ tf.contrib.framework.assert_same_float_dtype([None, None], None))
self.assertIs(
tf.float32,
- tf.contrib.layers.assert_same_float_dtype([None, None], tf.float32))
+ tf.contrib.framework.assert_same_float_dtype([None, None], tf.float32))
const_float = tf.constant(3.0, dtype=tf.float32)
self.assertIs(
tf.float32,
- tf.contrib.layers.assert_same_float_dtype([const_float], tf.float32))
+ tf.contrib.framework.assert_same_float_dtype([const_float], tf.float32))
self.assertRaises(
ValueError,
- tf.contrib.layers.assert_same_float_dtype, [const_float], tf.int32)
+ tf.contrib.framework.assert_same_float_dtype, [const_float], tf.int32)
sparse_float = tf.SparseTensor(
tf.constant([[111], [232]], tf.int64),
tf.constant([23.4, -43.2], tf.float32),
tf.constant([500], tf.int64))
- self.assertIs(tf.float32, tf.contrib.layers.assert_same_float_dtype(
+ self.assertIs(tf.float32, tf.contrib.framework.assert_same_float_dtype(
[sparse_float], tf.float32))
self.assertRaises(
ValueError,
- tf.contrib.layers.assert_same_float_dtype, [sparse_float], tf.int32)
+ tf.contrib.framework.assert_same_float_dtype, [sparse_float], tf.int32)
self.assertRaises(
- ValueError, tf.contrib.layers.assert_same_float_dtype,
+ ValueError, tf.contrib.framework.assert_same_float_dtype,
[const_float, None, sparse_float], tf.float64)
self.assertIs(
tf.float32,
- tf.contrib.layers.assert_same_float_dtype([const_float, sparse_float]))
- self.assertIs(tf.float32, tf.contrib.layers.assert_same_float_dtype(
+ tf.contrib.framework.assert_same_float_dtype(
+ [const_float, sparse_float]))
+ self.assertIs(tf.float32, tf.contrib.framework.assert_same_float_dtype(
[const_float, sparse_float], tf.float32))
const_int = tf.constant(3, dtype=tf.int32)
- self.assertRaises(ValueError, tf.contrib.layers.assert_same_float_dtype,
+ self.assertRaises(ValueError, tf.contrib.framework.assert_same_float_dtype,
[sparse_float, const_int])
- self.assertRaises(ValueError, tf.contrib.layers.assert_same_float_dtype,
+ self.assertRaises(ValueError, tf.contrib.framework.assert_same_float_dtype,
[sparse_float, const_int], tf.int32)
- self.assertRaises(ValueError, tf.contrib.layers.assert_same_float_dtype,
+ self.assertRaises(ValueError, tf.contrib.framework.assert_same_float_dtype,
[sparse_float, const_int], tf.float32)
self.assertRaises(
- ValueError, tf.contrib.layers.assert_same_float_dtype, [const_int])
+ ValueError, tf.contrib.framework.assert_same_float_dtype, [const_int])
def test_assert_scalar_int(self):
- tf.contrib.layers.assert_scalar_int(tf.constant(3, dtype=tf.int32))
- tf.contrib.layers.assert_scalar_int(tf.constant(3, dtype=tf.int64))
+ tf.contrib.framework.assert_scalar_int(tf.constant(3, dtype=tf.int32))
+ tf.contrib.framework.assert_scalar_int(tf.constant(3, dtype=tf.int64))
with self.assertRaisesRegexp(ValueError, "Unexpected type"):
- tf.contrib.layers.assert_scalar_int(tf.constant(3, dtype=tf.float32))
+ tf.contrib.framework.assert_scalar_int(tf.constant(3, dtype=tf.float32))
with self.assertRaisesRegexp(ValueError, "Unexpected shape"):
- tf.contrib.layers.assert_scalar_int(tf.constant([3, 4], dtype=tf.int32))
+ tf.contrib.framework.assert_scalar_int(
+ tf.constant([3, 4], dtype=tf.int32))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index a9ea2c745f..0a3653d7fa 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -11,15 +11,12 @@ py_library(
name = "layers_py",
srcs = [
"__init__.py",
- "python/framework/tensor_util.py",
"python/layers/__init__.py",
"python/layers/initializers.py",
"python/layers/layers.py",
"python/layers/optimizers.py",
"python/layers/regularizers.py",
"python/layers/summaries.py",
- "python/ops/__init__.py",
- "python/ops/loss_ops.py",
],
srcs_version = "PY2AND3",
)
@@ -76,19 +73,6 @@ py_test(
)
py_test(
- name = "loss_ops_test",
- size = "small",
- srcs = glob(["python/ops/loss_ops_test.py"]),
- srcs_version = "PY2AND3",
- deps = [
- ":layers_py",
- "//tensorflow:tensorflow_py",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:platform_test",
- ],
-)
-
-py_test(
name = "summaries_test",
size = "small",
srcs = glob(["python/layers/summaries_test.py"]),
@@ -101,19 +85,6 @@ py_test(
],
)
-py_test(
- name = "tensor_util_test",
- size = "small",
- srcs = glob(["python/framework/tensor_util_test.py"]),
- srcs_version = "PY2AND3",
- deps = [
- ":layers_py",
- "//tensorflow:tensorflow_py",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:platform_test",
- ],
-)
-
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py
index bbb3c2a351..ddf7f43b73 100644
--- a/tensorflow/contrib/layers/__init__.py
+++ b/tensorflow/contrib/layers/__init__.py
@@ -57,13 +57,6 @@ The layers module defines convenience functions `summarize_variables`,
of `summarize_collection` to `VARIABLES`, `WEIGHTS` and `BIASES`, respectively.
@@summarize_activations
-
-## Utilities
-
-@@assert_same_float_dtype
-@@assert_scalar_int
-@@is_numeric_tensor
-
"""
from __future__ import absolute_import
@@ -73,15 +66,5 @@ from __future__ import print_function
import sys
# pylint: disable=unused-import,wildcard-import
-from tensorflow.contrib.layers.python.framework.tensor_util import *
from tensorflow.contrib.layers.python.layers import *
-from tensorflow.contrib.layers.python.ops import *
-from tensorflow.contrib.layers.python.ops import loss_ops
from tensorflow.python.util.all_util import make_all
-
-
-# Include loss_ops to get the symbols in - but they are not documented in main
-# docs yet.
-# TODO(cwhipkey): get the loss_ops documented in the main documentation and do
-# this in a better way.
-__all__ = make_all(__name__, [sys.modules[__name__], loss_ops])
diff --git a/tensorflow/contrib/layers/python/ops/loss_ops_test.py b/tensorflow/contrib/layers/python/ops/loss_ops_test.py
deleted file mode 100644
index 1453af5331..0000000000
--- a/tensorflow/contrib/layers/python/ops/loss_ops_test.py
+++ /dev/null
@@ -1,310 +0,0 @@
-# Copyright 2015 Google Inc. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Tests for contrib.layers.python.ops.loss_ops."""
-# pylint: disable=unused-import,g-bad-import-order
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-from tensorflow.contrib.layers.python.framework import tensor_util
-
-pi = 3.14
-indiana_pi = 3.2 # https://en.wikipedia.org/wiki/Indiana_Pi_Bill
-
-
-class ReduceBatchSumTest(tf.test.TestCase):
-
- def testDimensionNone(self):
- with self.test_session():
- input_array = np.array([
- [1.0, 2.0],
- [-1.0, -2.0]
- ], dtype=np.float32)
- placeholder_vec = tf.placeholder(tf.float32, name="placeholder_vec")
- expected_result = np.array([3.0, -3.0])
- actual_result = tf.contrib.layers.reduce_batch_sum(placeholder_vec)
- self.assertEqual(actual_result.get_shape().as_list(), [None])
- self.assertAllClose(expected_result, actual_result.eval(feed_dict={
- placeholder_vec: input_array
- }))
-
- def testDimension0(self):
- with self.test_session():
- input_vec = tf.constant(2.0)
- with self.assertRaises(ValueError):
- tf.contrib.layers.reduce_batch_sum(input_vec)
-
- def testDimension1(self):
- with self.test_session():
- input_vec = tf.constant([1.0, 2.0])
- expected_result = np.array([1.0, 2.0])
- actual_result = tf.contrib.layers.reduce_batch_sum(input_vec)
- self.assertAllClose(expected_result, actual_result.eval())
-
- def testDimension2(self):
- with self.test_session():
- input_vec = tf.constant([
- [1.0, 2.0],
- [-1.0, -2.0]
- ])
- expected_result = np.array([3.0, -3.0])
- actual_result = tf.contrib.layers.reduce_batch_sum(input_vec)
- self.assertAllClose(expected_result, actual_result.eval())
-
- def testReturnShape(self):
- with self.test_session():
- input_vec = tf.constant([
- [1.0, 2.0],
- [-1.0, -2.0]
- ])
- expected_result = np.array([3.0, -3.0])
- actual_result = tf.contrib.layers.reduce_batch_sum(input_vec)
- self.assertShapeEqual(expected_result, actual_result)
-
- def testDimensionN(self):
- with self.test_session():
- input_vec = tf.constant([
- [
- [1.0, 2.0],
- [3.0, 4.0]
- ],
- [
- [5.0, 6.0],
- [7.0, 8.0]
- ]
- ])
- expected_result = np.array([10.0, 26.0])
- actual_result = tf.contrib.layers.reduce_batch_sum(input_vec)
- self.assertAllClose(expected_result, actual_result.eval())
-
-
-class AbsoluteLossTest(tf.test.TestCase):
-
- def _getTestVectors(self):
- target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
- predicted = tf.constant([1.1, -0.2, 3.3, 1.6], shape=[2, 2],
- name="predicted")
- expected_loss = np.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2)
- return target, predicted, expected_loss
-
- def testAbsoluteLoss(self):
- with self.test_session():
- target, predicted, expected_loss = self._getTestVectors()
- result = tf.contrib.layers.absolute_loss(predicted, target)
- self.assertAllClose(expected_loss, result.eval())
-
- def testAbsoluteLossReturnShape(self):
- with self.test_session():
- target, predicted, expected_loss = self._getTestVectors()
- result = tf.contrib.layers.absolute_loss(predicted, target)
- self.assertShapeEqual(expected_loss, result)
-
- def testInvalidShapesValueError(self):
- with self.test_session():
- target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
- incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
- name="incompatible_shape")
- with self.assertRaises(ValueError):
- tf.contrib.layers.absolute_loss(incompatible_shape, target)
-
-
-class SquaredLossTest(tf.test.TestCase):
-
- def _getTestVectors(self):
- target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
- predicted = tf.constant([1.1, -0.2, 3.3, 1.6], shape=[2, 2],
- name="predicted")
- expected_loss = np.array([0.005, 0.02, 0.045, 0.08]).reshape(2, 2)
- return target, predicted, expected_loss
-
- def testSquaredLoss(self):
- with self.test_session():
- target, predicted, expected_loss = self._getTestVectors()
- result = tf.contrib.layers.squared_loss(predicted, target)
- self.assertAllClose(expected_loss, result.eval())
-
- def testSquaredLossReturnShape(self):
- with self.test_session():
- target, predicted, expected_loss = self._getTestVectors()
- result = tf.contrib.layers.squared_loss(predicted, target)
- self.assertShapeEqual(expected_loss, result)
-
- def testInvalidShapesValueError(self):
- with self.test_session():
- target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
- incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
- name="incompatible_shape")
- with self.assertRaises(ValueError):
- tf.contrib.layers.squared_loss(incompatible_shape, target)
-
-
-class SumSquaredLossTest(tf.test.TestCase):
-
- def _getTestVectors(self):
- target = tf.constant([[0.0, 1.0],
- [3.0, 2.0]],
- shape=[2, 2],
- name="target")
- predicted = tf.constant([[3.0, -2.0],
- [1.0, 2.0]],
- shape=[2, 2],
- name="predicted")
- expected_loss = np.array([9.0, 2.0])
- return target, predicted, expected_loss
-
- def testSumSquaredLoss(self):
- with self.test_session():
- target, predicted, expected_loss = self._getTestVectors()
- result = tf.contrib.layers.sum_squared_loss(predicted, target)
- self.assertAllClose(expected_loss, result.eval())
-
- def testSumSquaredLossReturnShape(self):
- with self.test_session():
- target, predicted, expected_loss = self._getTestVectors()
- result = tf.contrib.layers.sum_squared_loss(predicted, target)
- self.assertShapeEqual(expected_loss, result)
-
- def testInvalidShapesValueError(self):
- with self.test_session():
- target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
- incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
- name="incompatible_shape")
- with self.assertRaises(ValueError):
- tf.contrib.layers.sum_squared_loss(incompatible_shape, target)
-
-
-class ScalarAbsoluteLossTest(tf.test.TestCase):
-
- def testScalarAbsoluteLoss(self):
- with self.test_session():
- actual = tf.constant([pi], name="pi")
- actual_placeholder = tf.placeholder(tf.float32)
- label = tf.constant([indiana_pi], name="lbl")
- label_placeholder = tf.placeholder(tf.float32, name="lbl_ph")
- expected_loss = abs(indiana_pi - pi)
-
- # Both shapes are set.
- both_shapes_loss = tf.contrib.layers.scalar_absolute_loss(actual, label)
- tf.initialize_all_variables().run()
- np.testing.assert_almost_equal(
- both_shapes_loss.eval(), expected_loss, decimal=6)
-
- # No shape for 'actual' - check that the loss layer can be created.
- no_actual_shape_loss = tf.contrib.layers.scalar_absolute_loss(
- actual_placeholder, label)
- tf.initialize_all_variables().run()
- np.testing.assert_almost_equal(
- no_actual_shape_loss.eval({actual_placeholder: [pi]}),
- expected_loss, decimal=6)
-
- # No shape for 'label' - check that the loss layer can be created.
- no_label_shape_loss = tf.contrib.layers.scalar_absolute_loss(
- actual, label_placeholder)
- tf.initialize_all_variables().run()
- np.testing.assert_almost_equal(
- no_label_shape_loss.eval({label_placeholder: [indiana_pi]}),
- expected_loss, decimal=6)
-
- # No shapes.
- no_shape_loss = tf.contrib.layers.scalar_absolute_loss(
- actual_placeholder, label_placeholder)
- tf.initialize_all_variables().run()
- np.testing.assert_almost_equal(
- no_shape_loss.eval({label_placeholder: [indiana_pi],
- actual_placeholder: [pi]}),
- expected_loss, decimal=6)
-
- # Evaluate the previous one again, but this time with different
- # (matching) shapes. This should still work.
- np.testing.assert_almost_equal(
- no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi],
- actual_placeholder: [pi, pi]}),
- expected_loss, decimal=6)
-
-
-class ScalarSquaredLossTest(tf.test.TestCase):
-
- def testScalarSquaredLoss(self):
- with self.test_session():
- actual = tf.constant([pi], name="pi")
- actual_placeholder = tf.placeholder(tf.float32)
- label = tf.constant([indiana_pi], name="lbl")
- label_placeholder = tf.placeholder(tf.float32, name="lbl_ph")
- expected_loss = (indiana_pi - pi) * (indiana_pi - pi) / 2
-
- # Both shapes are set.
- both_shapes_loss = tf.contrib.layers.scalar_squared_loss(actual, label)
- tf.initialize_all_variables().run()
- np.testing.assert_almost_equal(
- both_shapes_loss.eval(), expected_loss, decimal=6)
-
- # No shape for 'actual' - check that the loss layer can be created.
- no_actual_shape_loss = tf.contrib.layers.scalar_squared_loss(
- actual_placeholder, label)
- tf.initialize_all_variables().run()
- np.testing.assert_almost_equal(
- no_actual_shape_loss.eval({actual_placeholder: [pi]}),
- expected_loss, decimal=6)
-
- # No shape for 'label' - check that the loss layer can be created.
- no_label_shape_loss = tf.contrib.layers.scalar_squared_loss(
- actual, label_placeholder)
- tf.initialize_all_variables().run()
- np.testing.assert_almost_equal(
- no_label_shape_loss.eval({label_placeholder: [indiana_pi]}),
- expected_loss,
- decimal=6)
-
- # No shapes.
- no_shape_loss = tf.contrib.layers.scalar_squared_loss(
- actual_placeholder, label_placeholder)
- tf.initialize_all_variables().run()
- np.testing.assert_almost_equal(
- no_shape_loss.eval({label_placeholder: [indiana_pi],
- actual_placeholder: [pi]}),
- expected_loss, decimal=6)
-
- # Evaluate the previous one again, but this time with different
- # (matching) shapes. This should still work.
- np.testing.assert_almost_equal(
- no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi],
- actual_placeholder: [pi, pi]}),
- expected_loss, decimal=6)
-
-
-class ScalarLogisticLossTest(tf.test.TestCase):
-
- def _expected_loss(self, logit, target):
- sigmoid = 1.0 / (1.0 + np.exp(-logit))
- logistic_loss = (target * -np.log(sigmoid)) - (
- (1.0 - target) * np.log(1.0 - sigmoid))
- batch_losses = np.sum(logistic_loss, 1)
-
- return np.sum(batch_losses) / len(batch_losses)
-
- def test_scalar_logistic_loss(self):
- logit = np.array([[9.45, -42], [4.2, 1], [-0.6, 20]])
- target = np.array([[0.8, 0.9], [0.45, 0.99999], [0.1, 0.0006]])
- with self.test_session():
- result = tf.contrib.layers.scalar_logistic_loss(
- tf.constant(logit), tf.constant(target))
- self.assertAllClose(self._expected_loss(logit, target), result.eval())
-
-
-if __name__ == "__main__":
- tf.test.main()
diff --git a/tensorflow/contrib/losses/BUILD b/tensorflow/contrib/losses/BUILD
new file mode 100644
index 0000000000..8452132c45
--- /dev/null
+++ b/tensorflow/contrib/losses/BUILD
@@ -0,0 +1,42 @@
+# Description:
+# contains parts of TensorFlow that are experimental or unstable and which are not supported.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+py_library(
+ name = "losses_py",
+ srcs = [
+ "__init__.py",
+ "python/losses/__init__.py",
+ "python/losses/loss_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "loss_ops_test",
+ srcs = glob(["python/losses/loss_ops_test.py"]),
+ srcs_version = "PY2AND3",
+ deps = [
+ ":losses_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/losses/__init__.py b/tensorflow/contrib/losses/__init__.py
new file mode 100644
index 0000000000..ec796a6ca7
--- /dev/null
+++ b/tensorflow/contrib/losses/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for building neural network losses."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.losses.python.losses import *
+from tensorflow.python.util.all_util import make_all
diff --git a/tensorflow/contrib/layers/python/ops/__init__.py b/tensorflow/contrib/losses/python/losses/__init__.py
index 09bfbe4dc5..d5e426e389 100644
--- a/tensorflow/contrib/layers/python/ops/__init__.py
+++ b/tensorflow/contrib/losses/python/losses/__init__.py
@@ -19,4 +19,4 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
-from tensorflow.contrib.layers.python.ops.loss_ops import *
+from tensorflow.contrib.losses.python.losses.loss_ops import *
diff --git a/tensorflow/contrib/layers/python/ops/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index c451fc81d4..c704222c4b 100644
--- a/tensorflow/contrib/layers/python/ops/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -21,35 +21,23 @@ or for regularization purposes (e.g., weight decay).
These loss ops are, by design, minimal, enabling flexibility in how
their output can be used.
-@@reduce_batch_sum
-
-@@absolute_loss
-@@squared_loss
-@@logistic_loss
-
-@@sum_absolute_loss
-@@sum_squared_loss
-@@sum_logistic_loss
-
-@@scalar_absolute_loss
-@@scalar_squared_loss
-@@scalar_logistic_loss
+@@absolute
+@@squared
+@@logistic
+@@softmax
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.layers.python.framework import tensor_util
+from tensorflow.contrib.framework.python.framework import tensor_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
-__all__ = ["reduce_batch_sum", "absolute_loss", "squared_loss", "logistic_loss",
- "sum_absolute_loss", "sum_squared_loss", "sum_logistic_loss",
- "scalar_absolute_loss", "scalar_squared_loss",
- "scalar_logistic_loss"]
+__all__ = ["absolute", "squared", "logistic", "softmax"]
def _reduce_batch(x, reduce_fn, name=None):
@@ -95,15 +83,14 @@ def _reduce_batch(x, reduce_fn, name=None):
return result
-def reduce_batch_sum(x, name=None):
+def _reduce_batch_sum(x, name=None):
"""Given a tensor `x`, sums across all dimensions except dimension 0.
- Given a tensor with the number of dimensions > 1, reduce_batch_sum
- will sum across all dimensions except for dimension 0. This function
- is useful for summing the loss (error) across all examples in a
- batch when training. As an example, given a tensor of shape
- [batch_size, d1, d2], this function will sum across dimensions d1
- and d2, returning a tensor of shape [batch_size].
+ Given a tensor with the number of dimensions > 1, this will sum across all
+ dimensions except for dimension 0. This function is useful for summing the
+ loss (error) across all examples in a batch when training. As an example,
+ given a tensor of shape [batch_size, d1, d2], this function will sum across
+ dimensions d1 and d2, returning a tensor of shape [batch_size].
Tensors of dimension 1 are returned as-is, while tensors of dimension 0
raise a ValueError.
@@ -122,6 +109,23 @@ def reduce_batch_sum(x, name=None):
return _reduce_batch(x, math_ops.reduce_sum, name)
+def _reduce_to_scalar(x, name=None):
+ """Reduces losses to a scalar.
+
+ Given a tensor `x`, sums across all dimensions except dimension 0, then
+ average across dimension 0.
+
+ Args:
+ x: A `Tensor` with dimension > 0.
+ name: A name for the operation (optional).
+
+ Returns:
+ Caculate sum of losses per example, then average across batch.
+ """
+ with ops.op_scope([x], name, "scalar") as scope:
+ return math_ops.reduce_mean(_reduce_batch_sum(x), name=scope)
+
+
def _validate_predicted_and_target(predicted, target):
# TODO(ptucker): Optionally add assert op for shape check, for cases when
# shape is not fully defined at graph construction time?
@@ -129,7 +133,7 @@ def _validate_predicted_and_target(predicted, target):
tensor_util.assert_same_float_dtype([predicted, target])
-def absolute_loss(predicted, target, name=None):
+def _raw_absolute(predicted, target, name=None):
"""Computes and returns the per-example absolute loss.
Computes the per-example absolute value of the difference between
@@ -158,7 +162,7 @@ def absolute_loss(predicted, target, name=None):
return math_ops.abs(target - predicted, name=scope)
-def squared_loss(predicted, target, name=None):
+def _raw_squared(predicted, target, name=None):
"""Computes and returns the per-example squared loss, divided by 2.
Computes the per-example squared difference between the target and
@@ -186,59 +190,8 @@ def squared_loss(predicted, target, name=None):
return math_ops.div(math_ops.square(target - predicted), 2.0, name=scope)
-def logistic_loss(logit, target, name=None):
- """Calculates the logistic cross-entropy loss.
-
- **WARNING:** `logit` must be unscaled, while the `target` should be a
- normalized probability prediction. See
- `tf.nn.sigmoid_cross_entropy_with_logits` for more details.
-
- Args:
- logit: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
- of predicted logit values.
- target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
- target values. The shape of the target tensor should match the
- `predicted` tensor.
- name: A name for the operation (optional).
-
- Returns:
- A `Tensor` of the logistic cross-entropy loss.
- """
- return nn.sigmoid_cross_entropy_with_logits(logit, target, name=name)
-
-
-def _sum_loss(predicted, target, loss_fn, name="sum_loss"):
- """Apply loss function, then sum across all non-batch dimensions.
-
- Args:
- predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
- of predicted values.
- target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
- target values. The shape of the target tensor should match the
- `predicted` tensor.
- loss_fn: Loss to apply, takes 2 tensors as parameters and returns a tensor.
- name: A name for the operation (optional).
-
- Returns:
- A `[batch_size]` tensor of losses, averaged across all dimensions except
- dimension 0.
- """
- return reduce_batch_sum(loss_fn(predicted, target), name=name)
-
-
-def sum_absolute_loss(predicted, target, name="sum_absolute_loss"):
- """Calculates the sum of absolute losses across batches.
-
- Computes the absolute difference between the target and predicted
- tensors, averaged across all dimensions except dimension 0:
-
- losses = reduce_batch_sum(absolute_loss(predicted, target))
-
- where `losses` is a tensor with dimensions [batch_size].
-
- The tensors must have the same shape.
-
- This loss function is a form of L1 loss.
+def absolute(predicted, target, name=None):
+ """Reduces absolute losses to a scalar.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
@@ -249,32 +202,14 @@ def sum_absolute_loss(predicted, target, name="sum_absolute_loss"):
name: A name for the operation (optional).
Returns:
- A `[batch_size]` tensor of absolute differences, averaged across all
- dimensions except dimension 0.
-
- Raises:
- ValueError: If `predicted` and `target` shapes do not match.
-
+ Caculate sum of absolute losses per example, then average across batch.
"""
- return _sum_loss(predicted, target, absolute_loss, name=name)
-
-
-def sum_squared_loss(predicted, target, name="sum_squared_loss"):
- """Calculates the sum of the squared loss across batches.
-
- Computes the squared difference between the target and predicted
- tensors, sums across all dimensions except dimension 0.
-
- losses = reduce_batch_sum(squared_loss(predicted, target))
-
- where `losses` is a tensor with dimensions [batch_size].
+ with ops.op_scope([predicted, target], name, "absolute_loss") as scope:
+ return _reduce_to_scalar(_raw_absolute(predicted, target), name=scope)
- The tensors must have the same shape.
- This function is equivalent to typical formulations of L2 loss, and
- similar to TensorFlow's l2_loss function. It differs from the
- l2_loss function by allowing the caller to specify both the
- predicted and target tensors.
+def squared(predicted, target, name=None):
+ """Reduces squared losses to a scalar.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
@@ -285,21 +220,14 @@ def sum_squared_loss(predicted, target, name="sum_squared_loss"):
name: A name for the operation (optional).
Returns:
- A `[batch_size]` tensor of squared losses summed across all dimensions
- except dimension 0.
-
- Raises:
- ValueError: If `predicted` and `target` shapes do not match.
-
+ Caculate sum of squared losses per example, then average across batch.
"""
- return _sum_loss(predicted, target, squared_loss, name=name)
-
+ with ops.op_scope([predicted, target], name, "squared_loss") as scope:
+ return _reduce_to_scalar(_raw_squared(predicted, target), name=scope)
-def sum_logistic_loss(logit, target, name="sum_logistic_loss"):
- """Calculates the sum of the logistic loss across batches.
- Computes the logistic between logit and predicted tensors, summed across all
- dimensions except dimension 0.
+def logistic(logit, target, name=None):
+ """Calculates the logistic cross-entropy loss, averaged across batches.
**WARNING:** `logit` must be unscaled, while the `target` should be a
normalized probability prediction. See
@@ -310,91 +238,49 @@ def sum_logistic_loss(logit, target, name="sum_logistic_loss"):
of predicted logit values.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
- `predicted` tensor.
- name: A name for the operation (optional).
-
- Returns:
- A `[batch_size]` tensor of logistic losses summed across all dimensions
- except dimension 0.
- """
- return _sum_loss(logit, target, logistic_loss, name=name)
-
-
-def _scalar_loss(predicted, target, loss_fn, name=None):
- """Reduces losses to a scalar.
-
- Args:
- predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
- of predicted values.
- target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
- target values. The shape of the target tensor should match the
- `predicted` tensor.
- loss_fn: Loss to apply, takes 2 tensors as parameters and returns a tensor.
+ `logit` tensor.
name: A name for the operation (optional).
Returns:
- Caculate sum of losses per example, then average across batch.
- """
- with ops.op_scope([predicted, target], name, "scalar_loss") as scope:
- return math_ops.reduce_mean(
- _sum_loss(predicted, target, loss_fn), name=scope)
-
-
-def scalar_absolute_loss(predicted, target, name="scalar_absolute_loss"):
- """Reduces absolute losses to a scalar.
-
- Args:
- predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
- of predicted values.
- target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
- target values. The shape of the target tensor should match the
- `predicted` tensor.
- name: A name for the operation (optional).
-
- Returns:
- Caculate sum of absolute losses per example, then average across batch.
- """
- return _scalar_loss(predicted, target, loss_fn=absolute_loss, name=name)
-
-
-def scalar_squared_loss(predicted, target, name="scalar_squared_loss"):
- """Reduces squared losses to a scalar.
-
- Args:
- predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
- of predicted values.
- target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
- target values. The shape of the target tensor should match the
- `predicted` tensor.
- name: A name for the operation (optional).
+ A scalar `tensor` of the logistic cross-entropy loss, averaged across
+ batches.
- Returns:
- Caculate sum of squared losses per example, then average across batch.
+ Raises:
+ ValueError: If `logit` and `target` shapes do not match.
"""
- return _scalar_loss(predicted, target, loss_fn=squared_loss, name=name)
+ with ops.op_scope([logit, target], name, "logistic_loss") as scope:
+ return _reduce_to_scalar(
+ nn.sigmoid_cross_entropy_with_logits(logit, target), name=scope)
-def scalar_logistic_loss(logit, target, name="scalar_logistic_loss"):
- """Calculates the logistic cross-entropy loss, averaged across batches.
+def softmax(logit, target, name=None):
+ """Calculates the softmax cross-entropy loss, averaged across batches.
**WARNING:** `logit` must be unscaled, while the `target` should be a
normalized probability prediction. See
`tf.nn.sigmoid_cross_entropy_with_logits` for more details.
Args:
- logit: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
- of predicted logit values.
+ logit: Tensor of actual values. Shape must have rank 2, generally
+ (batch, num_classes). num_classes must be > 1. For single-class
+ regression, use `logistic`. Type must be `tf.float32` or `tf.float64`.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
- `predicted` tensor.
+ `logit` tensor.
name: A name for the operation (optional).
Returns:
- A scalar `tensor` of the logistic cross-entropy loss, averaged across
+ A scalar `tensor` of the softmax cross-entropy loss, averaged across
batches.
Raises:
ValueError: If `logit` and `target` shapes do not match.
"""
- return _scalar_loss(logit, target, loss_fn=logistic_loss, name=name)
-
+ with ops.op_scope([logit, target], name, "softmax_loss") as scope:
+ shape = logit.get_shape().with_rank(2)
+ if shape.dims[1] and shape.dims[1] < 2:
+ raise ValueError(
+ "Invalid shape %s; use logistic() instead for only 1 class." %
+ shape)
+ return _reduce_to_scalar(
+ nn.softmax_cross_entropy_with_logits(logit, target), name=scope)
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
new file mode 100644
index 0000000000..71e464362f
--- /dev/null
+++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
@@ -0,0 +1,272 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for contrib.losses.python.losses.loss_ops."""
+# pylint: disable=unused-import,g-bad-import-order
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.contrib.framework.python.framework import tensor_util
+
+pi = 3.14
+indiana_pi = 3.2 # https://en.wikipedia.org/wiki/Indiana_Pi_Bill
+
+
+class AbsoluteLossTest(tf.test.TestCase):
+
+ def testAbsoluteLoss(self):
+ with self.test_session():
+ actual = tf.constant([pi], name="pi")
+ actual_placeholder = tf.placeholder(tf.float32)
+ label = tf.constant([indiana_pi], name="lbl")
+ label_placeholder = tf.placeholder(tf.float32, name="lbl_ph")
+ expected_loss = abs(indiana_pi - pi)
+
+ # Both shapes are set.
+ both_shapes_loss = tf.contrib.losses.absolute(actual, label)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ both_shapes_loss.eval(), expected_loss, decimal=6)
+
+ # No shape for 'actual' - check that the loss layer can be created.
+ no_actual_shape_loss = tf.contrib.losses.absolute(
+ actual_placeholder, label)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ no_actual_shape_loss.eval({actual_placeholder: [pi]}),
+ expected_loss, decimal=6)
+
+ # No shape for 'label' - check that the loss layer can be created.
+ no_label_shape_loss = tf.contrib.losses.absolute(
+ actual, label_placeholder)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ no_label_shape_loss.eval({label_placeholder: [indiana_pi]}),
+ expected_loss, decimal=6)
+
+ # No shapes.
+ no_shape_loss = tf.contrib.losses.absolute(
+ actual_placeholder, label_placeholder)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ no_shape_loss.eval({label_placeholder: [indiana_pi],
+ actual_placeholder: [pi]}),
+ expected_loss, decimal=6)
+
+ # Evaluate the previous one again, but this time with different
+ # (matching) shapes. This should still work.
+ np.testing.assert_almost_equal(
+ no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi],
+ actual_placeholder: [pi, pi]}),
+ expected_loss, decimal=6)
+
+
+class SquaredLossTest(tf.test.TestCase):
+
+ def testSquaredLoss(self):
+ with self.test_session():
+ actual = tf.constant([pi], name="pi")
+ actual_placeholder = tf.placeholder(tf.float32)
+ label = tf.constant([indiana_pi], name="lbl")
+ label_placeholder = tf.placeholder(tf.float32, name="lbl_ph")
+ expected_loss = (indiana_pi - pi) * (indiana_pi - pi) / 2
+
+ # Both shapes are set.
+ both_shapes_loss = tf.contrib.losses.squared(actual, label)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ both_shapes_loss.eval(), expected_loss, decimal=6)
+
+ # No shape for 'actual' - check that the loss layer can be created.
+ no_actual_shape_loss = tf.contrib.losses.squared(
+ actual_placeholder, label)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ no_actual_shape_loss.eval({actual_placeholder: [pi]}),
+ expected_loss, decimal=6)
+
+ # No shape for 'label' - check that the loss layer can be created.
+ no_label_shape_loss = tf.contrib.losses.squared(
+ actual, label_placeholder)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ no_label_shape_loss.eval({label_placeholder: [indiana_pi]}),
+ expected_loss,
+ decimal=6)
+
+ # No shapes.
+ no_shape_loss = tf.contrib.losses.squared(
+ actual_placeholder, label_placeholder)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ no_shape_loss.eval({label_placeholder: [indiana_pi],
+ actual_placeholder: [pi]}),
+ expected_loss, decimal=6)
+
+ # Evaluate the previous one again, but this time with different
+ # (matching) shapes. This should still work.
+ np.testing.assert_almost_equal(
+ no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi],
+ actual_placeholder: [pi, pi]}),
+ expected_loss, decimal=6)
+
+
+class LogisticTest(tf.test.TestCase):
+
+ def _expected_loss(self, logit, target):
+ sigmoid = 1.0 / (1.0 + np.exp(-logit))
+ logistic_loss = (target * -np.log(sigmoid)) - (
+ (1.0 - target) * np.log(1.0 - sigmoid))
+ batch_losses = np.sum(logistic_loss, 1)
+
+ return np.sum(batch_losses) / len(batch_losses)
+
+ def testSimple(self):
+ logit = np.array([[9.45, -42], [4.2, 1], [-0.6, 20]])
+ target = np.array([[0.8, 0.9], [0.45, 0.99999], [0.1, 0.0006]])
+ with self.test_session():
+ loss = tf.contrib.losses.logistic(tf.constant(logit), tf.constant(target))
+ self.assertAllClose(self._expected_loss(logit, target), loss.eval())
+
+ def testComplex(self):
+ with self.test_session():
+ # [batch] and [batch,1] work the same.
+ loss3x0 = tf.contrib.losses.logistic(
+ tf.constant([-1.0, 3.0, -3.0]),
+ tf.constant([0.3, 0.1, 0.4]))
+ tf.initialize_all_variables().run()
+ self.assertAllClose(1.536812, loss3x0.eval())
+
+ expected_loss = 1.536812
+ actual3x1 = [[-1.0], [3.0], [-3.0]]
+ label3x1 = [[0.3], [0.1], [0.4]]
+ loss3x1 = tf.contrib.losses.logistic(
+ tf.constant(actual3x1), tf.constant(label3x1))
+ tf.initialize_all_variables().run()
+ self.assertAllClose(expected_loss, loss3x1.eval())
+
+ # Batch average stays the same with repeats of the same examples.
+ loss9x1 = tf.contrib.losses.logistic(
+ tf.constant(actual3x1 * 3), tf.constant(label3x1 * 3))
+ tf.initialize_all_variables().run()
+ self.assertAllClose(expected_loss, loss9x1.eval())
+
+ # Loss stays the same when adding another class with 0 loss.
+ loss3x2 = tf.contrib.losses.logistic(
+ tf.constant([[-1.0, 100.0], [3.0, -100.0], [-3.0, -100.0]]),
+ tf.constant([[0.3, 1.0], [0.1, 0.0], [0.4, 0.0]]))
+ tf.initialize_all_variables().run()
+ self.assertAllClose(expected_loss, loss3x2.eval())
+
+ # Loss stays the same with additional x1 dimension.
+ loss3x1x2 = tf.contrib.losses.logistic(
+ tf.constant([[[-1.0, 100.0]], [[3.0, -100.0]], [[-3.0, -100.0]]]),
+ tf.constant([[[0.3, 1.0]], [[0.1, 0.0]], [[0.4, 0.0]]]))
+ tf.initialize_all_variables().run()
+ self.assertAllClose(expected_loss, loss3x1x2.eval())
+
+ # We have set one label value to be out of range (the -0.4) and
+ # expect the absence of a crash since we did not set validate=True
+ loss = tf.contrib.losses.logistic(
+ tf.constant([[[-1.0, 100.0]], [[3.0, -100.0]], [[-3.0, -100.0]]]),
+ tf.constant([[[0.3, 1.0]], [[0.1, 0.0]], [[-0.4, 0.0]]]))
+ tf.initialize_all_variables().run()
+ loss.eval()
+
+ def testLogisticVsSoftmax(self):
+ with self.test_session():
+ # Each logit = L and target = T used for logistic_loss corresponds to
+ # logits [a, b] where a - b = L and targets [T, 1 - T] for
+ # softmax_loss.
+
+ expected_loss = (0.69314718 + 1.01326168 + 2.10692811) / 3.0
+
+ logistic_loss = tf.contrib.losses.logistic(
+ tf.constant([0.0, 1.0, 2.0]),
+ tf.constant([0.5, 0.3, 0.01]))
+ tf.initialize_all_variables().run()
+ self.assertAllClose(expected_loss, logistic_loss.eval())
+
+ softmax_loss = tf.contrib.losses.softmax(
+ tf.constant([[1.0, 1.0], [2.0, 1.0], [3.0, 1.0]]),
+ tf.constant([[0.5, 0.5], [0.3, 0.7], [0.01, 0.99]]))
+ tf.initialize_all_variables().run()
+ self.assertAllClose(expected_loss, softmax_loss.eval())
+
+
+class SoftmaxTest(tf.test.TestCase):
+
+ def testAllCorrect(self):
+ with self.test_session():
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[1.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0],
+ [0.0, 0.0, 1.0]])
+ loss = tf.contrib.losses.softmax(logits, labels)
+ self.assertAlmostEqual(loss.eval(), 0.0, 3)
+
+ def testAllWrong(self):
+ with self.test_session():
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[0.0, 0.0, 1.0],
+ [1.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0]])
+ loss = tf.contrib.losses.softmax(logits, labels)
+ self.assertAlmostEqual(loss.eval(), 10.0, 3)
+
+ def testSoftmax(self):
+ with self.test_session():
+ # [batch] and [batch,1] fail, softmax_loss is only for multiclass.
+ self.assertRaisesRegexp(
+ ValueError, "must have rank 2", tf.contrib.losses.softmax,
+ tf.constant([-100.0, 10.0, 0.0]),
+ tf.constant([1.0, 1.0, 1.0]))
+
+ self.assertRaisesRegexp(
+ ValueError, "only 1 class", tf.contrib.losses.softmax,
+ tf.constant([[-100.0], [10.0], [0.0]]),
+ tf.constant([[1.0], [1.0], [1.0]]))
+
+ expected_loss = 3.173363
+ loss3x2 = tf.contrib.losses.softmax(
+ tf.constant([[-1.0, 1.0], [0.0, 0.0], [10.0, -1.0]]),
+ tf.constant([[0.5, 0.5], [0.3, 0.7], [0.3, 0.7]]))
+ tf.initialize_all_variables().run()
+ self.assertAllClose(expected_loss, loss3x2.eval())
+
+ # Loss stays the same when adding another negative class.
+ loss3x3 = tf.contrib.losses.softmax(
+ tf.constant(
+ [[-1.0, 1.0, -100.0], [0.0, 0.0, -100.0], [10.0, -1.0, -100.0]]),
+ tf.constant([[0.5, 0.5, 0.0], [0.3, 0.7, 0.0], [0.3, 0.7, 0.0]]))
+ tf.initialize_all_variables().run()
+ self.assertAllClose(expected_loss, loss3x3.eval())
+
+ # Fails for rank > 2.
+ self.assertRaisesRegexp(
+ ValueError, "must have rank 2", tf.contrib.losses.softmax,
+ tf.constant([[[-1.0, 1.0]], [[0.0, 0.0]], [[10.0, -1.0]]]),
+ tf.constant([[[0.5, 0.5]], [[0.3, 0.7]], [[0.3, 0.7]]]))
+
+
+if __name__ == "__main__":
+ tf.test.main()