aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/array_ops.py5
-rw-r--r--tensorflow/python/ops/boosted_trees_ops.py160
-rw-r--r--tensorflow/python/ops/distributions/BUILD12
-rw-r--r--tensorflow/python/ops/init_ops.py9
-rw-r--r--tensorflow/python/ops/linalg/BUILD12
-rw-r--r--tensorflow/python/ops/losses/BUILD12
-rw-r--r--tensorflow/python/ops/nn_ops.py8
-rw-r--r--tensorflow/python/ops/nn_test.py51
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py5
-rw-r--r--tensorflow/python/ops/script_ops.py6
10 files changed, 237 insertions, 43 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 9106461c60..207866610b 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -957,6 +957,11 @@ def _autopacking_helper(list_or_tuple, dtype, name):
Returns:
A `tf.Tensor` with value equivalent to `list_or_tuple`.
"""
+ if context.executing_eagerly():
+ # NOTE: Fast path when all the items are tensors, this doesn't do any type
+ # checking.
+ if all(ops.is_dense_tensor_like(elem) for elem in list_or_tuple):
+ return gen_array_ops.pack(list_or_tuple, name=name)
must_pack = False
converted_elems = []
with ops.name_scope(name) as scope:
diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py
new file mode 100644
index 0000000000..174d00987f
--- /dev/null
+++ b/tensorflow/python/ops/boosted_trees_ops.py
@@ -0,0 +1,160 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for boosted_trees."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_boosted_trees_ops
+from tensorflow.python.ops import resources
+
+# Re-exporting ops used by other modules.
+# pylint: disable=unused-import
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble
+# pylint: enable=unused-import
+
+from tensorflow.python.training import saver
+
+
+class PruningMode(object):
+ NO_PRUNING, PRE_PRUNING, POST_PRUNING = range(0, 3)
+
+
+class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject):
+ """SaveableObject implementation for TreeEnsemble."""
+
+ def __init__(self, resource_handle, create_op, name):
+ """Creates a _TreeEnsembleSavable object.
+
+ Args:
+ resource_handle: handle to the decision tree ensemble variable.
+ create_op: the op to initialize the variable.
+ name: the name to save the tree ensemble variable under.
+ """
+ stamp_token, serialized = (
+ gen_boosted_trees_ops.boosted_trees_serialize_ensemble(resource_handle))
+ # slice_spec is useful for saving a slice from a variable.
+ # It's not meaningful the tree ensemble variable. So we just pass an empty
+ # value.
+ slice_spec = ''
+ specs = [
+ saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec,
+ name + '_stamp'),
+ saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec,
+ name + '_serialized'),
+ ]
+ super(_TreeEnsembleSavable, self).__init__(resource_handle, specs, name)
+ self._resource_handle = resource_handle
+ self._create_op = create_op
+
+ def restore(self, restored_tensors, unused_restored_shapes):
+ """Restores the associated tree ensemble from 'restored_tensors'.
+
+ Args:
+ restored_tensors: the tensors that were loaded from a checkpoint.
+ unused_restored_shapes: the shapes this object should conform to after
+ restore. Not meaningful for trees.
+
+ Returns:
+ The operation that restores the state of the tree ensemble variable.
+ """
+ with ops.control_dependencies([self._create_op]):
+ return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble(
+ self._resource_handle,
+ stamp_token=restored_tensors[0],
+ tree_ensemble_serialized=restored_tensors[1])
+
+
+class TreeEnsemble(object):
+ """Creates TreeEnsemble resource."""
+
+ def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''):
+ with ops.name_scope(name, 'TreeEnsemble') as name:
+ self._resource_handle = (
+ gen_boosted_trees_ops.boosted_trees_ensemble_resource_handle_op(
+ container='', shared_name=name, name=name))
+ create_op = gen_boosted_trees_ops.boosted_trees_create_ensemble(
+ self.resource_handle,
+ stamp_token,
+ tree_ensemble_serialized=serialized_proto)
+ is_initialized_op = (
+ gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized(
+ self._resource_handle))
+ # Adds the variable to the savable list.
+ if not is_local:
+ saveable = _TreeEnsembleSavable(self.resource_handle, create_op,
+ self.resource_handle.name)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ resources.register_resource(
+ self.resource_handle,
+ create_op,
+ is_initialized_op,
+ is_shared=not is_local)
+
+ @property
+ def resource_handle(self):
+ return self._resource_handle
+
+ def get_stamp_token(self):
+ """Returns the current stamp token of the resource."""
+ stamp_token, _, _, _ = (
+ gen_boosted_trees_ops.boosted_trees_get_ensemble_states(
+ self.resource_handle))
+ return stamp_token
+
+ def get_states(self):
+ """Returns states of the tree ensemble.
+
+ Returns:
+ stamp_token, num_trees, num_finalized_trees, num_attempted_layers.
+ """
+ stamp_token, num_trees, num_finalized_trees, num_attempted_layers = (
+ gen_boosted_trees_ops.boosted_trees_get_ensemble_states(
+ self.resource_handle))
+ # Use identity to give names.
+ return (array_ops.identity(stamp_token, name='stamp_token'),
+ array_ops.identity(num_trees, name='num_trees'),
+ array_ops.identity(num_finalized_trees, name='num_finalized_trees'),
+ array_ops.identity(
+ num_attempted_layers, name='num_attempted_layers'))
+
+ def serialize(self):
+ """Serializes the ensemble into proto and returns the serialized proto.
+
+ Returns:
+ stamp_token: int64 scalar Tensor to denote the stamp of the resource.
+ serialized_proto: string scalar Tensor of the serialized proto.
+ """
+ return gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
+ self.resource_handle)
+
+ def deserialize(self, stamp_token, serialized_proto):
+ """Deserialize the input proto and resets the ensemble from it.
+
+ Args:
+ stamp_token: int64 scalar Tensor to denote the stamp of the resource.
+ serialized_proto: string scalar Tensor of the serialized proto.
+
+ Returns:
+ Operation (for dependencies).
+ """
+ return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble(
+ self.resource_handle, stamp_token, serialized_proto)
diff --git a/tensorflow/python/ops/distributions/BUILD b/tensorflow/python/ops/distributions/BUILD
index 50b956a267..9d9ede7ad7 100644
--- a/tensorflow/python/ops/distributions/BUILD
+++ b/tensorflow/python/ops/distributions/BUILD
@@ -26,15 +26,3 @@ py_library(
"@six_archive//:six",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 40ab22951b..9dfe5ffbf4 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -532,8 +532,7 @@ class Orthogonal(Initializer):
q, r = linalg_ops.qr(a, full_matrices=False)
# Make Q uniform
d = array_ops.diag_part(r)
- ph = d / math_ops.abs(d)
- q *= ph
+ q *= math_ops.sign(d)
if num_rows < num_cols:
q = array_ops.matrix_transpose(q)
return self.gain * array_ops.reshape(q, shape)
@@ -579,7 +578,11 @@ class ConvolutionDeltaOrthogonal(Initializer):
a = random_ops.random_normal([shape[-1], shape[-1]],
dtype=dtype, seed=self.seed)
# Compute the qr factorization
- q, _ = linalg_ops.qr(a, full_matrices=False)
+ q, r = linalg_ops.qr(a, full_matrices=False)
+ # Make Q uniform
+ d = array_ops.diag_part(r)
+ # ph = d / math_ops.abs(d)
+ q *= math_ops.sign(d)
q = q[:shape[-2], :]
q *= math_ops.sqrt(math_ops.cast(self.gain, dtype=dtype))
if len(shape) == 3:
diff --git a/tensorflow/python/ops/linalg/BUILD b/tensorflow/python/ops/linalg/BUILD
index ce8c1580fe..07659ef44c 100644
--- a/tensorflow/python/ops/linalg/BUILD
+++ b/tensorflow/python/ops/linalg/BUILD
@@ -34,15 +34,3 @@ py_library(
"//tensorflow/python:special_math_ops",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/ops/losses/BUILD b/tensorflow/python/ops/losses/BUILD
index 07741e0c3c..4aea0265a7 100644
--- a/tensorflow/python/ops/losses/BUILD
+++ b/tensorflow/python/ops/losses/BUILD
@@ -43,15 +43,3 @@ py_test(
"//tensorflow/python:framework_for_generated_wrappers",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index a74de39eab..0c55386241 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1836,8 +1836,10 @@ def softmax_cross_entropy_with_logits_v2(
[logits, labels]) as name:
logits = ops.convert_to_tensor(logits, name="logits")
labels = ops.convert_to_tensor(labels, name="labels")
+ convert_to_float32 = (
+ logits.dtype == dtypes.float16 or logits.dtype == dtypes.bfloat16)
precise_logits = math_ops.cast(
- logits, dtypes.float32) if (logits.dtype == dtypes.float16) else logits
+ logits, dtypes.float32) if convert_to_float32 else logits
# labels and logits must be of the same type
labels = math_ops.cast(labels, precise_logits.dtype)
input_rank = array_ops.rank(precise_logits)
@@ -1883,8 +1885,8 @@ def softmax_cross_entropy_with_logits_v2(
del shape[dim]
cost.set_shape(shape)
- if logits.dtype == dtypes.float16:
- return math_ops.cast(cost, dtypes.float16)
+ if convert_to_float32:
+ return math_ops.cast(cost, logits.dtype)
else:
return cost
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index af9dae2aa6..da86d5f6ca 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -852,6 +852,57 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
self.assertAllClose(exp_sampled_softmax_loss,
got_sampled_softmax_loss.eval(), 1e-4)
+ def testSampledSoftmaxLossBf16(self):
+ # A simple test to verify the numerics for bfloat16.
+ def _SoftmaxCrossEntropyWithLogits(logits, targets):
+ # logits, targets: float arrays of the same shape.
+ assert logits.shape == targets.shape
+ stable_exp_logits = np.exp(
+ logits - np.amax(logits, axis=1, keepdims=True))
+ pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True)
+ return -np.sum(targets * np.log(pred + 1.0e-20), axis=1)
+
+ np.random.seed(0)
+ num_classes = 5
+ batch_size = 3
+ labels = [0, 1, 2]
+ sampled = [1, 0, 2, 3]
+ (weights, biases, hidden_acts, _, exp_logits,
+ exp_labels) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=1,
+ labels=labels,
+ sampled=sampled,
+ subtract_log_q=True)
+ exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits(
+ exp_logits, exp_labels)
+
+ with self.test_session():
+ true_exp_bf16 = np.full(
+ [batch_size, 1], fill_value=0.5, dtype=dtypes.bfloat16.as_numpy_dtype)
+ sampled_exp_bf16 = np.full(
+ [len(sampled)], fill_value=0.5, dtype=dtypes.bfloat16.as_numpy_dtype)
+ sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16)
+
+ got_sampled_softmax_loss = math_ops.cast(
+ nn_impl.sampled_softmax_loss(
+ weights=constant_op.constant(weights, dtype=dtypes.bfloat16),
+ biases=constant_op.constant(biases, dtype=dtypes.bfloat16),
+ labels=constant_op.constant(
+ labels, shape=(batch_size, 1), dtype=dtypes.bfloat16),
+ inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=1,
+ sampled_values=sampled_vals_bf16,
+ remove_accidental_hits=False,
+ partition_strategy="div"), dtypes.float32)
+
+ self.assertAllClose(exp_sampled_softmax_loss,
+ got_sampled_softmax_loss.eval(), 1e-1)
+
class CReluTest(test_lib.TestCase):
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index df873da98e..2f39ea2e7d 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -1087,6 +1087,11 @@ ops.register_proto_function(
proto_type=variable_pb2.VariableDef,
to_proto=_to_proto_fn,
from_proto=_from_proto_fn)
+ops.register_proto_function(
+ ops.GraphKeys.GLOBAL_STEP,
+ proto_type=variable_pb2.VariableDef,
+ to_proto=_to_proto_fn,
+ from_proto=_from_proto_fn)
def is_resource_variable(var):
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index 1b4111bca6..96fb024715 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -334,7 +334,11 @@ def py_func(func, inp, Tout, stateful=True, name=None):
result = func(*[x.numpy() for x in inp])
result = nest.flatten(result)
- return [x if x is None else ops.convert_to_tensor(x) for x in result]
+ result = [x if x is None else ops.convert_to_tensor(x) for x in result]
+ if len(result) == 1:
+ # Mimic the automatic unwrapping in graph-mode py_func
+ result, = result
+ return result
return _internal_py_func(
func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)