aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc203
-rw-r--r--tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc42
-rw-r--r--tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc95
-rw-r--r--tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc110
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/framework/load_library.cc26
-rw-r--r--tensorflow/core/framework/op.cc67
-rw-r--r--tensorflow/core/framework/op.h47
-rw-r--r--tensorflow/core/framework/op_registration_test.cc57
-rw-r--r--tensorflow/core/ops/array_ops.cc10
-rw-r--r--tensorflow/python/framework/load_library.py6
-rw-r--r--tensorflow/user_ops/BUILD12
-rw-r--r--tensorflow/user_ops/duplicate_op.cc26
-rw-r--r--tensorflow/user_ops/duplicate_op_test.py39
14 files changed, 470 insertions, 271 deletions
diff --git a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc
index 913a7d3c01..0ccf75bcc6 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc
@@ -120,109 +120,108 @@ void Evaluate(const Tensor& input_data, const Tensor& input_labels,
}
}
-
REGISTER_OP("CountExtremelyRandomStats")
- .Attr("num_classes: int32")
- .Attr("regression: bool = false")
- .Input("input_data: float")
- .Input("input_labels: float")
-
- .Input("tree: int32")
- .Input("tree_thresholds: float")
-
- .Input("node_to_accumulator: int32")
-
- .Input("candidate_split_features: int32")
- .Input("candidate_split_thresholds: float")
-
- .Output("pcw_node_sums_delta: float")
- .Output("pcw_node_squares_delta: float")
- .Output("pcw_splits_indices: int32")
- .Output("pcw_candidate_splits_sums_delta: float")
- .Output("pcw_candidate_splits_squares_delta: float")
- .Output("pcw_totals_indices: int32")
- .Output("pcw_totals_sums_delta: float")
- .Output("pcw_totals_squares_delta: float")
-
- .Output("leaves: int32")
- .Doc(R"doc(
- Calculates incremental statistics for a batch of training data.
-
- Each training example in `input_data` is sent through the decision tree
- represented by `tree` and `tree_thresholds`.
- The shape and contents of the outputs differ depending on whether
- `regression` is true or not.
-
- For `regression` = false (classification), `pcw_node_sums_delta[i]` is
- incremented for every node i that it passes through, and the leaf it ends up
- in is recorded in `leaves[i]`. Then, if the leaf is fertile and
- initialized, the statistics for its corresponding accumulator slot
- are updated in `pcw_candidate_splits_delta` and `pcw_total_splits_delta`.
-
- For `regression` = true, outputs contain the sum of the input_labels
- for the appropriate nodes. In adddition, the *_squares outputs are filled
- in with the sums of the squares of the input_labels. Since outputs are
- all updated at once, the *_indicies outputs don't specify the output
- dimension to update, rather the *_delta output contains updates for all the
- outputs. For example, `pcw_totals_indices` specifies the accumulators to
- update, and `pcw_total_splits_sums_delta` contains the complete output
- updates for each of those accumulators.
-
- The attr `num_classes` is needed to appropriately size the outputs.
-
- input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
- gives the j-th feature of the i-th input.
- input_labels: The training batch's labels; `input_labels[i]` is the class
- of the i-th input.
- tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child
- of the i-th node, `tree[i][0] + 1` gives the index of the right child of
- the i-th node, and `tree[i][1]` gives the index of the feature used to
- split the i-th node.
- tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
- node.
- node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]`
- is it's accumulator slot. Otherwise, `node_to_accumulator[i]` is -1.
- candidate_split_features: `candidate_split_features[a][s]` is the
- index of the feature being considered by split s of accumulator slot a.
- candidate_split_thresholds: `candidate_split_thresholds[a][s]` is the
- threshold value being considered by split s of accumulator slot a.
- pcw_node_sums_delta: `pcw_node_sums_delta[i][c]` is the number of training
- examples in this training batch with class c that passed through node i for
- classification. For regression, it is the sum of the input_labels that
- have passed through node i.
- pcw_node_squares_delta: `pcw_node_squares_delta[i][c]` is the sum of the
- squares of the input labels that have passed through node i for
- regression. Not set for classification.
- pcw_splits_indices:= A 2-d tensor of shape (?, 3) for classification and
- (?, 2) for regression.
- `pcw_splits_indices[i]` gives the coordinates of an entry in
- candidate_split_pcw_sums and candidate_split_pcw_squares that need to be
- updated. This is meant to be passed with `pcw_candidate_splits_*_delta` to
- a scatter_add for candidate_split_pcw_*:
- training_ops.scatter_add_ndim(candidate_split_pcw_sums
- pcw_splits_indices, pcw_candidate_splits_sums_delta)
- pcw_candidate_splits_sums_delta: For classification,
- `pcw_candidate_splits_sums_delta[i]` is the
- number of training examples in this training batch that correspond to
- the i-th entry in `pcw_splits_indices` which took the *left* branch of
- candidate split. For regression, it is the same but a 2-D tensor that has
- the sum of the input_labels for each i-th entry in the indices.
- pcw_candidate_splits_squares_delta: For regression, same as
- `pcw_candidate_splits_sums_delta` but the sum of the squares. Not set
- for classification.
- pcw_totals_indices: For classification, 'pcw_totals_indices` contains the
- indices (accumulator, class) into total_pcw_sums to update with
- pcw_totals_sums_delta. For regression, it only contains the accumulator
- (not the class), because pcw_totals_*_delta will contain all the outputs.
- pcw_totals_sums_delta: For classification, `pcw_totals_sums_delta[i]` is the
- number of training examples in this batch that ended up in the fertile
- node with accumulator and class indicated by `pcw_totals_indices[i]`.
- For regression, it is the sum of the input_labels corresponding to the
- entries in `pcw_totals_indices[i]`.
- pcw_totals_squares_delta: For regression, same as
- `pcw_totals_sums_delta` but the sum of the squares. Not set
- for classification.
- leaves: `leaves[i]` is the leaf that input i ended up in.
+ .Attr("num_classes: int")
+ .Attr("regression: bool = false")
+ .Input("input_data: float")
+ .Input("input_labels: float")
+
+ .Input("tree: int32")
+ .Input("tree_thresholds: float")
+
+ .Input("node_to_accumulator: int32")
+
+ .Input("candidate_split_features: int32")
+ .Input("candidate_split_thresholds: float")
+
+ .Output("pcw_node_sums_delta: float")
+ .Output("pcw_node_squares_delta: float")
+ .Output("pcw_splits_indices: int32")
+ .Output("pcw_candidate_splits_sums_delta: float")
+ .Output("pcw_candidate_splits_squares_delta: float")
+ .Output("pcw_totals_indices: int32")
+ .Output("pcw_totals_sums_delta: float")
+ .Output("pcw_totals_squares_delta: float")
+
+ .Output("leaves: int32")
+ .Doc(R"doc(
+Calculates incremental statistics for a batch of training data.
+
+Each training example in `input_data` is sent through the decision tree
+represented by `tree` and `tree_thresholds`.
+The shape and contents of the outputs differ depending on whether
+`regression` is true or not.
+
+For `regression` = false (classification), `pcw_node_sums_delta[i]` is
+incremented for every node i that it passes through, and the leaf it ends up
+in is recorded in `leaves[i]`. Then, if the leaf is fertile and
+initialized, the statistics for its corresponding accumulator slot
+are updated in `pcw_candidate_splits_delta` and `pcw_total_splits_delta`.
+
+For `regression` = true, outputs contain the sum of the input_labels
+for the appropriate nodes. In adddition, the *_squares outputs are filled
+in with the sums of the squares of the input_labels. Since outputs are
+all updated at once, the *_indicies outputs don't specify the output
+dimension to update, rather the *_delta output contains updates for all the
+outputs. For example, `pcw_totals_indices` specifies the accumulators to
+update, and `pcw_total_splits_sums_delta` contains the complete output
+updates for each of those accumulators.
+
+The attr `num_classes` is needed to appropriately size the outputs.
+
+input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
+ gives the j-th feature of the i-th input.
+input_labels: The training batch's labels; `input_labels[i]` is the class
+ of the i-th input.
+tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child
+ of the i-th node, `tree[i][0] + 1` gives the index of the right child of
+ the i-th node, and `tree[i][1]` gives the index of the feature used to
+ split the i-th node.
+tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
+ node.
+node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]`
+ is it's accumulator slot. Otherwise, `node_to_accumulator[i]` is -1.
+candidate_split_features: `candidate_split_features[a][s]` is the
+ index of the feature being considered by split s of accumulator slot a.
+candidate_split_thresholds: `candidate_split_thresholds[a][s]` is the
+ threshold value being considered by split s of accumulator slot a.
+pcw_node_sums_delta: `pcw_node_sums_delta[i][c]` is the number of training
+ examples in this training batch with class c that passed through node i for
+ classification. For regression, it is the sum of the input_labels that
+ have passed through node i.
+pcw_node_squares_delta: `pcw_node_squares_delta[i][c]` is the sum of the
+ squares of the input labels that have passed through node i for
+ regression. Not set for classification.
+pcw_splits_indices:= A 2-d tensor of shape (?, 3) for classification and
+ (?, 2) for regression.
+ `pcw_splits_indices[i]` gives the coordinates of an entry in
+ candidate_split_pcw_sums and candidate_split_pcw_squares that need to be
+ updated. This is meant to be passed with `pcw_candidate_splits_*_delta` to
+ a scatter_add for candidate_split_pcw_*:
+ training_ops.scatter_add_ndim(candidate_split_pcw_sums
+ pcw_splits_indices, pcw_candidate_splits_sums_delta)
+pcw_candidate_splits_sums_delta: For classification,
+ `pcw_candidate_splits_sums_delta[i]` is the
+ number of training examples in this training batch that correspond to
+ the i-th entry in `pcw_splits_indices` which took the *left* branch of
+ candidate split. For regression, it is the same but a 2-D tensor that has
+ the sum of the input_labels for each i-th entry in the indices.
+pcw_candidate_splits_squares_delta: For regression, same as
+ `pcw_candidate_splits_sums_delta` but the sum of the squares. Not set
+ for classification.
+pcw_totals_indices: For classification, 'pcw_totals_indices` contains the
+ indices (accumulator, class) into total_pcw_sums to update with
+ pcw_totals_sums_delta. For regression, it only contains the accumulator
+ (not the class), because pcw_totals_*_delta will contain all the outputs.
+pcw_totals_sums_delta: For classification, `pcw_totals_sums_delta[i]` is the
+ number of training examples in this batch that ended up in the fertile
+ node with accumulator and class indicated by `pcw_totals_indices[i]`.
+ For regression, it is the sum of the input_labels corresponding to the
+ entries in `pcw_totals_indices[i]`.
+pcw_totals_squares_delta: For regression, same as
+ `pcw_totals_sums_delta` but the sum of the squares. Not set
+ for classification.
+leaves: `leaves[i]` is the leaf that input i ended up in.
)doc");
class CountExtremelyRandomStats : public OpKernel {
diff --git a/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc
index 9cb84fc63d..e1369f9d8c 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc
@@ -26,29 +26,28 @@ using tensorforest::CheckTensorBounds;
using tensorforest::Sum;
REGISTER_OP("FinishedNodes")
- .Attr("num_split_after_samples: int32")
- .Input("leaves: int32")
- .Input("node_to_accumulator: int32")
- .Input("accumulator_sums: float")
-
- .Output("finished: int32")
- .Doc(R"doc(
- Determines which of the given leaf nodes are done accumulating.
-
- leaves:= A 1-d int32 tensor. Lists the nodes that are currently leaves.
- node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]`
- is it's accumulator slot. Otherwise, `node_to_accumulator[i]` is -1.
- accumulator_sums: For classification, `accumulator_sums[a][c]` records how many
- training examples have class c and have ended up in the fertile node
- associated with accumulator slot a. It has the total sum in entry 0 for
- convenience. For regression, it is the same except it contains the sum
- of the input labels that have been seen, and entry 0 contains the number
- of training examples that have been seen.
- finished:= A 1-d int32 tensor. Contains the nodes that have total split
- counts greater or equal to the num_split_after_samples attribute.
+ .Attr("num_split_after_samples: int")
+ .Input("leaves: int32")
+ .Input("node_to_accumulator: int32")
+ .Input("accumulator_sums: float")
+
+ .Output("finished: int32")
+ .Doc(R"doc(
+Determines which of the given leaf nodes are done accumulating.
+
+leaves:= A 1-d int32 tensor. Lists the nodes that are currently leaves.
+node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]`
+ is it's accumulator slot. Otherwise, `node_to_accumulator[i]` is -1.
+accumulator_sums: For classification, `accumulator_sums[a][c]` records how many
+ training examples have class c and have ended up in the fertile node
+ associated with accumulator slot a. It has the total sum in entry 0 for
+ convenience. For regression, it is the same except it contains the sum
+ of the input labels that have been seen, and entry 0 contains the number
+ of training examples that have been seen.
+finished:= A 1-d int32 tensor. Contains the nodes that have total split
+ counts greater or equal to the num_split_after_samples attribute.
)doc");
-
class FinishedNodes : public OpKernel {
public:
explicit FinishedNodes(OpKernelConstruction* context)
@@ -128,4 +127,3 @@ REGISTER_KERNEL_BUILDER(Name("FinishedNodes").Device(DEVICE_CPU),
FinishedNodes);
} // namespace tensorflow
-
diff --git a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc
index 8ceb4a9f0c..182b1257b6 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc
@@ -31,58 +31,57 @@ namespace tensorflow {
using tensorforest::CheckTensorBounds;
using tensorforest::IsAllInitialized;
-
REGISTER_OP("SampleInputs")
- .Attr("split_initializations_per_input: int32")
- .Attr("split_sampling_random_seed: int32")
- .Input("input_data: float")
- .Input("node_to_accumulator: int32")
- .Input("leaves: int32")
- .Input("candidate_split_features: int32")
- .Input("candidate_split_thresholds: float")
- .Output("accumulators_to_update: int32")
- .Output("new_split_feature_rows: int32")
- .Output("new_split_threshold_rows: float")
- .Doc(R"doc(
- Initializes candidate splits for newly fertile nodes.
+ .Attr("split_initializations_per_input: int")
+ .Attr("split_sampling_random_seed: int")
+ .Input("input_data: float")
+ .Input("node_to_accumulator: int32")
+ .Input("leaves: int32")
+ .Input("candidate_split_features: int32")
+ .Input("candidate_split_thresholds: float")
+ .Output("accumulators_to_update: int32")
+ .Output("new_split_feature_rows: int32")
+ .Output("new_split_threshold_rows: float")
+ .Doc(R"doc(
+Initializes candidate splits for newly fertile nodes.
- In an extremely random forest, we don't consider all possible threshold
- values for a candidate split feature, but rather only a sampling of them.
- This Op takes those samples from the training data in `input_data`. The
- feature and threshold samples are stored in tensors that are indexed by
- accumulator slot, so for each input, we must first look up which leaf
- it ended up in (using `leaves`) and then which accumulator slot if any
- that leaf maps to (using `node_to_accumulator`).
+In an extremely random forest, we don't consider all possible threshold
+values for a candidate split feature, but rather only a sampling of them.
+This Op takes those samples from the training data in `input_data`. The
+feature and threshold samples are stored in tensors that are indexed by
+accumulator slot, so for each input, we must first look up which leaf
+it ended up in (using `leaves`) and then which accumulator slot if any
+that leaf maps to (using `node_to_accumulator`).
- The attribute `split_initializations_per_input` controls how many splits
- a single training example can initialize, and the attribute
- `split_sampling_random_seed` sets the random number generator's seed
- (a value of 0 means use the current time as the seed).
+The attribute `split_initializations_per_input` controls how many splits
+a single training example can initialize, and the attribute
+`split_sampling_random_seed` sets the random number generator's seed
+(a value of 0 means use the current time as the seed).
- input_data: The features for the current batch of training data.
- `input_data[i][j]` is the j-th feature of the i-th input.
- node_to_accumulator: For a fertile node i, node_to_accumulator[i] is the
- associated accumulator slot. For non-fertile nodes, it is -1.
- leaves: `leaves[i]` is the leaf that the i-th input landed in, as
- calculated by CountExtremelyRandomStats.
- candidate_split_features: The current features for the candidate splits;
- `candidate_split_features[a][s]` is the index of the feature being
- considered by split s in accumulator slot a.
- candidate_split_thresholds: The current thresholds for the candidate splits;
- `candidate_split_thresholds[a][s]` is the threshold value being
- considered by split s in accumulator slot a.
- accumulators_to_update: A list of the accumulators to change in the
- candidate_split_features and candidate_split_thresholds tensors.
- new_split_feature_rows: The new values for the candidate_split_features
- tensor. Intended to be used with
- `tf.scatter_update(candidate_split_features,
- accumulators_to_update,
- new_split_feature_rows)`
- new_split_threshold_rows: The new values for the candidate_split_thresholds
- tensor. Intended to be used with
- `tf.scatter_update(candidate_split_thresholds,
- accumulators_to_update,
- new_split_feature_thresholds)`
+input_data: The features for the current batch of training data.
+ `input_data[i][j]` is the j-th feature of the i-th input.
+node_to_accumulator: For a fertile node i, node_to_accumulator[i] is the
+ associated accumulator slot. For non-fertile nodes, it is -1.
+leaves: `leaves[i]` is the leaf that the i-th input landed in, as
+ calculated by CountExtremelyRandomStats.
+candidate_split_features: The current features for the candidate splits;
+ `candidate_split_features[a][s]` is the index of the feature being
+ considered by split s in accumulator slot a.
+candidate_split_thresholds: The current thresholds for the candidate splits;
+ `candidate_split_thresholds[a][s]` is the threshold value being
+ considered by split s in accumulator slot a.
+accumulators_to_update: A list of the accumulators to change in the
+ candidate_split_features and candidate_split_thresholds tensors.
+new_split_feature_rows: The new values for the candidate_split_features
+ tensor. Intended to be used with
+ `tf.scatter_update(candidate_split_features,
+ accumulators_to_update,
+ new_split_feature_rows)`
+new_split_threshold_rows: The new values for the candidate_split_thresholds
+ tensor. Intended to be used with
+ `tf.scatter_update(candidate_split_thresholds,
+ accumulators_to_update,
+ new_split_feature_thresholds)`
)doc");
class SampleInputs : public OpKernel {
diff --git a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
index 69e92c1a13..026262e47f 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
@@ -35,65 +35,63 @@ using tensorforest::CheckTensorBounds;
using tensorforest::Initialize;
using tensorforest::WeightedGiniImpurity;
-
REGISTER_OP("UpdateFertileSlots")
- .Attr("max_depth: int32")
- .Attr("regression: bool = False")
- .Input("finished: int32")
- .Input("non_fertile_leaves: int32")
- .Input("non_fertile_leaf_scores: float")
- .Input("end_of_tree: int32")
- .Input("tree_depths: int32")
- .Input("accumulator_sums: float")
- .Input("node_to_accumulator: int32")
- .Output("node_map_updates: int32")
- .Output("accumulators_cleared: int32")
- .Output("accumulators_allocated: int32")
- .Output("new_nonfertile_leaves: int32")
- .Output("new_nonfertile_leaves_scores: float")
- .Doc(R"doc(
- Updates accumulator slots to reflect finished or newly fertile nodes.
-
- Leaves at the depth of the attribute `max_depth` won't be made fertile
- (i.e., won't be given an accumulator slot.)
-
- finished:= A 1-d int32 tensor containing the indices of fertile nodes that
- are ready to decide on a split.
- non_fertile_leaves:= A 1-d int32 tensor containing the indices of all the
- currently non-fertile leaves. If there are free accumulator slots after
- deallocation, UpdateFertileSlots will consider these nodes (plus the ones
- in new_leaves) and potentially turn some of them fertile.
- non_fertile_leaf_scores: `non_fertile_leaf_scores[i]` is the splitting score
- of the non-fertile leaf `non_fertile_leaves[i]`.
- end_of_tree: The end of tree tensor from the previous training iteration, used
- with the finished input to calculate a list of new leaf indices created by
- GrowTree, which will be considered to become fertile if there are free
- slots.
- tree_depths: `tree_depths[i]` is the depth in the tree of node i.
- accumulator_sums: For classification, `accumulator_sums[a][c]` records how
- many training examples have class c and have ended up in the fertile node
- associated with accumulator slot a. It has the total sum in entry 0 for
- convenience. For regression, it is the same except it contains the sum
- of the input labels that have been seen, and entry 0 contains the number
- of training examples that have been seen.
- node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by
- fertile node i, or -1 if node i isn't fertile.
- node_map_updates:= A 2-d int32 tensor describing the changes that need to
- be applied to the node_to_accumulator map. Intended to be used with
- `tf.scatter_update(node_to_accumulator,
- node_map_updates[0],
- node_map_updates[1])`.
- accumulators_cleared:= A 1-d int32 tensor containing the indices of all
- the accumulator slots that need to be cleared.
- accumulators_allocated:= A 1-d int32 tensor containing the indices of all
- the accumulator slots that need to be allocated.
- new_nonfertile_leaves:= A 1-d int32 tensor containing the indices of all the
- leaves that are now non-fertile.
- new_nonfertile_leaves_scores: `new_nonfertile_leaves_scores[i]` contains the
- splitting score for the non-fertile leaf `new_nonfertile_leaves[i]`.
+ .Attr("max_depth: int")
+ .Attr("regression: bool = False")
+ .Input("finished: int32")
+ .Input("non_fertile_leaves: int32")
+ .Input("non_fertile_leaf_scores: float")
+ .Input("end_of_tree: int32")
+ .Input("tree_depths: int32")
+ .Input("accumulator_sums: float")
+ .Input("node_to_accumulator: int32")
+ .Output("node_map_updates: int32")
+ .Output("accumulators_cleared: int32")
+ .Output("accumulators_allocated: int32")
+ .Output("new_nonfertile_leaves: int32")
+ .Output("new_nonfertile_leaves_scores: float")
+ .Doc(R"doc(
+Updates accumulator slots to reflect finished or newly fertile nodes.
+
+Leaves at the depth of the attribute `max_depth` won't be made fertile
+(i.e., won't be given an accumulator slot.)
+
+finished:= A 1-d int32 tensor containing the indices of fertile nodes that
+ are ready to decide on a split.
+non_fertile_leaves:= A 1-d int32 tensor containing the indices of all the
+ currently non-fertile leaves. If there are free accumulator slots after
+ deallocation, UpdateFertileSlots will consider these nodes (plus the ones
+ in new_leaves) and potentially turn some of them fertile.
+non_fertile_leaf_scores: `non_fertile_leaf_scores[i]` is the splitting score
+ of the non-fertile leaf `non_fertile_leaves[i]`.
+end_of_tree: The end of tree tensor from the previous training iteration, used
+ with the finished input to calculate a list of new leaf indices created by
+ GrowTree, which will be considered to become fertile if there are free
+ slots.
+tree_depths: `tree_depths[i]` is the depth in the tree of node i.
+accumulator_sums: For classification, `accumulator_sums[a][c]` records how
+ many training examples have class c and have ended up in the fertile node
+ associated with accumulator slot a. It has the total sum in entry 0 for
+ convenience. For regression, it is the same except it contains the sum
+ of the input labels that have been seen, and entry 0 contains the number
+ of training examples that have been seen.
+node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by
+ fertile node i, or -1 if node i isn't fertile.
+node_map_updates:= A 2-d int32 tensor describing the changes that need to
+ be applied to the node_to_accumulator map. Intended to be used with
+ `tf.scatter_update(node_to_accumulator,
+ node_map_updates[0],
+ node_map_updates[1])`.
+accumulators_cleared:= A 1-d int32 tensor containing the indices of all
+ the accumulator slots that need to be cleared.
+accumulators_allocated:= A 1-d int32 tensor containing the indices of all
+ the accumulator slots that need to be allocated.
+new_nonfertile_leaves:= A 1-d int32 tensor containing the indices of all the
+ leaves that are now non-fertile.
+new_nonfertile_leaves_scores: `new_nonfertile_leaves_scores[i]` contains the
+ splitting score for the non-fertile leaf `new_nonfertile_leaves[i]`.
)doc");
-
class UpdateFertileSlots : public OpKernel {
public:
explicit UpdateFertileSlots(OpKernelConstruction* context)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 77db4004c7..7f4c1abc1f 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1367,6 +1367,7 @@ tf_cc_tests(
"framework/op_def_builder_test.cc",
"framework/op_def_util_test.cc",
"framework/op_kernel_test.cc",
+ "framework/op_registration_test.cc",
"framework/partial_tensor_shape_test.cc",
"framework/rendezvous_test.cc",
"framework/resource_mgr_test.cc",
diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc
index 6f66a21875..3f1b037b02 100644
--- a/tensorflow/core/framework/load_library.cc
+++ b/tensorflow/core/framework/load_library.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <memory>
+#include <unordered_set>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -47,11 +48,32 @@ Status LoadLibrary(const char* library_filename, void** result,
Env* env = Env::Default();
void* lib;
OpList op_list;
+ std::unordered_set<string> seen_op_names;
{
mutex_lock lock(mu);
+ OpRegistry::Global()->ProcessRegistrations();
TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(
- [&op_list](const OpDef& opdef) { *op_list.add_op() = opdef; }));
- TF_RETURN_IF_ERROR(env->LoadLibrary(library_filename, &lib));
+ [&op_list, &seen_op_names](const Status& s,
+ const OpDef& opdef) -> Status {
+ if (errors::IsAlreadyExists(s)) {
+ if (seen_op_names.find(opdef.name()) == seen_op_names.end()) {
+ // Over writing a registration of an op not in this custom op
+ // library. Treat this as not an error.
+ return Status::OK();
+ }
+ }
+ *op_list.add_op() = opdef;
+ seen_op_names.insert(opdef.name());
+ return s;
+ }));
+ OpRegistry::Global()->DeferRegistrations();
+ Status s = env->LoadLibrary(library_filename, &lib);
+ if (!s.ok()) {
+ OpRegistry::Global()->ClearDeferredRegistrations();
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr));
+ return s;
+ }
+ OpRegistry::Global()->ProcessRegistrations();
TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr));
}
string str;
diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc
index 2ea735a790..41bb07581f 100644
--- a/tensorflow/core/framework/op.cc
+++ b/tensorflow/core/framework/op.cc
@@ -47,17 +47,12 @@ OpRegistry::~OpRegistry() {
for (const auto& e : registry_) delete e.second;
}
-void OpRegistry::Register(std::unique_ptr<OpRegistrationData> op_reg_data) {
- OpRegistrationData* raw_ptr = op_reg_data.get();
-
+void OpRegistry::Register(OpRegistrationDataFactory op_data_factory) {
mutex_lock lock(mu_);
if (initialized_) {
- TF_QCHECK_OK(RegisterAlreadyLocked(std::move(op_reg_data)));
+ TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
} else {
- deferred_.push_back(std::move(op_reg_data));
- }
- if (watcher_) {
- watcher_(raw_ptr->op_def);
+ deferred_.push_back(op_data_factory);
}
}
@@ -133,6 +128,21 @@ void OpRegistry::Export(bool include_internal, OpList* ops) const {
}
}
+void OpRegistry::DeferRegistrations() {
+ mutex_lock lock(mu_);
+ initialized_ = false;
+}
+
+void OpRegistry::ClearDeferredRegistrations() {
+ mutex_lock lock(mu_);
+ deferred_.clear();
+}
+
+void OpRegistry::ProcessRegistrations() const {
+ mutex_lock lock(mu_);
+ CallDeferred();
+}
+
string OpRegistry::DebugString(bool include_internal) const {
OpList op_list;
Export(include_internal, &op_list);
@@ -147,23 +157,34 @@ bool OpRegistry::CallDeferred() const {
if (initialized_) return false;
initialized_ = true;
for (int i = 0; i < deferred_.size(); ++i) {
- TF_QCHECK_OK(RegisterAlreadyLocked(std::move(deferred_[i])));
+ TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i]));
}
deferred_.clear();
return true;
}
Status OpRegistry::RegisterAlreadyLocked(
- std::unique_ptr<OpRegistrationData> op_reg_data) const {
- TF_RETURN_IF_ERROR(ValidateOpDef(op_reg_data->op_def));
-
- if (gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
- op_reg_data.get())) {
- op_reg_data.release(); // Ownership transferred to op_registry
- return Status::OK();
+ OpRegistrationDataFactory op_data_factory) const {
+ std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
+ Status s = op_data_factory(op_reg_data.get());
+ if (s.ok()) {
+ s = ValidateOpDef(op_reg_data->op_def);
+ if (s.ok() &&
+ !gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
+ op_reg_data.get())) {
+ s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
+ }
+ }
+ Status watcher_status = s;
+ if (watcher_) {
+ watcher_status = watcher_(s, op_reg_data->op_def);
+ }
+ if (s.ok()) {
+ op_reg_data.release();
} else {
- return errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
+ op_reg_data.reset();
}
+ return watcher_status;
}
// static
@@ -202,9 +223,15 @@ Status OpListOpRegistry::LookUp(const string& op_type_name,
namespace register_op {
OpDefBuilderReceiver::OpDefBuilderReceiver(
const OpDefBuilderWrapper<true>& wrapper) {
- std::unique_ptr<OpRegistrationData> data(new OpRegistrationData);
- wrapper.builder().Finalize(data.get());
- OpRegistry::Global()->Register(std::move(data));
+ OpRegistry::Global()->Register(
+ [wrapper](OpRegistrationData* op_reg_data) -> Status {
+ wrapper.builder().Finalize(op_reg_data);
+ // TODO(keveman): Add this check back again in a separate CL.
+ // if (!s.ok()) {
+ // return s;
+ // }
+ return Status::OK();
+ });
}
} // namespace register_op
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
index cd484dacd4..6cd76c26e2 100644
--- a/tensorflow/core/framework/op.h
+++ b/tensorflow/core/framework/op.h
@@ -54,22 +54,23 @@ class OpRegistryInterface {
};
// The standard implementation of OpRegistryInterface, along with a
-// global singleton used for registering OpDefs via the REGISTER
+// global singleton used for registering ops via the REGISTER
// macros below. Thread-safe.
//
// Example registration:
-// OpRegistry::Global()->Register([]()->OpDef{
-// OpDef def;
-// // Populate def here.
-// return def;
+// OpRegistry::Global()->Register(
+// [](OpRegistrationData* op_reg_data)->Status {
+// // Populate *op_reg_data here.
+// return Status::OK();
// });
class OpRegistry : public OpRegistryInterface {
public:
+ typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
+
OpRegistry();
~OpRegistry() override;
- // Calls watcher and registers the passed OpDef.
- void Register(std::unique_ptr<OpRegistrationData> op_data);
+ void Register(OpRegistrationDataFactory op_data_factory);
Status LookUp(const string& op_type_name,
const OpRegistrationData** op_reg_data) const override;
@@ -89,10 +90,12 @@ class OpRegistry : public OpRegistryInterface {
void GetRegisteredOps(std::vector<OpDef>* op_defs);
// Watcher, a function object.
- // watcher_, if not null, is called every time an op is registered via the
- // Register function. watcher_ is passed the OpDef of the op getting
- // registered.
- typedef std::function<void(const OpDef&)> Watcher;
+ // The watcher, if set by SetWatcher(), is called every time an op is
+ // registered via the Register function. The watcher is passed the Status
+ // obtained from building and adding the OpDef to the registry, and the OpDef
+ // itself if it was successfully built. A watcher returns a Status which is in
+ // turn returned as the final registration status.
+ typedef std::function<Status(const Status&, const OpDef&)> Watcher;
// An OpRegistry object has only one watcher. This interface is not thread
// safe, as different clients are free to set the watcher any time.
@@ -100,11 +103,26 @@ class OpRegistry : public OpRegistryInterface {
// operations :
// SetWatcher(a_watcher);
// Register some ops;
+ // op_registry->ProcessRegistrations();
// SetWatcher(nullptr);
// Returns a non-OK status if a non-null watcher is over-written by another
// non-null watcher.
Status SetWatcher(const Watcher& watcher);
+ // Process the current list of deferred registrations. Note that calls to
+ // Export, LookUp and DebugString would also implicitly process the deferred
+ // registrations.
+ void ProcessRegistrations() const;
+
+ // Defer the registrations until a later call to a function that processes
+ // deferred registrations are made. Normally, registrations that happen after
+ // calls to Export, LookUp, ProcessRegistrations and DebugString are processed
+ // immediately. Call this to defer future registrations.
+ void DeferRegistrations();
+
+ // Clear the registrations that have been deferred.
+ void ClearDeferredRegistrations();
+
private:
// Ensures that all the functions in deferred_ get called, their OpDef's
// registered, and returns with deferred_ empty. Returns true the first
@@ -114,13 +132,12 @@ class OpRegistry : public OpRegistryInterface {
// Add 'def' to the registry with additional data 'data'. On failure, or if
// there is already an OpDef with that name registered, returns a non-okay
// status.
- Status RegisterAlreadyLocked(std::unique_ptr<OpRegistrationData> op_data)
- const EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ Status RegisterAlreadyLocked(OpRegistrationDataFactory op_data_factory) const
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutable mutex mu_;
// Functions in deferred_ may only be called with mu_ held.
- mutable std::vector<std::unique_ptr<OpRegistrationData>> deferred_
- GUARDED_BY(mu_);
+ mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_);
// Values are owned.
mutable std::unordered_map<string, const OpRegistrationData*> registry_
GUARDED_BY(mu_);
diff --git a/tensorflow/core/framework/op_registration_test.cc b/tensorflow/core/framework/op_registration_test.cc
new file mode 100644
index 0000000000..9ab90df422
--- /dev/null
+++ b/tensorflow/core/framework/op_registration_test.cc
@@ -0,0 +1,57 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/core/framework/op.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace {
+
+void Register(const string& op_name, OpRegistry* registry) {
+ registry->Register([op_name](OpRegistrationData* op_reg_data) -> Status {
+ op_reg_data->op_def.set_name(op_name);
+ return Status::OK();
+ });
+}
+
+} // namespace
+
+TEST(OpRegistrationTest, TestBasic) {
+ std::unique_ptr<OpRegistry> registry(new OpRegistry);
+ Register("Foo", registry.get());
+ OpList op_list;
+ registry->Export(true, &op_list);
+ EXPECT_EQ(op_list.op().size(), 1);
+ EXPECT_EQ(op_list.op(0).name(), "Foo");
+}
+
+TEST(OpRegistrationTest, TestDuplicate) {
+ std::unique_ptr<OpRegistry> registry(new OpRegistry);
+ Register("Foo", registry.get());
+ registry->ProcessRegistrations();
+
+ registry->SetWatcher([](const Status& s, const OpDef& op_def) -> Status {
+ EXPECT_TRUE(errors::IsAlreadyExists(s));
+ return Status::OK();
+ });
+ Register("Foo", registry.get());
+ registry->ProcessRegistrations();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index a0602d6c06..63da598276 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -983,11 +983,11 @@ REGISTER_OP("StridedSlice")
.Output("output: T")
.Attr("T: type")
.Attr("Index: {int32, int64}")
- .Attr("begin_mask: int32 = 0")
- .Attr("end_mask: int32 = 0")
- .Attr("ellipse_mask: int32 = 0")
- .Attr("new_axis_mask: int32 = 0")
- .Attr("shrink_axis_mask: int32 = 0")
+ .Attr("begin_mask: int = 0")
+ .Attr("end_mask: int = 0")
+ .Attr("ellipse_mask: int = 0")
+ .Attr("new_axis_mask: int = 0")
+ .Attr("shrink_axis_mask: int = 0")
.Doc(R"doc(
Return a strided slice from `input`.
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
index 8c1112c16c..546cb36706 100644
--- a/tensorflow/python/framework/load_library.py
+++ b/tensorflow/python/framework/load_library.py
@@ -42,7 +42,11 @@ def load_op_library(library_filename):
Pass "library_filename" to a platform-specific mechanism for dynamically
loading a library. The rules for determining the exact location of the
- library are platform-specific and are not documented here.
+ library are platform-specific and are not documented here. When the
+ library is loaded, ops and kernels registered in the library via the
+ REGISTER_* macros are made available in the TensorFlow process. Note
+ that ops with the same name as an existing op are rejected and not
+ registered with the process.
Args:
library_filename: Path to the plugin.
diff --git a/tensorflow/user_ops/BUILD b/tensorflow/user_ops/BUILD
index 6d8773aebf..682df75a68 100644
--- a/tensorflow/user_ops/BUILD
+++ b/tensorflow/user_ops/BUILD
@@ -32,6 +32,18 @@ py_tests(
data = [":ackermann_op.so"],
)
+tf_custom_op_library(
+ name = "duplicate_op.so",
+ srcs = ["duplicate_op.cc"],
+)
+
+py_tests(
+ name = "duplicate_op_test",
+ size = "small",
+ srcs = ["duplicate_op_test.py"],
+ data = [":duplicate_op.so"],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/user_ops/duplicate_op.cc b/tensorflow/user_ops/duplicate_op.cc
new file mode 100644
index 0000000000..9f622e4db5
--- /dev/null
+++ b/tensorflow/user_ops/duplicate_op.cc
@@ -0,0 +1,26 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+REGISTER_OP("Add").Doc(R"doc(
+An op to test that duplicate registrations don't override previously
+registered ops.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/user_ops/duplicate_op_test.py b/tensorflow/user_ops/duplicate_op_test.py
new file mode 100644
index 0000000000..b61e68d75e
--- /dev/null
+++ b/tensorflow/user_ops/duplicate_op_test.py
@@ -0,0 +1,39 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for custom user ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+import tensorflow as tf
+
+
+class DuplicateOpTest(tf.test.TestCase):
+
+ def testBasic(self):
+ library_filename = os.path.join(tf.resource_loader.get_data_files_path(),
+ 'duplicate_op.so')
+ duplicate = tf.load_op_library(library_filename)
+
+ self.assertEqual(len(duplicate.OP_LIST.op), 0)
+
+ with self.test_session():
+ self.assertEqual(tf.add(1, 41).eval(), 42)
+
+
+if __name__ == '__main__':
+ tf.test.main()