From 5b5b8412f0684a548e1e9001421e5d095cda0142 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 24 Mar 2016 08:19:12 -0800 Subject: In tf.contrib, move framework and loss packages up a level. Simplify loss ops, only return scalar in public API (see discussion at go/tf_contrib_ops). Add softmax loss. Change: 118034691 --- tensorflow/contrib/BUILD | 2 + tensorflow/contrib/__init__.py | 2 + tensorflow/contrib/framework/BUILD | 42 +++ tensorflow/contrib/framework/__init__.py | 30 ++ .../contrib/framework/python/framework/__init__.py | 22 ++ .../framework/python/framework/tensor_util.py | 115 ++++++ .../framework/python/framework/tensor_util_test.py | 94 +++++ tensorflow/contrib/layers/BUILD | 29 -- tensorflow/contrib/layers/__init__.py | 17 - .../contrib/layers/python/framework/tensor_util.py | 111 ------ .../layers/python/framework/tensor_util_test.py | 93 ----- tensorflow/contrib/layers/python/ops/__init__.py | 22 -- tensorflow/contrib/layers/python/ops/loss_ops.py | 400 --------------------- .../contrib/layers/python/ops/loss_ops_test.py | 310 ---------------- tensorflow/contrib/losses/BUILD | 42 +++ tensorflow/contrib/losses/__init__.py | 25 ++ .../contrib/losses/python/losses/__init__.py | 22 ++ .../contrib/losses/python/losses/loss_ops.py | 286 +++++++++++++++ .../contrib/losses/python/losses/loss_ops_test.py | 272 ++++++++++++++ 19 files changed, 954 insertions(+), 982 deletions(-) create mode 100644 tensorflow/contrib/framework/BUILD create mode 100644 tensorflow/contrib/framework/__init__.py create mode 100644 tensorflow/contrib/framework/python/framework/__init__.py create mode 100644 tensorflow/contrib/framework/python/framework/tensor_util.py create mode 100644 tensorflow/contrib/framework/python/framework/tensor_util_test.py delete mode 100644 tensorflow/contrib/layers/python/framework/tensor_util.py delete mode 100644 tensorflow/contrib/layers/python/framework/tensor_util_test.py delete mode 100644 tensorflow/contrib/layers/python/ops/__init__.py delete mode 100644 tensorflow/contrib/layers/python/ops/loss_ops.py delete mode 100644 tensorflow/contrib/layers/python/ops/loss_ops_test.py create mode 100644 tensorflow/contrib/losses/BUILD create mode 100644 tensorflow/contrib/losses/__init__.py create mode 100644 tensorflow/contrib/losses/python/losses/__init__.py create mode 100644 tensorflow/contrib/losses/python/losses/loss_ops.py create mode 100644 tensorflow/contrib/losses/python/losses/loss_ops_test.py diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index ece038a572..366af6b2c6 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -15,9 +15,11 @@ py_library( deps = [ "//tensorflow/contrib/ctc:ctc_py", "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", "//tensorflow/contrib/lookup:lookup_py", + "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/skflow", "//tensorflow/contrib/testing:testing_py", "//tensorflow/contrib/util:util_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index eab8b457d4..794ce70299 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -21,8 +21,10 @@ from __future__ import print_function # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import ctc from tensorflow.contrib import distributions +from tensorflow.contrib import framework from tensorflow.contrib import layers from tensorflow.contrib import linear_optimizer from tensorflow.contrib import lookup +from tensorflow.contrib import losses from tensorflow.contrib import testing from tensorflow.contrib import util diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD new file mode 100644 index 0000000000..4d83b1956e --- /dev/null +++ b/tensorflow/contrib/framework/BUILD @@ -0,0 +1,42 @@ +# Description: +# contains parts of TensorFlow that are experimental or unstable and which are not supported. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +py_library( + name = "framework_py", + srcs = [ + "__init__.py", + "python/framework/__init__.py", + "python/framework/tensor_util.py", + ], + srcs_version = "PY2AND3", +) + +py_test( + name = "tensor_util_test", + srcs = glob(["python/framework/tensor_util_test.py"]), + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py new file mode 100644 index 0000000000..be0cc3eb93 --- /dev/null +++ b/tensorflow/contrib/framework/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Framework utilities. + +@@assert_same_float_dtype +@@is_numeric_tensor +@@assert_scalar_int +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.framework.python.framework import * +from tensorflow.python.util.all_util import make_all diff --git a/tensorflow/contrib/framework/python/framework/__init__.py b/tensorflow/contrib/framework/python/framework/__init__.py new file mode 100644 index 0000000000..4a17474b63 --- /dev/null +++ b/tensorflow/contrib/framework/python/framework/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A module containing TensorFlow ops whose API may change in the future.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.contrib.framework.python.framework.tensor_util import * diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py new file mode 100644 index 0000000000..6b85c38f1a --- /dev/null +++ b/tensorflow/contrib/framework/python/framework/tensor_util.py @@ -0,0 +1,115 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tensor utility functions. + +@@assert_same_float_dtype +@@is_numeric_tensor +@@assert_scalar_int +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework.ops import Tensor + +__all__ = ['assert_same_float_dtype', 'is_numeric_tensor', 'assert_scalar_int'] + + +NUMERIC_TYPES = frozenset([dtypes.float32, dtypes.float64, dtypes.int8, + dtypes.int16, dtypes.int32, dtypes.int64, + dtypes.uint8, dtypes.qint8, dtypes.qint32, + dtypes.quint8, dtypes.complex64]) + + +def is_numeric_tensor(tensor): + return isinstance(tensor, Tensor) and tensor.dtype in NUMERIC_TYPES + + +def _assert_same_base_type(items, expected_type=None): + """Asserts all items are of the same base type. + + Args: + items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`, + `Operation`, or `IndexedSlices`). Can include `None` elements, which + will be ignored. + expected_type: Expected type. If not specified, assert all items are + of the same base type. + Returns: + Validated type, or none if neither expected_type nor items provided. + + Raises: + ValueError: If any types do not match. + """ + original_item_str = None + for item in items: + if item is not None: + item_type = item.dtype.base_dtype + if not expected_type: + expected_type = item_type + original_item_str = item.name if hasattr(item, 'name') else str(item) + elif expected_type != item_type: + raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( + item.name if hasattr(item, 'name') else str(item), + item_type, expected_type, + (' as %s' % original_item_str) if original_item_str else '')) + return expected_type + + +def assert_same_float_dtype(tensors=None, dtype=None): + """Validate and return float type based on `tensors` and `dtype`. + + For ops such as matrix multiplication, inputs and weights must be of the + same float type. This function validates that all `tensors` are the same type, + validates that type is `dtype` (if supplied), and returns the type. Type must + be `dtypes.float32` or `dtypes.float64`. If neither `tensors` nor + `dtype` is supplied, default to `dtypes.float32`. + + Args: + tensors: Tensors of input values. Can include `None` elements, which will be + ignored. + dtype: Expected type. + Returns: + Validated type. + Raises: + ValueError: if neither `tensors` nor `dtype` is supplied, or result is not + float. + """ + if tensors: + dtype = _assert_same_base_type(tensors, dtype) + if not dtype: + dtype = dtypes.float32 + elif not dtype.is_floating: + raise ValueError('Expected float, got %s.' % dtype) + return dtype + + +def assert_scalar_int(tensor): + """Assert `tensor` is 0-D, of type `tf.int32` or `tf.int64`. + + Args: + tensor: Tensor to test. + Returns: + `tensor`, for chaining. + Raises: + ValueError: if `tensor` is not 0-D, of type `tf.int32` or `tf.int64`. + """ + data_type = tensor.dtype + if data_type.base_dtype not in [dtypes.int32, dtypes.int64]: + raise ValueError('Unexpected type %s for %s.' % (data_type, tensor.name)) + shape = tensor.get_shape() + if shape.ndims != 0: + raise ValueError('Unexpected shape %s for %s.' % (shape, tensor.name)) + return tensor diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py new file mode 100644 index 0000000000..644fa9905b --- /dev/null +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -0,0 +1,94 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""DType tests.""" + +# pylint: disable=unused-import +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +class FloatDTypeTest(tf.test.TestCase): + + def test_assert_same_float_dtype(self): + self.assertIs( + tf.float32, tf.contrib.framework.assert_same_float_dtype(None, None)) + self.assertIs( + tf.float32, tf.contrib.framework.assert_same_float_dtype([], None)) + self.assertIs( + tf.float32, + tf.contrib.framework.assert_same_float_dtype([], tf.float32)) + self.assertIs( + tf.float32, + tf.contrib.framework.assert_same_float_dtype(None, tf.float32)) + self.assertIs( + tf.float32, + tf.contrib.framework.assert_same_float_dtype([None, None], None)) + self.assertIs( + tf.float32, + tf.contrib.framework.assert_same_float_dtype([None, None], tf.float32)) + + const_float = tf.constant(3.0, dtype=tf.float32) + self.assertIs( + tf.float32, + tf.contrib.framework.assert_same_float_dtype([const_float], tf.float32)) + self.assertRaises( + ValueError, + tf.contrib.framework.assert_same_float_dtype, [const_float], tf.int32) + + sparse_float = tf.SparseTensor( + tf.constant([[111], [232]], tf.int64), + tf.constant([23.4, -43.2], tf.float32), + tf.constant([500], tf.int64)) + self.assertIs(tf.float32, tf.contrib.framework.assert_same_float_dtype( + [sparse_float], tf.float32)) + self.assertRaises( + ValueError, + tf.contrib.framework.assert_same_float_dtype, [sparse_float], tf.int32) + self.assertRaises( + ValueError, tf.contrib.framework.assert_same_float_dtype, + [const_float, None, sparse_float], tf.float64) + + self.assertIs( + tf.float32, + tf.contrib.framework.assert_same_float_dtype( + [const_float, sparse_float])) + self.assertIs(tf.float32, tf.contrib.framework.assert_same_float_dtype( + [const_float, sparse_float], tf.float32)) + + const_int = tf.constant(3, dtype=tf.int32) + self.assertRaises(ValueError, tf.contrib.framework.assert_same_float_dtype, + [sparse_float, const_int]) + self.assertRaises(ValueError, tf.contrib.framework.assert_same_float_dtype, + [sparse_float, const_int], tf.int32) + self.assertRaises(ValueError, tf.contrib.framework.assert_same_float_dtype, + [sparse_float, const_int], tf.float32) + self.assertRaises( + ValueError, tf.contrib.framework.assert_same_float_dtype, [const_int]) + + def test_assert_scalar_int(self): + tf.contrib.framework.assert_scalar_int(tf.constant(3, dtype=tf.int32)) + tf.contrib.framework.assert_scalar_int(tf.constant(3, dtype=tf.int64)) + with self.assertRaisesRegexp(ValueError, "Unexpected type"): + tf.contrib.framework.assert_scalar_int(tf.constant(3, dtype=tf.float32)) + with self.assertRaisesRegexp(ValueError, "Unexpected shape"): + tf.contrib.framework.assert_scalar_int( + tf.constant([3, 4], dtype=tf.int32)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index a9ea2c745f..0a3653d7fa 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -11,15 +11,12 @@ py_library( name = "layers_py", srcs = [ "__init__.py", - "python/framework/tensor_util.py", "python/layers/__init__.py", "python/layers/initializers.py", "python/layers/layers.py", "python/layers/optimizers.py", "python/layers/regularizers.py", "python/layers/summaries.py", - "python/ops/__init__.py", - "python/ops/loss_ops.py", ], srcs_version = "PY2AND3", ) @@ -75,19 +72,6 @@ py_test( ], ) -py_test( - name = "loss_ops_test", - size = "small", - srcs = glob(["python/ops/loss_ops_test.py"]), - srcs_version = "PY2AND3", - deps = [ - ":layers_py", - "//tensorflow:tensorflow_py", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - py_test( name = "summaries_test", size = "small", @@ -101,19 +85,6 @@ py_test( ], ) -py_test( - name = "tensor_util_test", - size = "small", - srcs = glob(["python/framework/tensor_util_test.py"]), - srcs_version = "PY2AND3", - deps = [ - ":layers_py", - "//tensorflow:tensorflow_py", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index bbb3c2a351..ddf7f43b73 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -57,13 +57,6 @@ The layers module defines convenience functions `summarize_variables`, of `summarize_collection` to `VARIABLES`, `WEIGHTS` and `BIASES`, respectively. @@summarize_activations - -## Utilities - -@@assert_same_float_dtype -@@assert_scalar_int -@@is_numeric_tensor - """ from __future__ import absolute_import @@ -73,15 +66,5 @@ from __future__ import print_function import sys # pylint: disable=unused-import,wildcard-import -from tensorflow.contrib.layers.python.framework.tensor_util import * from tensorflow.contrib.layers.python.layers import * -from tensorflow.contrib.layers.python.ops import * -from tensorflow.contrib.layers.python.ops import loss_ops from tensorflow.python.util.all_util import make_all - - -# Include loss_ops to get the symbols in - but they are not documented in main -# docs yet. -# TODO(cwhipkey): get the loss_ops documented in the main documentation and do -# this in a better way. -__all__ = make_all(__name__, [sys.modules[__name__], loss_ops]) diff --git a/tensorflow/contrib/layers/python/framework/tensor_util.py b/tensorflow/contrib/layers/python/framework/tensor_util.py deleted file mode 100644 index 1a5450630c..0000000000 --- a/tensorflow/contrib/layers/python/framework/tensor_util.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2015 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Tensor utility functions.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from tensorflow.python.framework import dtypes -from tensorflow.python.framework.ops import Tensor - - -__all__ = ['assert_same_float_dtype', 'is_numeric_tensor', 'assert_scalar_int'] - - -NUMERIC_TYPES = frozenset([dtypes.float32, dtypes.float64, dtypes.int8, - dtypes.int16, dtypes.int32, dtypes.int64, - dtypes.uint8, dtypes.qint8, dtypes.qint32, - dtypes.quint8, dtypes.complex64]) - - -def is_numeric_tensor(tensor): - return isinstance(tensor, Tensor) and tensor.dtype in NUMERIC_TYPES - - -def _assert_same_base_type(items, expected_type=None): - """Asserts all items are of the same base type. - - Args: - items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`, - `Operation`, or `IndexedSlices`). Can include `None` elements, which - will be ignored. - expected_type: Expected type. If not specified, assert all items are - of the same base type. - Returns: - Validated type, or none if neither expected_type nor items provided. - - Raises: - ValueError: If any types do not match. - """ - original_item_str = None - for item in items: - if item is not None: - item_type = item.dtype.base_dtype - if not expected_type: - expected_type = item_type - original_item_str = item.name if hasattr(item, 'name') else str(item) - elif expected_type != item_type: - raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( - item.name if hasattr(item, 'name') else str(item), - item_type, expected_type, - (' as %s' % original_item_str) if original_item_str else '')) - return expected_type - - -def assert_same_float_dtype(tensors=None, dtype=None): - """Validate and return float type based on `tensors` and `dtype`. - - For ops such as matrix multiplication, inputs and weights must be of the - same float type. This function validates that all `tensors` are the same type, - validates that type is `dtype` (if supplied), and returns the type. Type must - be `dtypes.float32` or `dtypes.float64`. If neither `tensors` nor - `dtype` is supplied, default to `dtypes.float32`. - - Args: - tensors: Tensors of input values. Can include `None` elements, which will be - ignored. - dtype: Expected type. - Returns: - Validated type. - Raises: - ValueError: if neither `tensors` nor `dtype` is supplied, or result is not - float. - """ - if tensors: - dtype = _assert_same_base_type(tensors, dtype) - if not dtype: - dtype = dtypes.float32 - elif not dtype.is_floating: - raise ValueError('Expected float, got %s.' % dtype) - return dtype - - -def assert_scalar_int(tensor): - """Assert `tensor` is 0-D, of type `tf.int32` or `tf.int64`. - - Args: - tensor: Tensor to test. - Returns: - `tensor`, for chaining. - Raises: - ValueError: if `tensor` is not 0-D, of type `tf.int32` or `tf.int64`. - """ - data_type = tensor.dtype - if data_type.base_dtype not in [dtypes.int32, dtypes.int64]: - raise ValueError('Unexpected type %s for %s.' % (data_type, tensor.name)) - shape = tensor.get_shape() - if shape.ndims != 0: - raise ValueError('Unexpected shape %s for %s.' % (shape, tensor.name)) - return tensor diff --git a/tensorflow/contrib/layers/python/framework/tensor_util_test.py b/tensorflow/contrib/layers/python/framework/tensor_util_test.py deleted file mode 100644 index 6785ab4938..0000000000 --- a/tensorflow/contrib/layers/python/framework/tensor_util_test.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2015 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""DType tests.""" - -# pylint: disable=unused-import -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -import tensorflow.python.framework - - -class FloatDTypeTest(tf.test.TestCase): - - def test_assert_same_float_dtype(self): - self.assertIs( - tf.float32, tf.contrib.layers.assert_same_float_dtype(None, None)) - self.assertIs( - tf.float32, tf.contrib.layers.assert_same_float_dtype([], None)) - self.assertIs( - tf.float32, tf.contrib.layers.assert_same_float_dtype([], tf.float32)) - self.assertIs( - tf.float32, - tf.contrib.layers.assert_same_float_dtype(None, tf.float32)) - self.assertIs( - tf.float32, - tf.contrib.layers.assert_same_float_dtype([None, None], None)) - self.assertIs( - tf.float32, - tf.contrib.layers.assert_same_float_dtype([None, None], tf.float32)) - - const_float = tf.constant(3.0, dtype=tf.float32) - self.assertIs( - tf.float32, - tf.contrib.layers.assert_same_float_dtype([const_float], tf.float32)) - self.assertRaises( - ValueError, - tf.contrib.layers.assert_same_float_dtype, [const_float], tf.int32) - - sparse_float = tf.SparseTensor( - tf.constant([[111], [232]], tf.int64), - tf.constant([23.4, -43.2], tf.float32), - tf.constant([500], tf.int64)) - self.assertIs(tf.float32, tf.contrib.layers.assert_same_float_dtype( - [sparse_float], tf.float32)) - self.assertRaises( - ValueError, - tf.contrib.layers.assert_same_float_dtype, [sparse_float], tf.int32) - self.assertRaises( - ValueError, tf.contrib.layers.assert_same_float_dtype, - [const_float, None, sparse_float], tf.float64) - - self.assertIs( - tf.float32, - tf.contrib.layers.assert_same_float_dtype([const_float, sparse_float])) - self.assertIs(tf.float32, tf.contrib.layers.assert_same_float_dtype( - [const_float, sparse_float], tf.float32)) - - const_int = tf.constant(3, dtype=tf.int32) - self.assertRaises(ValueError, tf.contrib.layers.assert_same_float_dtype, - [sparse_float, const_int]) - self.assertRaises(ValueError, tf.contrib.layers.assert_same_float_dtype, - [sparse_float, const_int], tf.int32) - self.assertRaises(ValueError, tf.contrib.layers.assert_same_float_dtype, - [sparse_float, const_int], tf.float32) - self.assertRaises( - ValueError, tf.contrib.layers.assert_same_float_dtype, [const_int]) - - def test_assert_scalar_int(self): - tf.contrib.layers.assert_scalar_int(tf.constant(3, dtype=tf.int32)) - tf.contrib.layers.assert_scalar_int(tf.constant(3, dtype=tf.int64)) - with self.assertRaisesRegexp(ValueError, "Unexpected type"): - tf.contrib.layers.assert_scalar_int(tf.constant(3, dtype=tf.float32)) - with self.assertRaisesRegexp(ValueError, "Unexpected shape"): - tf.contrib.layers.assert_scalar_int(tf.constant([3, 4], dtype=tf.int32)) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow/contrib/layers/python/ops/__init__.py b/tensorflow/contrib/layers/python/ops/__init__.py deleted file mode 100644 index 09bfbe4dc5..0000000000 --- a/tensorflow/contrib/layers/python/ops/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2015 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A module containing TensorFlow ops whose API may change in the future.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=wildcard-import -from tensorflow.contrib.layers.python.ops.loss_ops import * diff --git a/tensorflow/contrib/layers/python/ops/loss_ops.py b/tensorflow/contrib/layers/python/ops/loss_ops.py deleted file mode 100644 index c451fc81d4..0000000000 --- a/tensorflow/contrib/layers/python/ops/loss_ops.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright 2015 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""## Loss operations for use in neural networks. - -The loss ops measure error for use in neural networks. These losses -can be used for measuring accuracy of a network in a regression task -or for regularization purposes (e.g., weight decay). - -These loss ops are, by design, minimal, enabling flexibility in how -their output can be used. - -@@reduce_batch_sum - -@@absolute_loss -@@squared_loss -@@logistic_loss - -@@sum_absolute_loss -@@sum_squared_loss -@@sum_logistic_loss - -@@scalar_absolute_loss -@@scalar_squared_loss -@@scalar_logistic_loss -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.layers.python.framework import tensor_util -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn - -__all__ = ["reduce_batch_sum", "absolute_loss", "squared_loss", "logistic_loss", - "sum_absolute_loss", "sum_squared_loss", "sum_logistic_loss", - "scalar_absolute_loss", "scalar_squared_loss", - "scalar_logistic_loss"] - - -def _reduce_batch(x, reduce_fn, name=None): - """Given a tensor `x`, calls reduce_fn to reduce it across dimensions. - - Given a tensor with number of dimensions > 1, _reduce_batch will reduce the - tensor across all dimensions except for dimension 0. As an example, given a - tensor of shape [batch_size, d1, d2], this function will reduce across - dimensions d1 and d2, returning a tensor of shape [batch_size]. - - Tensors of dimension 1 are returned as-is, while tensors of dimension 0 - raise a ValueError. - - Args: - x: A `Tensor` with dimension > 0. - reduce_fn: A math_ops reduce function that takes arguments of - `x`, `reduction_indices`, and `name`. - name: A name for the operation (optional). - - Returns: - A `Tensor` with values reduced by reduce_fn across all dimensions > 0. - - Raises: - ValueError: If `x` has dimension 0. - """ - x = ops.convert_to_tensor(x, name="x") - with ops.op_scope([x], name, "reduce_batch"): - ndims = x.get_shape().ndims - if ndims == 0: - raise ValueError("Cannot reduce a scalar into batches.") - elif ndims == 1: - return x # Don't include a useless reduction. - elif ndims: - reduction_indices = math_ops.range(1, ndims) - shape = [x.get_shape().dims[0]] - else: - reduction_indices = math_ops.range(1, array_ops.size(array_ops.shape(x))) - shape = [None] # We don't know much about the shape, but it is rank 1. - result = reduce_fn(x, reduction_indices=reduction_indices) - - # Give a shape hint in case we have extra information. - result.set_shape(shape) - return result - - -def reduce_batch_sum(x, name=None): - """Given a tensor `x`, sums across all dimensions except dimension 0. - - Given a tensor with the number of dimensions > 1, reduce_batch_sum - will sum across all dimensions except for dimension 0. This function - is useful for summing the loss (error) across all examples in a - batch when training. As an example, given a tensor of shape - [batch_size, d1, d2], this function will sum across dimensions d1 - and d2, returning a tensor of shape [batch_size]. - - Tensors of dimension 1 are returned as-is, while tensors of dimension 0 - raise a ValueError. - - Args: - x: A `Tensor` with dimension > 0. - name: A name for the operation (optional). - - Returns: - A `Tensor` with values summed across all dimensions > 0. - - Raises: - ValueError: If `x` has dimension 0. - - """ - return _reduce_batch(x, math_ops.reduce_sum, name) - - -def _validate_predicted_and_target(predicted, target): - # TODO(ptucker): Optionally add assert op for shape check, for cases when - # shape is not fully defined at graph construction time? - predicted.get_shape().assert_is_compatible_with(target.get_shape()) - tensor_util.assert_same_float_dtype([predicted, target]) - - -def absolute_loss(predicted, target, name=None): - """Computes and returns the per-example absolute loss. - - Computes the per-example absolute value of the difference between - the target and predicted tensors. The tensors must have the same - shape. - - Args: - predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` - of predicted values. - target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of - target values. The shape of the target tensor should match the - `predicted` tensor. - name: A name for the operation (optional). - - Returns: - A `[batch_size, dim_1, ..., dim_n]` tensor of per-example absolute losses. - - Raises: - ValueError: If `predicted` and `target` shapes do not match. - - """ - with ops.op_scope([predicted, target], name, "absolute_loss") as scope: - predicted = ops.convert_to_tensor(predicted, name="predicted") - target = ops.convert_to_tensor(target, name="target") - _validate_predicted_and_target(predicted, target) - return math_ops.abs(target - predicted, name=scope) - - -def squared_loss(predicted, target, name=None): - """Computes and returns the per-example squared loss, divided by 2. - - Computes the per-example squared difference between the target and - predicted tensors. The tensors must have the same shape. - - Args: - predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` - of predicted values. - target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of - target values. The shape of the target tensor should match the - `predicted` tensor. - name: A name for the operation (optional). - - Returns: - A `[batch_size, dim_1, ..., dim_n]` tensor of per-example squared losses. - - Raises: - ValueError: If `predicted` and `target` shapes do not match. - - """ - with ops.op_scope([predicted, target], name, "squared_loss") as scope: - predicted = ops.convert_to_tensor(predicted, name="predicted") - target = ops.convert_to_tensor(target, name="target") - _validate_predicted_and_target(predicted, target) - return math_ops.div(math_ops.square(target - predicted), 2.0, name=scope) - - -def logistic_loss(logit, target, name=None): - """Calculates the logistic cross-entropy loss. - - **WARNING:** `logit` must be unscaled, while the `target` should be a - normalized probability prediction. See - `tf.nn.sigmoid_cross_entropy_with_logits` for more details. - - Args: - logit: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` - of predicted logit values. - target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of - target values. The shape of the target tensor should match the - `predicted` tensor. - name: A name for the operation (optional). - - Returns: - A `Tensor` of the logistic cross-entropy loss. - """ - return nn.sigmoid_cross_entropy_with_logits(logit, target, name=name) - - -def _sum_loss(predicted, target, loss_fn, name="sum_loss"): - """Apply loss function, then sum across all non-batch dimensions. - - Args: - predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` - of predicted values. - target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of - target values. The shape of the target tensor should match the - `predicted` tensor. - loss_fn: Loss to apply, takes 2 tensors as parameters and returns a tensor. - name: A name for the operation (optional). - - Returns: - A `[batch_size]` tensor of losses, averaged across all dimensions except - dimension 0. - """ - return reduce_batch_sum(loss_fn(predicted, target), name=name) - - -def sum_absolute_loss(predicted, target, name="sum_absolute_loss"): - """Calculates the sum of absolute losses across batches. - - Computes the absolute difference between the target and predicted - tensors, averaged across all dimensions except dimension 0: - - losses = reduce_batch_sum(absolute_loss(predicted, target)) - - where `losses` is a tensor with dimensions [batch_size]. - - The tensors must have the same shape. - - This loss function is a form of L1 loss. - - Args: - predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` - of predicted values. - target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of - target values. The shape of the target tensor should match the - `predicted` tensor. - name: A name for the operation (optional). - - Returns: - A `[batch_size]` tensor of absolute differences, averaged across all - dimensions except dimension 0. - - Raises: - ValueError: If `predicted` and `target` shapes do not match. - - """ - return _sum_loss(predicted, target, absolute_loss, name=name) - - -def sum_squared_loss(predicted, target, name="sum_squared_loss"): - """Calculates the sum of the squared loss across batches. - - Computes the squared difference between the target and predicted - tensors, sums across all dimensions except dimension 0. - - losses = reduce_batch_sum(squared_loss(predicted, target)) - - where `losses` is a tensor with dimensions [batch_size]. - - The tensors must have the same shape. - - This function is equivalent to typical formulations of L2 loss, and - similar to TensorFlow's l2_loss function. It differs from the - l2_loss function by allowing the caller to specify both the - predicted and target tensors. - - Args: - predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` - of predicted values. - target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of - target values. The shape of the target tensor should match the - `predicted` tensor. - name: A name for the operation (optional). - - Returns: - A `[batch_size]` tensor of squared losses summed across all dimensions - except dimension 0. - - Raises: - ValueError: If `predicted` and `target` shapes do not match. - - """ - return _sum_loss(predicted, target, squared_loss, name=name) - - -def sum_logistic_loss(logit, target, name="sum_logistic_loss"): - """Calculates the sum of the logistic loss across batches. - - Computes the logistic between logit and predicted tensors, summed across all - dimensions except dimension 0. - - **WARNING:** `logit` must be unscaled, while the `target` should be a - normalized probability prediction. See - `tf.nn.sigmoid_cross_entropy_with_logits` for more details. - - Args: - logit: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` - of predicted logit values. - target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of - target values. The shape of the target tensor should match the - `predicted` tensor. - name: A name for the operation (optional). - - Returns: - A `[batch_size]` tensor of logistic losses summed across all dimensions - except dimension 0. - """ - return _sum_loss(logit, target, logistic_loss, name=name) - - -def _scalar_loss(predicted, target, loss_fn, name=None): - """Reduces losses to a scalar. - - Args: - predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` - of predicted values. - target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of - target values. The shape of the target tensor should match the - `predicted` tensor. - loss_fn: Loss to apply, takes 2 tensors as parameters and returns a tensor. - name: A name for the operation (optional). - - Returns: - Caculate sum of losses per example, then average across batch. - """ - with ops.op_scope([predicted, target], name, "scalar_loss") as scope: - return math_ops.reduce_mean( - _sum_loss(predicted, target, loss_fn), name=scope) - - -def scalar_absolute_loss(predicted, target, name="scalar_absolute_loss"): - """Reduces absolute losses to a scalar. - - Args: - predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` - of predicted values. - target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of - target values. The shape of the target tensor should match the - `predicted` tensor. - name: A name for the operation (optional). - - Returns: - Caculate sum of absolute losses per example, then average across batch. - """ - return _scalar_loss(predicted, target, loss_fn=absolute_loss, name=name) - - -def scalar_squared_loss(predicted, target, name="scalar_squared_loss"): - """Reduces squared losses to a scalar. - - Args: - predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` - of predicted values. - target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of - target values. The shape of the target tensor should match the - `predicted` tensor. - name: A name for the operation (optional). - - Returns: - Caculate sum of squared losses per example, then average across batch. - """ - return _scalar_loss(predicted, target, loss_fn=squared_loss, name=name) - - -def scalar_logistic_loss(logit, target, name="scalar_logistic_loss"): - """Calculates the logistic cross-entropy loss, averaged across batches. - - **WARNING:** `logit` must be unscaled, while the `target` should be a - normalized probability prediction. See - `tf.nn.sigmoid_cross_entropy_with_logits` for more details. - - Args: - logit: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` - of predicted logit values. - target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of - target values. The shape of the target tensor should match the - `predicted` tensor. - name: A name for the operation (optional). - - Returns: - A scalar `tensor` of the logistic cross-entropy loss, averaged across - batches. - - Raises: - ValueError: If `logit` and `target` shapes do not match. - """ - return _scalar_loss(logit, target, loss_fn=logistic_loss, name=name) - diff --git a/tensorflow/contrib/layers/python/ops/loss_ops_test.py b/tensorflow/contrib/layers/python/ops/loss_ops_test.py deleted file mode 100644 index 1453af5331..0000000000 --- a/tensorflow/contrib/layers/python/ops/loss_ops_test.py +++ /dev/null @@ -1,310 +0,0 @@ -# Copyright 2015 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Tests for contrib.layers.python.ops.loss_ops.""" -# pylint: disable=unused-import,g-bad-import-order -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf -from tensorflow.contrib.layers.python.framework import tensor_util - -pi = 3.14 -indiana_pi = 3.2 # https://en.wikipedia.org/wiki/Indiana_Pi_Bill - - -class ReduceBatchSumTest(tf.test.TestCase): - - def testDimensionNone(self): - with self.test_session(): - input_array = np.array([ - [1.0, 2.0], - [-1.0, -2.0] - ], dtype=np.float32) - placeholder_vec = tf.placeholder(tf.float32, name="placeholder_vec") - expected_result = np.array([3.0, -3.0]) - actual_result = tf.contrib.layers.reduce_batch_sum(placeholder_vec) - self.assertEqual(actual_result.get_shape().as_list(), [None]) - self.assertAllClose(expected_result, actual_result.eval(feed_dict={ - placeholder_vec: input_array - })) - - def testDimension0(self): - with self.test_session(): - input_vec = tf.constant(2.0) - with self.assertRaises(ValueError): - tf.contrib.layers.reduce_batch_sum(input_vec) - - def testDimension1(self): - with self.test_session(): - input_vec = tf.constant([1.0, 2.0]) - expected_result = np.array([1.0, 2.0]) - actual_result = tf.contrib.layers.reduce_batch_sum(input_vec) - self.assertAllClose(expected_result, actual_result.eval()) - - def testDimension2(self): - with self.test_session(): - input_vec = tf.constant([ - [1.0, 2.0], - [-1.0, -2.0] - ]) - expected_result = np.array([3.0, -3.0]) - actual_result = tf.contrib.layers.reduce_batch_sum(input_vec) - self.assertAllClose(expected_result, actual_result.eval()) - - def testReturnShape(self): - with self.test_session(): - input_vec = tf.constant([ - [1.0, 2.0], - [-1.0, -2.0] - ]) - expected_result = np.array([3.0, -3.0]) - actual_result = tf.contrib.layers.reduce_batch_sum(input_vec) - self.assertShapeEqual(expected_result, actual_result) - - def testDimensionN(self): - with self.test_session(): - input_vec = tf.constant([ - [ - [1.0, 2.0], - [3.0, 4.0] - ], - [ - [5.0, 6.0], - [7.0, 8.0] - ] - ]) - expected_result = np.array([10.0, 26.0]) - actual_result = tf.contrib.layers.reduce_batch_sum(input_vec) - self.assertAllClose(expected_result, actual_result.eval()) - - -class AbsoluteLossTest(tf.test.TestCase): - - def _getTestVectors(self): - target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target") - predicted = tf.constant([1.1, -0.2, 3.3, 1.6], shape=[2, 2], - name="predicted") - expected_loss = np.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2) - return target, predicted, expected_loss - - def testAbsoluteLoss(self): - with self.test_session(): - target, predicted, expected_loss = self._getTestVectors() - result = tf.contrib.layers.absolute_loss(predicted, target) - self.assertAllClose(expected_loss, result.eval()) - - def testAbsoluteLossReturnShape(self): - with self.test_session(): - target, predicted, expected_loss = self._getTestVectors() - result = tf.contrib.layers.absolute_loss(predicted, target) - self.assertShapeEqual(expected_loss, result) - - def testInvalidShapesValueError(self): - with self.test_session(): - target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target") - incompatible_shape = tf.constant([0.0, 1.1], shape=[2], - name="incompatible_shape") - with self.assertRaises(ValueError): - tf.contrib.layers.absolute_loss(incompatible_shape, target) - - -class SquaredLossTest(tf.test.TestCase): - - def _getTestVectors(self): - target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target") - predicted = tf.constant([1.1, -0.2, 3.3, 1.6], shape=[2, 2], - name="predicted") - expected_loss = np.array([0.005, 0.02, 0.045, 0.08]).reshape(2, 2) - return target, predicted, expected_loss - - def testSquaredLoss(self): - with self.test_session(): - target, predicted, expected_loss = self._getTestVectors() - result = tf.contrib.layers.squared_loss(predicted, target) - self.assertAllClose(expected_loss, result.eval()) - - def testSquaredLossReturnShape(self): - with self.test_session(): - target, predicted, expected_loss = self._getTestVectors() - result = tf.contrib.layers.squared_loss(predicted, target) - self.assertShapeEqual(expected_loss, result) - - def testInvalidShapesValueError(self): - with self.test_session(): - target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target") - incompatible_shape = tf.constant([0.0, 1.1], shape=[2], - name="incompatible_shape") - with self.assertRaises(ValueError): - tf.contrib.layers.squared_loss(incompatible_shape, target) - - -class SumSquaredLossTest(tf.test.TestCase): - - def _getTestVectors(self): - target = tf.constant([[0.0, 1.0], - [3.0, 2.0]], - shape=[2, 2], - name="target") - predicted = tf.constant([[3.0, -2.0], - [1.0, 2.0]], - shape=[2, 2], - name="predicted") - expected_loss = np.array([9.0, 2.0]) - return target, predicted, expected_loss - - def testSumSquaredLoss(self): - with self.test_session(): - target, predicted, expected_loss = self._getTestVectors() - result = tf.contrib.layers.sum_squared_loss(predicted, target) - self.assertAllClose(expected_loss, result.eval()) - - def testSumSquaredLossReturnShape(self): - with self.test_session(): - target, predicted, expected_loss = self._getTestVectors() - result = tf.contrib.layers.sum_squared_loss(predicted, target) - self.assertShapeEqual(expected_loss, result) - - def testInvalidShapesValueError(self): - with self.test_session(): - target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target") - incompatible_shape = tf.constant([0.0, 1.1], shape=[2], - name="incompatible_shape") - with self.assertRaises(ValueError): - tf.contrib.layers.sum_squared_loss(incompatible_shape, target) - - -class ScalarAbsoluteLossTest(tf.test.TestCase): - - def testScalarAbsoluteLoss(self): - with self.test_session(): - actual = tf.constant([pi], name="pi") - actual_placeholder = tf.placeholder(tf.float32) - label = tf.constant([indiana_pi], name="lbl") - label_placeholder = tf.placeholder(tf.float32, name="lbl_ph") - expected_loss = abs(indiana_pi - pi) - - # Both shapes are set. - both_shapes_loss = tf.contrib.layers.scalar_absolute_loss(actual, label) - tf.initialize_all_variables().run() - np.testing.assert_almost_equal( - both_shapes_loss.eval(), expected_loss, decimal=6) - - # No shape for 'actual' - check that the loss layer can be created. - no_actual_shape_loss = tf.contrib.layers.scalar_absolute_loss( - actual_placeholder, label) - tf.initialize_all_variables().run() - np.testing.assert_almost_equal( - no_actual_shape_loss.eval({actual_placeholder: [pi]}), - expected_loss, decimal=6) - - # No shape for 'label' - check that the loss layer can be created. - no_label_shape_loss = tf.contrib.layers.scalar_absolute_loss( - actual, label_placeholder) - tf.initialize_all_variables().run() - np.testing.assert_almost_equal( - no_label_shape_loss.eval({label_placeholder: [indiana_pi]}), - expected_loss, decimal=6) - - # No shapes. - no_shape_loss = tf.contrib.layers.scalar_absolute_loss( - actual_placeholder, label_placeholder) - tf.initialize_all_variables().run() - np.testing.assert_almost_equal( - no_shape_loss.eval({label_placeholder: [indiana_pi], - actual_placeholder: [pi]}), - expected_loss, decimal=6) - - # Evaluate the previous one again, but this time with different - # (matching) shapes. This should still work. - np.testing.assert_almost_equal( - no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi], - actual_placeholder: [pi, pi]}), - expected_loss, decimal=6) - - -class ScalarSquaredLossTest(tf.test.TestCase): - - def testScalarSquaredLoss(self): - with self.test_session(): - actual = tf.constant([pi], name="pi") - actual_placeholder = tf.placeholder(tf.float32) - label = tf.constant([indiana_pi], name="lbl") - label_placeholder = tf.placeholder(tf.float32, name="lbl_ph") - expected_loss = (indiana_pi - pi) * (indiana_pi - pi) / 2 - - # Both shapes are set. - both_shapes_loss = tf.contrib.layers.scalar_squared_loss(actual, label) - tf.initialize_all_variables().run() - np.testing.assert_almost_equal( - both_shapes_loss.eval(), expected_loss, decimal=6) - - # No shape for 'actual' - check that the loss layer can be created. - no_actual_shape_loss = tf.contrib.layers.scalar_squared_loss( - actual_placeholder, label) - tf.initialize_all_variables().run() - np.testing.assert_almost_equal( - no_actual_shape_loss.eval({actual_placeholder: [pi]}), - expected_loss, decimal=6) - - # No shape for 'label' - check that the loss layer can be created. - no_label_shape_loss = tf.contrib.layers.scalar_squared_loss( - actual, label_placeholder) - tf.initialize_all_variables().run() - np.testing.assert_almost_equal( - no_label_shape_loss.eval({label_placeholder: [indiana_pi]}), - expected_loss, - decimal=6) - - # No shapes. - no_shape_loss = tf.contrib.layers.scalar_squared_loss( - actual_placeholder, label_placeholder) - tf.initialize_all_variables().run() - np.testing.assert_almost_equal( - no_shape_loss.eval({label_placeholder: [indiana_pi], - actual_placeholder: [pi]}), - expected_loss, decimal=6) - - # Evaluate the previous one again, but this time with different - # (matching) shapes. This should still work. - np.testing.assert_almost_equal( - no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi], - actual_placeholder: [pi, pi]}), - expected_loss, decimal=6) - - -class ScalarLogisticLossTest(tf.test.TestCase): - - def _expected_loss(self, logit, target): - sigmoid = 1.0 / (1.0 + np.exp(-logit)) - logistic_loss = (target * -np.log(sigmoid)) - ( - (1.0 - target) * np.log(1.0 - sigmoid)) - batch_losses = np.sum(logistic_loss, 1) - - return np.sum(batch_losses) / len(batch_losses) - - def test_scalar_logistic_loss(self): - logit = np.array([[9.45, -42], [4.2, 1], [-0.6, 20]]) - target = np.array([[0.8, 0.9], [0.45, 0.99999], [0.1, 0.0006]]) - with self.test_session(): - result = tf.contrib.layers.scalar_logistic_loss( - tf.constant(logit), tf.constant(target)) - self.assertAllClose(self._expected_loss(logit, target), result.eval()) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow/contrib/losses/BUILD b/tensorflow/contrib/losses/BUILD new file mode 100644 index 0000000000..8452132c45 --- /dev/null +++ b/tensorflow/contrib/losses/BUILD @@ -0,0 +1,42 @@ +# Description: +# contains parts of TensorFlow that are experimental or unstable and which are not supported. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +py_library( + name = "losses_py", + srcs = [ + "__init__.py", + "python/losses/__init__.py", + "python/losses/loss_ops.py", + ], + srcs_version = "PY2AND3", +) + +py_test( + name = "loss_ops_test", + srcs = glob(["python/losses/loss_ops_test.py"]), + srcs_version = "PY2AND3", + deps = [ + ":losses_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/losses/__init__.py b/tensorflow/contrib/losses/__init__.py new file mode 100644 index 0000000000..ec796a6ca7 --- /dev/null +++ b/tensorflow/contrib/losses/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for building neural network losses.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.losses.python.losses import * +from tensorflow.python.util.all_util import make_all diff --git a/tensorflow/contrib/losses/python/losses/__init__.py b/tensorflow/contrib/losses/python/losses/__init__.py new file mode 100644 index 0000000000..d5e426e389 --- /dev/null +++ b/tensorflow/contrib/losses/python/losses/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A module containing TensorFlow ops whose API may change in the future.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.contrib.losses.python.losses.loss_ops import * diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py new file mode 100644 index 0000000000..c704222c4b --- /dev/null +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -0,0 +1,286 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""## Loss operations for use in neural networks. + +The loss ops measure error for use in neural networks. These losses +can be used for measuring accuracy of a network in a regression task +or for regularization purposes (e.g., weight decay). + +These loss ops are, by design, minimal, enabling flexibility in how +their output can be used. + +@@absolute +@@squared +@@logistic +@@softmax +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.framework import tensor_util +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn + +__all__ = ["absolute", "squared", "logistic", "softmax"] + + +def _reduce_batch(x, reduce_fn, name=None): + """Given a tensor `x`, calls reduce_fn to reduce it across dimensions. + + Given a tensor with number of dimensions > 1, _reduce_batch will reduce the + tensor across all dimensions except for dimension 0. As an example, given a + tensor of shape [batch_size, d1, d2], this function will reduce across + dimensions d1 and d2, returning a tensor of shape [batch_size]. + + Tensors of dimension 1 are returned as-is, while tensors of dimension 0 + raise a ValueError. + + Args: + x: A `Tensor` with dimension > 0. + reduce_fn: A math_ops reduce function that takes arguments of + `x`, `reduction_indices`, and `name`. + name: A name for the operation (optional). + + Returns: + A `Tensor` with values reduced by reduce_fn across all dimensions > 0. + + Raises: + ValueError: If `x` has dimension 0. + """ + x = ops.convert_to_tensor(x, name="x") + with ops.op_scope([x], name, "reduce_batch"): + ndims = x.get_shape().ndims + if ndims == 0: + raise ValueError("Cannot reduce a scalar into batches.") + elif ndims == 1: + return x # Don't include a useless reduction. + elif ndims: + reduction_indices = math_ops.range(1, ndims) + shape = [x.get_shape().dims[0]] + else: + reduction_indices = math_ops.range(1, array_ops.size(array_ops.shape(x))) + shape = [None] # We don't know much about the shape, but it is rank 1. + result = reduce_fn(x, reduction_indices=reduction_indices) + + # Give a shape hint in case we have extra information. + result.set_shape(shape) + return result + + +def _reduce_batch_sum(x, name=None): + """Given a tensor `x`, sums across all dimensions except dimension 0. + + Given a tensor with the number of dimensions > 1, this will sum across all + dimensions except for dimension 0. This function is useful for summing the + loss (error) across all examples in a batch when training. As an example, + given a tensor of shape [batch_size, d1, d2], this function will sum across + dimensions d1 and d2, returning a tensor of shape [batch_size]. + + Tensors of dimension 1 are returned as-is, while tensors of dimension 0 + raise a ValueError. + + Args: + x: A `Tensor` with dimension > 0. + name: A name for the operation (optional). + + Returns: + A `Tensor` with values summed across all dimensions > 0. + + Raises: + ValueError: If `x` has dimension 0. + + """ + return _reduce_batch(x, math_ops.reduce_sum, name) + + +def _reduce_to_scalar(x, name=None): + """Reduces losses to a scalar. + + Given a tensor `x`, sums across all dimensions except dimension 0, then + average across dimension 0. + + Args: + x: A `Tensor` with dimension > 0. + name: A name for the operation (optional). + + Returns: + Caculate sum of losses per example, then average across batch. + """ + with ops.op_scope([x], name, "scalar") as scope: + return math_ops.reduce_mean(_reduce_batch_sum(x), name=scope) + + +def _validate_predicted_and_target(predicted, target): + # TODO(ptucker): Optionally add assert op for shape check, for cases when + # shape is not fully defined at graph construction time? + predicted.get_shape().assert_is_compatible_with(target.get_shape()) + tensor_util.assert_same_float_dtype([predicted, target]) + + +def _raw_absolute(predicted, target, name=None): + """Computes and returns the per-example absolute loss. + + Computes the per-example absolute value of the difference between + the target and predicted tensors. The tensors must have the same + shape. + + Args: + predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` + of predicted values. + target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of + target values. The shape of the target tensor should match the + `predicted` tensor. + name: A name for the operation (optional). + + Returns: + A `[batch_size, dim_1, ..., dim_n]` tensor of per-example absolute losses. + + Raises: + ValueError: If `predicted` and `target` shapes do not match. + + """ + with ops.op_scope([predicted, target], name, "absolute_loss") as scope: + predicted = ops.convert_to_tensor(predicted, name="predicted") + target = ops.convert_to_tensor(target, name="target") + _validate_predicted_and_target(predicted, target) + return math_ops.abs(target - predicted, name=scope) + + +def _raw_squared(predicted, target, name=None): + """Computes and returns the per-example squared loss, divided by 2. + + Computes the per-example squared difference between the target and + predicted tensors. The tensors must have the same shape. + + Args: + predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` + of predicted values. + target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of + target values. The shape of the target tensor should match the + `predicted` tensor. + name: A name for the operation (optional). + + Returns: + A `[batch_size, dim_1, ..., dim_n]` tensor of per-example squared losses. + + Raises: + ValueError: If `predicted` and `target` shapes do not match. + + """ + with ops.op_scope([predicted, target], name, "squared_loss") as scope: + predicted = ops.convert_to_tensor(predicted, name="predicted") + target = ops.convert_to_tensor(target, name="target") + _validate_predicted_and_target(predicted, target) + return math_ops.div(math_ops.square(target - predicted), 2.0, name=scope) + + +def absolute(predicted, target, name=None): + """Reduces absolute losses to a scalar. + + Args: + predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` + of predicted values. + target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of + target values. The shape of the target tensor should match the + `predicted` tensor. + name: A name for the operation (optional). + + Returns: + Caculate sum of absolute losses per example, then average across batch. + """ + with ops.op_scope([predicted, target], name, "absolute_loss") as scope: + return _reduce_to_scalar(_raw_absolute(predicted, target), name=scope) + + +def squared(predicted, target, name=None): + """Reduces squared losses to a scalar. + + Args: + predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` + of predicted values. + target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of + target values. The shape of the target tensor should match the + `predicted` tensor. + name: A name for the operation (optional). + + Returns: + Caculate sum of squared losses per example, then average across batch. + """ + with ops.op_scope([predicted, target], name, "squared_loss") as scope: + return _reduce_to_scalar(_raw_squared(predicted, target), name=scope) + + +def logistic(logit, target, name=None): + """Calculates the logistic cross-entropy loss, averaged across batches. + + **WARNING:** `logit` must be unscaled, while the `target` should be a + normalized probability prediction. See + `tf.nn.sigmoid_cross_entropy_with_logits` for more details. + + Args: + logit: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` + of predicted logit values. + target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of + target values. The shape of the target tensor should match the + `logit` tensor. + name: A name for the operation (optional). + + Returns: + A scalar `tensor` of the logistic cross-entropy loss, averaged across + batches. + + Raises: + ValueError: If `logit` and `target` shapes do not match. + """ + with ops.op_scope([logit, target], name, "logistic_loss") as scope: + return _reduce_to_scalar( + nn.sigmoid_cross_entropy_with_logits(logit, target), name=scope) + + +def softmax(logit, target, name=None): + """Calculates the softmax cross-entropy loss, averaged across batches. + + **WARNING:** `logit` must be unscaled, while the `target` should be a + normalized probability prediction. See + `tf.nn.sigmoid_cross_entropy_with_logits` for more details. + + Args: + logit: Tensor of actual values. Shape must have rank 2, generally + (batch, num_classes). num_classes must be > 1. For single-class + regression, use `logistic`. Type must be `tf.float32` or `tf.float64`. + target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of + target values. The shape of the target tensor should match the + `logit` tensor. + name: A name for the operation (optional). + + Returns: + A scalar `tensor` of the softmax cross-entropy loss, averaged across + batches. + + Raises: + ValueError: If `logit` and `target` shapes do not match. + """ + with ops.op_scope([logit, target], name, "softmax_loss") as scope: + shape = logit.get_shape().with_rank(2) + if shape.dims[1] and shape.dims[1] < 2: + raise ValueError( + "Invalid shape %s; use logistic() instead for only 1 class." % + shape) + return _reduce_to_scalar( + nn.softmax_cross_entropy_with_logits(logit, target), name=scope) diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py new file mode 100644 index 0000000000..71e464362f --- /dev/null +++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py @@ -0,0 +1,272 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for contrib.losses.python.losses.loss_ops.""" +# pylint: disable=unused-import,g-bad-import-order +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.framework.python.framework import tensor_util + +pi = 3.14 +indiana_pi = 3.2 # https://en.wikipedia.org/wiki/Indiana_Pi_Bill + + +class AbsoluteLossTest(tf.test.TestCase): + + def testAbsoluteLoss(self): + with self.test_session(): + actual = tf.constant([pi], name="pi") + actual_placeholder = tf.placeholder(tf.float32) + label = tf.constant([indiana_pi], name="lbl") + label_placeholder = tf.placeholder(tf.float32, name="lbl_ph") + expected_loss = abs(indiana_pi - pi) + + # Both shapes are set. + both_shapes_loss = tf.contrib.losses.absolute(actual, label) + tf.initialize_all_variables().run() + np.testing.assert_almost_equal( + both_shapes_loss.eval(), expected_loss, decimal=6) + + # No shape for 'actual' - check that the loss layer can be created. + no_actual_shape_loss = tf.contrib.losses.absolute( + actual_placeholder, label) + tf.initialize_all_variables().run() + np.testing.assert_almost_equal( + no_actual_shape_loss.eval({actual_placeholder: [pi]}), + expected_loss, decimal=6) + + # No shape for 'label' - check that the loss layer can be created. + no_label_shape_loss = tf.contrib.losses.absolute( + actual, label_placeholder) + tf.initialize_all_variables().run() + np.testing.assert_almost_equal( + no_label_shape_loss.eval({label_placeholder: [indiana_pi]}), + expected_loss, decimal=6) + + # No shapes. + no_shape_loss = tf.contrib.losses.absolute( + actual_placeholder, label_placeholder) + tf.initialize_all_variables().run() + np.testing.assert_almost_equal( + no_shape_loss.eval({label_placeholder: [indiana_pi], + actual_placeholder: [pi]}), + expected_loss, decimal=6) + + # Evaluate the previous one again, but this time with different + # (matching) shapes. This should still work. + np.testing.assert_almost_equal( + no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi], + actual_placeholder: [pi, pi]}), + expected_loss, decimal=6) + + +class SquaredLossTest(tf.test.TestCase): + + def testSquaredLoss(self): + with self.test_session(): + actual = tf.constant([pi], name="pi") + actual_placeholder = tf.placeholder(tf.float32) + label = tf.constant([indiana_pi], name="lbl") + label_placeholder = tf.placeholder(tf.float32, name="lbl_ph") + expected_loss = (indiana_pi - pi) * (indiana_pi - pi) / 2 + + # Both shapes are set. + both_shapes_loss = tf.contrib.losses.squared(actual, label) + tf.initialize_all_variables().run() + np.testing.assert_almost_equal( + both_shapes_loss.eval(), expected_loss, decimal=6) + + # No shape for 'actual' - check that the loss layer can be created. + no_actual_shape_loss = tf.contrib.losses.squared( + actual_placeholder, label) + tf.initialize_all_variables().run() + np.testing.assert_almost_equal( + no_actual_shape_loss.eval({actual_placeholder: [pi]}), + expected_loss, decimal=6) + + # No shape for 'label' - check that the loss layer can be created. + no_label_shape_loss = tf.contrib.losses.squared( + actual, label_placeholder) + tf.initialize_all_variables().run() + np.testing.assert_almost_equal( + no_label_shape_loss.eval({label_placeholder: [indiana_pi]}), + expected_loss, + decimal=6) + + # No shapes. + no_shape_loss = tf.contrib.losses.squared( + actual_placeholder, label_placeholder) + tf.initialize_all_variables().run() + np.testing.assert_almost_equal( + no_shape_loss.eval({label_placeholder: [indiana_pi], + actual_placeholder: [pi]}), + expected_loss, decimal=6) + + # Evaluate the previous one again, but this time with different + # (matching) shapes. This should still work. + np.testing.assert_almost_equal( + no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi], + actual_placeholder: [pi, pi]}), + expected_loss, decimal=6) + + +class LogisticTest(tf.test.TestCase): + + def _expected_loss(self, logit, target): + sigmoid = 1.0 / (1.0 + np.exp(-logit)) + logistic_loss = (target * -np.log(sigmoid)) - ( + (1.0 - target) * np.log(1.0 - sigmoid)) + batch_losses = np.sum(logistic_loss, 1) + + return np.sum(batch_losses) / len(batch_losses) + + def testSimple(self): + logit = np.array([[9.45, -42], [4.2, 1], [-0.6, 20]]) + target = np.array([[0.8, 0.9], [0.45, 0.99999], [0.1, 0.0006]]) + with self.test_session(): + loss = tf.contrib.losses.logistic(tf.constant(logit), tf.constant(target)) + self.assertAllClose(self._expected_loss(logit, target), loss.eval()) + + def testComplex(self): + with self.test_session(): + # [batch] and [batch,1] work the same. + loss3x0 = tf.contrib.losses.logistic( + tf.constant([-1.0, 3.0, -3.0]), + tf.constant([0.3, 0.1, 0.4])) + tf.initialize_all_variables().run() + self.assertAllClose(1.536812, loss3x0.eval()) + + expected_loss = 1.536812 + actual3x1 = [[-1.0], [3.0], [-3.0]] + label3x1 = [[0.3], [0.1], [0.4]] + loss3x1 = tf.contrib.losses.logistic( + tf.constant(actual3x1), tf.constant(label3x1)) + tf.initialize_all_variables().run() + self.assertAllClose(expected_loss, loss3x1.eval()) + + # Batch average stays the same with repeats of the same examples. + loss9x1 = tf.contrib.losses.logistic( + tf.constant(actual3x1 * 3), tf.constant(label3x1 * 3)) + tf.initialize_all_variables().run() + self.assertAllClose(expected_loss, loss9x1.eval()) + + # Loss stays the same when adding another class with 0 loss. + loss3x2 = tf.contrib.losses.logistic( + tf.constant([[-1.0, 100.0], [3.0, -100.0], [-3.0, -100.0]]), + tf.constant([[0.3, 1.0], [0.1, 0.0], [0.4, 0.0]])) + tf.initialize_all_variables().run() + self.assertAllClose(expected_loss, loss3x2.eval()) + + # Loss stays the same with additional x1 dimension. + loss3x1x2 = tf.contrib.losses.logistic( + tf.constant([[[-1.0, 100.0]], [[3.0, -100.0]], [[-3.0, -100.0]]]), + tf.constant([[[0.3, 1.0]], [[0.1, 0.0]], [[0.4, 0.0]]])) + tf.initialize_all_variables().run() + self.assertAllClose(expected_loss, loss3x1x2.eval()) + + # We have set one label value to be out of range (the -0.4) and + # expect the absence of a crash since we did not set validate=True + loss = tf.contrib.losses.logistic( + tf.constant([[[-1.0, 100.0]], [[3.0, -100.0]], [[-3.0, -100.0]]]), + tf.constant([[[0.3, 1.0]], [[0.1, 0.0]], [[-0.4, 0.0]]])) + tf.initialize_all_variables().run() + loss.eval() + + def testLogisticVsSoftmax(self): + with self.test_session(): + # Each logit = L and target = T used for logistic_loss corresponds to + # logits [a, b] where a - b = L and targets [T, 1 - T] for + # softmax_loss. + + expected_loss = (0.69314718 + 1.01326168 + 2.10692811) / 3.0 + + logistic_loss = tf.contrib.losses.logistic( + tf.constant([0.0, 1.0, 2.0]), + tf.constant([0.5, 0.3, 0.01])) + tf.initialize_all_variables().run() + self.assertAllClose(expected_loss, logistic_loss.eval()) + + softmax_loss = tf.contrib.losses.softmax( + tf.constant([[1.0, 1.0], [2.0, 1.0], [3.0, 1.0]]), + tf.constant([[0.5, 0.5], [0.3, 0.7], [0.01, 0.99]])) + tf.initialize_all_variables().run() + self.assertAllClose(expected_loss, softmax_loss.eval()) + + +class SoftmaxTest(tf.test.TestCase): + + def testAllCorrect(self): + with self.test_session(): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0]]) + loss = tf.contrib.losses.softmax(logits, labels) + self.assertAlmostEqual(loss.eval(), 0.0, 3) + + def testAllWrong(self): + with self.test_session(): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0]]) + loss = tf.contrib.losses.softmax(logits, labels) + self.assertAlmostEqual(loss.eval(), 10.0, 3) + + def testSoftmax(self): + with self.test_session(): + # [batch] and [batch,1] fail, softmax_loss is only for multiclass. + self.assertRaisesRegexp( + ValueError, "must have rank 2", tf.contrib.losses.softmax, + tf.constant([-100.0, 10.0, 0.0]), + tf.constant([1.0, 1.0, 1.0])) + + self.assertRaisesRegexp( + ValueError, "only 1 class", tf.contrib.losses.softmax, + tf.constant([[-100.0], [10.0], [0.0]]), + tf.constant([[1.0], [1.0], [1.0]])) + + expected_loss = 3.173363 + loss3x2 = tf.contrib.losses.softmax( + tf.constant([[-1.0, 1.0], [0.0, 0.0], [10.0, -1.0]]), + tf.constant([[0.5, 0.5], [0.3, 0.7], [0.3, 0.7]])) + tf.initialize_all_variables().run() + self.assertAllClose(expected_loss, loss3x2.eval()) + + # Loss stays the same when adding another negative class. + loss3x3 = tf.contrib.losses.softmax( + tf.constant( + [[-1.0, 1.0, -100.0], [0.0, 0.0, -100.0], [10.0, -1.0, -100.0]]), + tf.constant([[0.5, 0.5, 0.0], [0.3, 0.7, 0.0], [0.3, 0.7, 0.0]])) + tf.initialize_all_variables().run() + self.assertAllClose(expected_loss, loss3x3.eval()) + + # Fails for rank > 2. + self.assertRaisesRegexp( + ValueError, "must have rank 2", tf.contrib.losses.softmax, + tf.constant([[[-1.0, 1.0]], [[0.0, 0.0]], [[10.0, -1.0]]]), + tf.constant([[[0.5, 0.5]], [[0.3, 0.7]], [[0.3, 0.7]]])) + + +if __name__ == "__main__": + tf.test.main() -- cgit v1.2.3