aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-08-21 14:48:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 15:00:13 -0700
commit0f02f05913e03889bbcb85e71a6d005a8519bfb9 (patch)
treec5b2bacb1b96d260b67cb56c208ce8f1b1025dae /tensorflow/contrib
parent3f24f93c2a32b2eae8951e5b272c3b647c5b9611 (diff)
Merged commit includes the following changes:
209663919 by yifeif<yifeif@google.com>: Internal change. -- 209663914 by amitpatankar<amitpatankar@google.com>: Fix the topk_op_test for numpy>1.15. -- 209660476 by jdduke<jdduke@google.com>: Fix model lifetime for TensorFlow Lite C# bindings Ensure the model's existence for the duration of the interpreter, as per API requirements. -- 209655960 by scottzhu<scottzhu@google.com>: Unify RNN Cell interface between TF and Keras. -- 209655731 by A. Unique TensorFlower<gardener@tensorflow.org>: Added tests for PredictionOps and PartitionExamplesOps -- 209655291 by nolivia<nolivia@google.com>: adding rate class so that we can save global_step/sec using tf.contrib.summary. The function takes the rate in relation to any tensors provided that the numerator and denominator are broadcastable and have dtypes that can be cast to float64 -- 209654655 by kramerb<kramerb@google.com>: [XLA] Switch from tensorflow::gtl::InlinedVector to absl::InlinedVector This one comes with extra goodies like a move constructor. -- 209653851 by A. Unique TensorFlower<gardener@tensorflow.org>: Internal build specification change -- PiperOrigin-RevId: 209663919
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py185
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs31
-rw-r--r--tensorflow/contrib/rate/BUILD48
-rw-r--r--tensorflow/contrib/rate/rate.py151
-rw-r--r--tensorflow/contrib/rate/rate_test.py97
5 files changed, 466 insertions, 46 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
index cf55759aaa..bef42fdf7f 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
@@ -96,6 +96,20 @@ def _set_float_split(split, feat_col, thresh, l_id, r_id, feature_dim_id=None):
split.dimension_id = feature_dim_id
+def _set_float_oblivious_split(split, feat_col, thresh):
+ """Helper method for building tree float splits.
+
+ Sets split feature column and threshold.
+
+ Args:
+ split: split node to update.
+ feat_col: feature column for the split.
+ thresh: threshold to split on forming rule x <= thresh.
+ """
+ split.feature_column = feat_col
+ split.threshold = thresh
+
+
def _set_categorical_id_split(split, feat_col, feat_id, l_id, r_id):
"""Helper method for building tree categorical id splits.
@@ -119,15 +133,17 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
def setUp(self):
"""Sets up the prediction tests.
- Create a batch of two examples having one dense float, two sparse float
+ Creates, a batch of two examples having three dense float, two sparse float
single valued, one sparse float multidimensional and one sparse int
features. The data looks like the following:
- | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | SparseM
- | 0 | 7 | -3 | | 9,1 | __, 5.0
- | 1 | -2 | | 4 | | 3, ___
+ |Instance |Dense0 |Dense1 |Dense2 |SparseF0 |SparseF1 |SparseI0 |SparseM
+ | 0 | 7 | 1 | 2 | -3 | | 9,1 | __, 5.0
+ | 1 | -2 | 2 | 0.5 | | 4 | | 3, ___
"""
super(PredictionOpsTest, self).setUp()
- self._dense_float_tensor = np.array([[7.0], [-2.0]])
+ self._dense_float_tensor1 = np.array([[7.0], [-2.0]])
+ self._dense_float_tensor2 = np.array([[1.0], [2.0]])
+ self._dense_float_tensor3 = np.array([[2.0], [0.5]])
self._sparse_float_indices1 = np.array([[0, 0]])
self._sparse_float_values1 = np.array([-3.0])
self._sparse_float_shape1 = np.array([2, 1])
@@ -153,7 +169,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
reduce_dim=False):
return prediction_ops.gradient_trees_prediction(
tree_ensemble_handle,
- self._seed, [self._dense_float_tensor],
+ self._seed, [self._dense_float_tensor1],
[self._sparse_float_indices1, self._sparse_float_indices2],
[self._sparse_float_values1, self._sparse_float_values2],
[self._sparse_float_shape1, self._sparse_float_shape2],
@@ -165,6 +181,25 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
center_bias=center_bias,
reduce_dim=reduce_dim)
+ def _get_predictions_oblivious_case(self,
+ tree_ensemble_handle,
+ learner_config,
+ apply_dropout=False,
+ apply_averaging=False,
+ center_bias=False,
+ reduce_dim=False):
+ return prediction_ops.gradient_trees_prediction(
+ tree_ensemble_handle,
+ self._seed, [
+ self._dense_float_tensor1, self._dense_float_tensor2,
+ self._dense_float_tensor3
+ ], [], [], [], [], [], [],
+ learner_config=learner_config,
+ apply_dropout=apply_dropout,
+ apply_averaging=apply_averaging,
+ center_bias=center_bias,
+ reduce_dim=reduce_dim)
+
def testEmptyEnsemble(self):
with self.test_session():
# Empty tree ensenble.
@@ -295,6 +330,53 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# Empty dropout.
self.assertAllEqual([[], []], dropout_info.eval())
+ def testObliviousEnsemble(self):
+ with self.test_session():
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ # Bias tree.
+ tree1 = tree_ensemble_config.trees.add()
+ tree_ensemble_config.tree_metadata.add().is_finalized = True
+ _append_to_leaf(tree1.nodes.add().leaf, 0, -0.4)
+
+ # Depth 3 tree.
+ tree2 = tree_ensemble_config.trees.add()
+ _set_float_oblivious_split(
+ tree2.nodes.add().oblivious_dense_float_binary_split, 0, 5.0)
+ _set_float_oblivious_split(
+ tree2.nodes.add().oblivious_dense_float_binary_split, 1, 3.0)
+ _set_float_oblivious_split(
+ tree2.nodes.add().oblivious_dense_float_binary_split, 2, 1.0)
+ for i in range(1, 9):
+ _append_to_leaf(tree2.nodes.add().leaf, 0, i / 10.0)
+
+ tree_ensemble_config.tree_weights.append(1.0)
+ tree_ensemble_config.tree_weights.append(1.0)
+
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="full_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+
+ result, dropout_info = self._get_predictions_oblivious_case(
+ tree_ensemble_handle,
+ learner_config=learner_config.SerializeToString(),
+ reduce_dim=True)
+
+ # The first example will get bias -0.4 from first tree and 0.6 from
+ # the 5th leaf of the second tree corresponding to node_id = 8, hence a
+ # prediction of 0.2.
+ # The second example will get bias -0.4 and 0.1 from the 0th leaf of the
+ # second tree corresponding to node_id = 3, hence a prediction of -0.3
+ self.assertAllClose([[0.2], [-0.3]], result.eval())
+
+ # Empty dropout.
+ self.assertAllEqual([[], []], dropout_info.eval())
+
def testFullEnsembleWithMultidimensionalSparseSingleClass(self):
with self.test_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
@@ -358,7 +440,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
result, dropout_info = prediction_ops.gradient_trees_prediction(
tree_ensemble_handle,
- self._seed, [self._dense_float_tensor], [
+ self._seed, [self._dense_float_tensor1], [
self._sparse_float_indices1, self._sparse_float_indices2,
self._sparse_float_indices_m
], [
@@ -917,7 +999,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# Different seed.
_, dropout_info_3 = prediction_ops.gradient_trees_prediction(
tree_ensemble_handle,
- 112314, [self._dense_float_tensor],
+ 112314, [self._dense_float_tensor1],
[self._sparse_float_indices1, self._sparse_float_indices2],
[self._sparse_float_values1, self._sparse_float_values2],
[self._sparse_float_shape1, self._sparse_float_shape2],
@@ -1204,15 +1286,18 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
def setUp(self):
"""Sets up the prediction tests.
- Create a batch of two examples having one dense float, two sparse float and
- one sparse int features.
+ Create a batch of two examples having three dense float, two sparse float
+ and one sparse int features.
The data looks like the following:
- | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 |
- | 0 | 7 | -3 | | 9,1 |
- | 1 | -2 | | 4 | |
+ |Instance |Dense0 |Dense1 |Dense2 |SparseF0 |SparseF1 |SparseI0 |
+ | 0 | 7 | 1 | 2 | -3 | | 9,1 |
+ | 1 | -2 | 2 | 0.5 | | 4 | |
+
"""
super(PartitionExamplesOpsTest, self).setUp()
- self._dense_float_tensor = np.array([[7.0], [-2.0]])
+ self._dense_float_tensor1 = np.array([[7.0], [-2.0]])
+ self._dense_float_tensor2 = np.array([[1.0], [2.0]])
+ self._dense_float_tensor3 = np.array([[2.0], [0.5]])
self._sparse_float_indices1 = np.array([[0, 0]])
self._sparse_float_values1 = np.array([-3.0])
self._sparse_float_shape1 = np.array([2, 1])
@@ -1234,12 +1319,12 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
resources.initialize_resources(resources.shared_resources()).run()
result = prediction_ops.gradient_trees_partition_examples(
- tree_ensemble_handle, [self._dense_float_tensor], [
- self._sparse_float_indices1, self._sparse_float_indices2
- ], [self._sparse_float_values1, self._sparse_float_values2],
- [self._sparse_float_shape1,
- self._sparse_float_shape2], [self._sparse_int_indices1],
- [self._sparse_int_values1], [self._sparse_int_shape1])
+ tree_ensemble_handle, [self._dense_float_tensor1],
+ [self._sparse_float_indices1, self._sparse_float_indices2],
+ [self._sparse_float_values1, self._sparse_float_values2],
+ [self._sparse_float_shape1, self._sparse_float_shape2],
+ [self._sparse_int_indices1], [self._sparse_int_values1],
+ [self._sparse_int_shape1])
self.assertAllEqual([0, 0], result.eval())
@@ -1269,12 +1354,12 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
resources.initialize_resources(resources.shared_resources()).run()
result = prediction_ops.gradient_trees_partition_examples(
- tree_ensemble_handle, [self._dense_float_tensor], [
- self._sparse_float_indices1, self._sparse_float_indices2
- ], [self._sparse_float_values1, self._sparse_float_values2],
- [self._sparse_float_shape1,
- self._sparse_float_shape2], [self._sparse_int_indices1],
- [self._sparse_int_values1], [self._sparse_int_shape1])
+ tree_ensemble_handle, [self._dense_float_tensor1],
+ [self._sparse_float_indices1, self._sparse_float_indices2],
+ [self._sparse_float_values1, self._sparse_float_values2],
+ [self._sparse_float_shape1, self._sparse_float_shape2],
+ [self._sparse_int_indices1], [self._sparse_int_values1],
+ [self._sparse_int_shape1])
self.assertAllEqual([5, 3], result.eval())
@@ -1304,15 +1389,51 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
resources.initialize_resources(resources.shared_resources()).run()
result = prediction_ops.gradient_trees_partition_examples(
- tree_ensemble_handle, [self._dense_float_tensor], [
- self._sparse_float_indices1, self._sparse_float_indices2
- ], [self._sparse_float_values1, self._sparse_float_values2],
- [self._sparse_float_shape1,
- self._sparse_float_shape2], [self._sparse_int_indices1],
- [self._sparse_int_values1], [self._sparse_int_shape1])
+ tree_ensemble_handle, [self._dense_float_tensor1],
+ [self._sparse_float_indices1, self._sparse_float_indices2],
+ [self._sparse_float_values1, self._sparse_float_values2],
+ [self._sparse_float_shape1, self._sparse_float_shape2],
+ [self._sparse_int_indices1], [self._sparse_int_values1],
+ [self._sparse_int_shape1])
self.assertAllEqual([0, 0], result.eval())
+ def testObliviousTreeNonFinalized(self):
+ with self.test_session():
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ # Depth 3 tree.
+ tree1 = tree_ensemble_config.trees.add()
+ _set_float_oblivious_split(
+ tree1.nodes.add().oblivious_dense_float_binary_split, 0, 5.0)
+ _set_float_oblivious_split(
+ tree1.nodes.add().oblivious_dense_float_binary_split, 1, 3.0)
+ _set_float_oblivious_split(
+ tree1.nodes.add().oblivious_dense_float_binary_split, 2, 1.0)
+ for i in range(1, 9):
+ _append_to_leaf(tree1.nodes.add().leaf, 0, i / 10.0)
+ tree_ensemble_config.tree_weights.append(1.0)
+ tree_ensemble_config.tree_metadata.add().is_finalized = False
+
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="full_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ result = prediction_ops.gradient_trees_partition_examples(
+ tree_ensemble_handle, [
+ self._dense_float_tensor1,
+ self._dense_float_tensor2,
+ self._dense_float_tensor3
+ ], [], [], [], [], [], [])
+
+ # The first example goes right, left, right in the tree and the second
+ # example goes lef, left, left. Since the depth of the tree is 3, the
+ # partition id's are as follows:
+ # First example: 3 + 5 = 8
+ # Second exampel: 3 + 0 = 3
+ self.assertAllEqual([8, 3], result.eval())
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs
index b6905b5fbf..676783063d 100644
--- a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs
@@ -29,15 +29,16 @@ namespace TensorFlowLite
{
private const string TensorFlowLibrary = "tensorflowlite_c";
- private TFL_Interpreter handle;
+ private TFL_Model model;
+ private TFL_Interpreter interpreter;
public Interpreter(byte[] modelData) {
GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned);
IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject();
- TFL_Model model = TFL_NewModel(modelDataPtr, modelData.Length);
- handle = TFL_NewInterpreter(model, /*options=*/IntPtr.Zero);
- TFL_DeleteModel(model);
- if (handle == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter");
+ model = TFL_NewModel(modelDataPtr, modelData.Length);
+ if (model == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Model");
+ interpreter = TFL_NewInterpreter(model, /*options=*/IntPtr.Zero);
+ if (interpreter == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter");
}
~Interpreter() {
@@ -45,43 +46,45 @@ namespace TensorFlowLite
}
public void Dispose() {
- if (handle != IntPtr.Zero) TFL_DeleteInterpreter(handle);
- handle = IntPtr.Zero;
+ if (interpreter != IntPtr.Zero) TFL_DeleteInterpreter(interpreter);
+ interpreter = IntPtr.Zero;
+ if (model != IntPtr.Zero) TFL_DeleteModel(model);
+ model = IntPtr.Zero;
}
public void Invoke() {
- ThrowIfError(TFL_InterpreterInvoke(handle));
+ ThrowIfError(TFL_InterpreterInvoke(interpreter));
}
public int GetInputTensorCount() {
- return TFL_InterpreterGetInputTensorCount(handle);
+ return TFL_InterpreterGetInputTensorCount(interpreter);
}
public void SetInputTensorData(int inputTensorIndex, Array inputTensorData) {
GCHandle tensorDataHandle = GCHandle.Alloc(inputTensorData, GCHandleType.Pinned);
IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject();
- TFL_Tensor tensor = TFL_InterpreterGetInputTensor(handle, inputTensorIndex);
+ TFL_Tensor tensor = TFL_InterpreterGetInputTensor(interpreter, inputTensorIndex);
ThrowIfError(TFL_TensorCopyFromBuffer(
tensor, tensorDataPtr, Buffer.ByteLength(inputTensorData)));
}
public void ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape) {
ThrowIfError(TFL_InterpreterResizeInputTensor(
- handle, inputTensorIndex, inputTensorShape, inputTensorShape.Length));
+ interpreter, inputTensorIndex, inputTensorShape, inputTensorShape.Length));
}
public void AllocateTensors() {
- ThrowIfError(TFL_InterpreterAllocateTensors(handle));
+ ThrowIfError(TFL_InterpreterAllocateTensors(interpreter));
}
public int GetOutputTensorCount() {
- return TFL_InterpreterGetOutputTensorCount(handle);
+ return TFL_InterpreterGetOutputTensorCount(interpreter);
}
public void GetOutputTensorData(int outputTensorIndex, Array outputTensorData) {
GCHandle tensorDataHandle = GCHandle.Alloc(outputTensorData, GCHandleType.Pinned);
IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject();
- TFL_Tensor tensor = TFL_InterpreterGetOutputTensor(handle, outputTensorIndex);
+ TFL_Tensor tensor = TFL_InterpreterGetOutputTensor(interpreter, outputTensorIndex);
ThrowIfError(TFL_TensorCopyToBuffer(
tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData)));
}
diff --git a/tensorflow/contrib/rate/BUILD b/tensorflow/contrib/rate/BUILD
new file mode 100644
index 0000000000..c461a7145e
--- /dev/null
+++ b/tensorflow/contrib/rate/BUILD
@@ -0,0 +1,48 @@
+# Description:
+# contains parts of TensorFlow that are experimental or unstable and which are not supported.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//visibility:public"])
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "rate",
+ srcs = [
+ "rate.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
+
+py_test(
+ name = "rate_test",
+ size = "small",
+ srcs = ["rate_test.py"],
+ deps = [
+ ":rate",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:test",
+ ],
+)
diff --git a/tensorflow/contrib/rate/rate.py b/tensorflow/contrib/rate/rate.py
new file mode 100644
index 0000000000..24d586479a
--- /dev/null
+++ b/tensorflow/contrib/rate/rate.py
@@ -0,0 +1,151 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implementation of tf.contrib.rate module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+
+_to_replace = re.compile("[^A-Za-z0-9.]")
+
+
+class Rate(object):
+ """Computes the rate of change since the last rate call."""
+
+ def __init__(self, name=None):
+ self._built = False
+ self._vars = []
+ self._initial_values = {}
+ name = name or self.__class__.__name__
+ # Replace things like spaces in name to create a valid scope name.
+ scope_name = _to_replace.sub("_", name)
+ # We create the variable scope now to get the unique name that will
+ # be used as a variable prefix when build() calls _add_variable().
+ with variable_scope.variable_scope(
+ scope_name, use_resource=True, reuse=False) as scope:
+ pos = scope.name.rfind(scope_name)
+ self._name = name + scope.name[pos + len(scope_name):]
+ self._scope = scope
+
+ # Ensures that if the user calls build directly we still set self._built to
+ # True to prevent variables from being recreated.
+ self._build = self.build
+ if context.executing_eagerly():
+ self._construction_scope = context.eager_mode
+ else:
+ # We make self.call() into a graph callable here, so that we can
+ # return a single op that performs all of the variable updates.
+ self._construction_scope = ops.get_default_graph().as_default
+ self.call = function.defun(self.call)
+
+ def build(self, values, denominator):
+ """Method to create variables.
+
+ Called by `__call__()` before `call()` for the first time.
+
+ Args:
+ values: The numerator for rate.
+ denominator: Value to which the rate is taken with respect.
+ """
+ self.numer = self._add_variable(
+ name="numer", shape=values.get_shape(), dtype=dtypes.float64)
+ self.denom = self._add_variable(
+ name="denom", shape=denominator.get_shape(), dtype=dtypes.float64)
+ self.prev_values = self._add_variable(
+ name="prev_values", shape=values.get_shape(), dtype=dtypes.float64)
+ self.prev_denominator = self._add_variable(
+ name="prev_denominator",
+ shape=denominator.get_shape(),
+ dtype=dtypes.float64)
+ self._built = True
+
+ def __call__(self, *args, **kwargs):
+ """Returns op to execute to update.
+
+ Returns None if eager execution is enabled.
+ Returns a graph-mode function if graph execution is enabled.
+
+ Args:
+ *args:
+ **kwargs: A mini-batch of inputs to Rate, passed on to `call()`.
+ """
+ if not self._built:
+ with variable_scope.variable_scope(
+ self._scope), self._construction_scope():
+ self.build(*args, **kwargs)
+ self._built = True
+ return self.call(*args, **kwargs)
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def variables(self):
+ return self._vars
+
+ def _safe_div(self, numerator, denominator, name):
+ t = math_ops.truediv(numerator, denominator)
+ zero = array_ops.zeros_like(t, dtype=denominator.dtype)
+ condition = math_ops.greater(denominator, zero)
+ zero = math_ops.cast(zero, t.dtype)
+ return array_ops.where(condition, t, zero, name=name)
+
+ def _add_variable(self, name, shape=None, dtype=None):
+ """Private method for adding variables to the graph."""
+ if self._built:
+ raise RuntimeError("Can't call add_variable() except in build().")
+ v = resource_variable_ops.ResourceVariable(
+ lambda: array_ops.zeros(shape, dtype),
+ trainable=False,
+ validate_shape=True,
+ name=name,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ return v
+
+ def call(self, values, denominator):
+ """Computes the rate since the last call.
+
+ Args:
+ values: Tensor with the per-example value.
+ denominator: Measure to take the rate with respect to.
+
+ Returns:
+ The rate or 0 if denominator is unchanged since last call.
+ """
+ if denominator.dtype != dtypes.float64:
+ denominator = math_ops.cast(denominator, dtypes.float64)
+ if values.dtype != dtypes.float64:
+ values = math_ops.cast(values, dtypes.float64)
+
+ state_ops.assign(self.numer, math_ops.subtract(values, self.prev_values))
+ state_ops.assign(self.denom,
+ math_ops.subtract(denominator, self.prev_denominator))
+ state_ops.assign(self.prev_values, values)
+ state_ops.assign(self.prev_denominator, denominator)
+
+ return self._safe_div(self.numer, self.denom, name="safe_rate")
diff --git a/tensorflow/contrib/rate/rate_test.py b/tensorflow/contrib/rate/rate_test.py
new file mode 100644
index 0000000000..08908104f4
--- /dev/null
+++ b/tensorflow/contrib/rate/rate_test.py
@@ -0,0 +1,97 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Rate."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.rate import rate
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class RateTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testBuildRate(self):
+ m = rate.Rate()
+ m.build(
+ constant_op.constant([1], dtype=dtypes.float32),
+ constant_op.constant([2], dtype=dtypes.float32))
+ old_numer = m.numer
+ m(
+ constant_op.constant([2], dtype=dtypes.float32),
+ constant_op.constant([2], dtype=dtypes.float32))
+ self.assertTrue(old_numer is m.numer)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testBasic(self):
+ with self.test_session():
+ r_ = rate.Rate()
+ a = r_(array_ops.ones([1]), denominator=array_ops.ones([1]))
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(variables.local_variables_initializer())
+ self.assertEqual([[1]], self.evaluate(a))
+ b = r_(constant_op.constant([2]), denominator=constant_op.constant([2]))
+ self.assertEqual([[1]], self.evaluate(b))
+ c = r_(constant_op.constant([4]), denominator=constant_op.constant([3]))
+ self.assertEqual([[2]], self.evaluate(c))
+ d = r_(constant_op.constant([16]), denominator=constant_op.constant([3]))
+ self.assertEqual([[0]], self.evaluate(d)) # divide by 0
+
+ def testNamesWithSpaces(self):
+ m1 = rate.Rate(name="has space")
+ m1(array_ops.ones([1]), array_ops.ones([1]))
+ self.assertEqual(m1.name, "has space")
+ self.assertEqual(m1.prev_values.name, "has_space_1/prev_values:0")
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testWhileLoop(self):
+ with self.test_session():
+ r_ = rate.Rate()
+
+ def body(value, denom, i, ret_rate):
+ i += 1
+ ret_rate = r_(value, denom)
+ with ops.control_dependencies([ret_rate]):
+ value = math_ops.add(value, 2)
+ denom = math_ops.add(denom, 1)
+ return [value, denom, i, ret_rate]
+
+ def condition(v, d, i, r):
+ del v, d, r # unused vars by condition
+ return math_ops.less(i, 100)
+
+ i = constant_op.constant(0)
+ value = constant_op.constant([1], dtype=dtypes.float64)
+ denom = constant_op.constant([1], dtype=dtypes.float64)
+ ret_rate = r_(value, denom)
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(variables.local_variables_initializer())
+ loop = control_flow_ops.while_loop(condition, body,
+ [value, denom, i, ret_rate])
+ self.assertEqual([[2]], self.evaluate(loop[3]))
+
+
+if __name__ == "__main__":
+ test.main()