diff options
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 5 | ||||
-rw-r--r-- | tensorflow/python/ops/boosted_trees_ops.py | 160 | ||||
-rw-r--r-- | tensorflow/python/ops/distributions/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/python/ops/init_ops.py | 9 | ||||
-rw-r--r-- | tensorflow/python/ops/linalg/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/python/ops/losses/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 8 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_test.py | 51 | ||||
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 5 | ||||
-rw-r--r-- | tensorflow/python/ops/script_ops.py | 6 |
10 files changed, 237 insertions, 43 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 9106461c60..207866610b 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -957,6 +957,11 @@ def _autopacking_helper(list_or_tuple, dtype, name): Returns: A `tf.Tensor` with value equivalent to `list_or_tuple`. """ + if context.executing_eagerly(): + # NOTE: Fast path when all the items are tensors, this doesn't do any type + # checking. + if all(ops.is_dense_tensor_like(elem) for elem in list_or_tuple): + return gen_array_ops.pack(list_or_tuple, name=name) must_pack = False converted_elems = [] with ops.name_scope(name) as scope: diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py new file mode 100644 index 0000000000..174d00987f --- /dev/null +++ b/tensorflow/python/ops/boosted_trees_ops.py @@ -0,0 +1,160 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for boosted_trees.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_boosted_trees_ops +from tensorflow.python.ops import resources + +# Re-exporting ops used by other modules. +# pylint: disable=unused-import +from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature +from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary +from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict +from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict +from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble +# pylint: enable=unused-import + +from tensorflow.python.training import saver + + +class PruningMode(object): + NO_PRUNING, PRE_PRUNING, POST_PRUNING = range(0, 3) + + +class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for TreeEnsemble.""" + + def __init__(self, resource_handle, create_op, name): + """Creates a _TreeEnsembleSavable object. + + Args: + resource_handle: handle to the decision tree ensemble variable. + create_op: the op to initialize the variable. + name: the name to save the tree ensemble variable under. + """ + stamp_token, serialized = ( + gen_boosted_trees_ops.boosted_trees_serialize_ensemble(resource_handle)) + # slice_spec is useful for saving a slice from a variable. + # It's not meaningful the tree ensemble variable. So we just pass an empty + # value. + slice_spec = '' + specs = [ + saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, + name + '_stamp'), + saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec, + name + '_serialized'), + ] + super(_TreeEnsembleSavable, self).__init__(resource_handle, specs, name) + self._resource_handle = resource_handle + self._create_op = create_op + + def restore(self, restored_tensors, unused_restored_shapes): + """Restores the associated tree ensemble from 'restored_tensors'. + + Args: + restored_tensors: the tensors that were loaded from a checkpoint. + unused_restored_shapes: the shapes this object should conform to after + restore. Not meaningful for trees. + + Returns: + The operation that restores the state of the tree ensemble variable. + """ + with ops.control_dependencies([self._create_op]): + return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble( + self._resource_handle, + stamp_token=restored_tensors[0], + tree_ensemble_serialized=restored_tensors[1]) + + +class TreeEnsemble(object): + """Creates TreeEnsemble resource.""" + + def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''): + with ops.name_scope(name, 'TreeEnsemble') as name: + self._resource_handle = ( + gen_boosted_trees_ops.boosted_trees_ensemble_resource_handle_op( + container='', shared_name=name, name=name)) + create_op = gen_boosted_trees_ops.boosted_trees_create_ensemble( + self.resource_handle, + stamp_token, + tree_ensemble_serialized=serialized_proto) + is_initialized_op = ( + gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized( + self._resource_handle)) + # Adds the variable to the savable list. + if not is_local: + saveable = _TreeEnsembleSavable(self.resource_handle, create_op, + self.resource_handle.name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + resources.register_resource( + self.resource_handle, + create_op, + is_initialized_op, + is_shared=not is_local) + + @property + def resource_handle(self): + return self._resource_handle + + def get_stamp_token(self): + """Returns the current stamp token of the resource.""" + stamp_token, _, _, _ = ( + gen_boosted_trees_ops.boosted_trees_get_ensemble_states( + self.resource_handle)) + return stamp_token + + def get_states(self): + """Returns states of the tree ensemble. + + Returns: + stamp_token, num_trees, num_finalized_trees, num_attempted_layers. + """ + stamp_token, num_trees, num_finalized_trees, num_attempted_layers = ( + gen_boosted_trees_ops.boosted_trees_get_ensemble_states( + self.resource_handle)) + # Use identity to give names. + return (array_ops.identity(stamp_token, name='stamp_token'), + array_ops.identity(num_trees, name='num_trees'), + array_ops.identity(num_finalized_trees, name='num_finalized_trees'), + array_ops.identity( + num_attempted_layers, name='num_attempted_layers')) + + def serialize(self): + """Serializes the ensemble into proto and returns the serialized proto. + + Returns: + stamp_token: int64 scalar Tensor to denote the stamp of the resource. + serialized_proto: string scalar Tensor of the serialized proto. + """ + return gen_boosted_trees_ops.boosted_trees_serialize_ensemble( + self.resource_handle) + + def deserialize(self, stamp_token, serialized_proto): + """Deserialize the input proto and resets the ensemble from it. + + Args: + stamp_token: int64 scalar Tensor to denote the stamp of the resource. + serialized_proto: string scalar Tensor of the serialized proto. + + Returns: + Operation (for dependencies). + """ + return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble( + self.resource_handle, stamp_token, serialized_proto) diff --git a/tensorflow/python/ops/distributions/BUILD b/tensorflow/python/ops/distributions/BUILD index 50b956a267..9d9ede7ad7 100644 --- a/tensorflow/python/ops/distributions/BUILD +++ b/tensorflow/python/ops/distributions/BUILD @@ -26,15 +26,3 @@ py_library( "@six_archive//:six", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index 40ab22951b..9dfe5ffbf4 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -532,8 +532,7 @@ class Orthogonal(Initializer): q, r = linalg_ops.qr(a, full_matrices=False) # Make Q uniform d = array_ops.diag_part(r) - ph = d / math_ops.abs(d) - q *= ph + q *= math_ops.sign(d) if num_rows < num_cols: q = array_ops.matrix_transpose(q) return self.gain * array_ops.reshape(q, shape) @@ -579,7 +578,11 @@ class ConvolutionDeltaOrthogonal(Initializer): a = random_ops.random_normal([shape[-1], shape[-1]], dtype=dtype, seed=self.seed) # Compute the qr factorization - q, _ = linalg_ops.qr(a, full_matrices=False) + q, r = linalg_ops.qr(a, full_matrices=False) + # Make Q uniform + d = array_ops.diag_part(r) + # ph = d / math_ops.abs(d) + q *= math_ops.sign(d) q = q[:shape[-2], :] q *= math_ops.sqrt(math_ops.cast(self.gain, dtype=dtype)) if len(shape) == 3: diff --git a/tensorflow/python/ops/linalg/BUILD b/tensorflow/python/ops/linalg/BUILD index ce8c1580fe..07659ef44c 100644 --- a/tensorflow/python/ops/linalg/BUILD +++ b/tensorflow/python/ops/linalg/BUILD @@ -34,15 +34,3 @@ py_library( "//tensorflow/python:special_math_ops", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/python/ops/losses/BUILD b/tensorflow/python/ops/losses/BUILD index 07741e0c3c..4aea0265a7 100644 --- a/tensorflow/python/ops/losses/BUILD +++ b/tensorflow/python/ops/losses/BUILD @@ -43,15 +43,3 @@ py_test( "//tensorflow/python:framework_for_generated_wrappers", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index a74de39eab..0c55386241 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1836,8 +1836,10 @@ def softmax_cross_entropy_with_logits_v2( [logits, labels]) as name: logits = ops.convert_to_tensor(logits, name="logits") labels = ops.convert_to_tensor(labels, name="labels") + convert_to_float32 = ( + logits.dtype == dtypes.float16 or logits.dtype == dtypes.bfloat16) precise_logits = math_ops.cast( - logits, dtypes.float32) if (logits.dtype == dtypes.float16) else logits + logits, dtypes.float32) if convert_to_float32 else logits # labels and logits must be of the same type labels = math_ops.cast(labels, precise_logits.dtype) input_rank = array_ops.rank(precise_logits) @@ -1883,8 +1885,8 @@ def softmax_cross_entropy_with_logits_v2( del shape[dim] cost.set_shape(shape) - if logits.dtype == dtypes.float16: - return math_ops.cast(cost, dtypes.float16) + if convert_to_float32: + return math_ops.cast(cost, logits.dtype) else: return cost diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index af9dae2aa6..da86d5f6ca 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -852,6 +852,57 @@ class ComputeSampledLogitsTest(test_lib.TestCase): self.assertAllClose(exp_sampled_softmax_loss, got_sampled_softmax_loss.eval(), 1e-4) + def testSampledSoftmaxLossBf16(self): + # A simple test to verify the numerics for bfloat16. + def _SoftmaxCrossEntropyWithLogits(logits, targets): + # logits, targets: float arrays of the same shape. + assert logits.shape == targets.shape + stable_exp_logits = np.exp( + logits - np.amax(logits, axis=1, keepdims=True)) + pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) + return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) + + np.random.seed(0) + num_classes = 5 + batch_size = 3 + labels = [0, 1, 2] + sampled = [1, 0, 2, 3] + (weights, biases, hidden_acts, _, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=1, + labels=labels, + sampled=sampled, + subtract_log_q=True) + exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( + exp_logits, exp_labels) + + with self.test_session(): + true_exp_bf16 = np.full( + [batch_size, 1], fill_value=0.5, dtype=dtypes.bfloat16.as_numpy_dtype) + sampled_exp_bf16 = np.full( + [len(sampled)], fill_value=0.5, dtype=dtypes.bfloat16.as_numpy_dtype) + sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16) + + got_sampled_softmax_loss = math_ops.cast( + nn_impl.sampled_softmax_loss( + weights=constant_op.constant(weights, dtype=dtypes.bfloat16), + biases=constant_op.constant(biases, dtype=dtypes.bfloat16), + labels=constant_op.constant( + labels, shape=(batch_size, 1), dtype=dtypes.bfloat16), + inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals_bf16, + remove_accidental_hits=False, + partition_strategy="div"), dtypes.float32) + + self.assertAllClose(exp_sampled_softmax_loss, + got_sampled_softmax_loss.eval(), 1e-1) + class CReluTest(test_lib.TestCase): diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index df873da98e..2f39ea2e7d 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -1087,6 +1087,11 @@ ops.register_proto_function( proto_type=variable_pb2.VariableDef, to_proto=_to_proto_fn, from_proto=_from_proto_fn) +ops.register_proto_function( + ops.GraphKeys.GLOBAL_STEP, + proto_type=variable_pb2.VariableDef, + to_proto=_to_proto_fn, + from_proto=_from_proto_fn) def is_resource_variable(var): diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 1b4111bca6..96fb024715 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -334,7 +334,11 @@ def py_func(func, inp, Tout, stateful=True, name=None): result = func(*[x.numpy() for x in inp]) result = nest.flatten(result) - return [x if x is None else ops.convert_to_tensor(x) for x in result] + result = [x if x is None else ops.convert_to_tensor(x) for x in result] + if len(result) == 1: + # Mimic the automatic unwrapping in graph-mode py_func + result, = result + return result return _internal_py_func( func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name) |