diff options
-rw-r--r-- | tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc | 203 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc | 42 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc | 95 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc | 110 | ||||
-rw-r--r-- | tensorflow/core/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/framework/load_library.cc | 26 | ||||
-rw-r--r-- | tensorflow/core/framework/op.cc | 67 | ||||
-rw-r--r-- | tensorflow/core/framework/op.h | 47 | ||||
-rw-r--r-- | tensorflow/core/framework/op_registration_test.cc | 57 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 10 | ||||
-rw-r--r-- | tensorflow/python/framework/load_library.py | 6 | ||||
-rw-r--r-- | tensorflow/user_ops/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/user_ops/duplicate_op.cc | 26 | ||||
-rw-r--r-- | tensorflow/user_ops/duplicate_op_test.py | 39 |
14 files changed, 470 insertions, 271 deletions
diff --git a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc index 913a7d3c01..0ccf75bcc6 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc @@ -120,109 +120,108 @@ void Evaluate(const Tensor& input_data, const Tensor& input_labels, } } - REGISTER_OP("CountExtremelyRandomStats") - .Attr("num_classes: int32") - .Attr("regression: bool = false") - .Input("input_data: float") - .Input("input_labels: float") - - .Input("tree: int32") - .Input("tree_thresholds: float") - - .Input("node_to_accumulator: int32") - - .Input("candidate_split_features: int32") - .Input("candidate_split_thresholds: float") - - .Output("pcw_node_sums_delta: float") - .Output("pcw_node_squares_delta: float") - .Output("pcw_splits_indices: int32") - .Output("pcw_candidate_splits_sums_delta: float") - .Output("pcw_candidate_splits_squares_delta: float") - .Output("pcw_totals_indices: int32") - .Output("pcw_totals_sums_delta: float") - .Output("pcw_totals_squares_delta: float") - - .Output("leaves: int32") - .Doc(R"doc( - Calculates incremental statistics for a batch of training data. - - Each training example in `input_data` is sent through the decision tree - represented by `tree` and `tree_thresholds`. - The shape and contents of the outputs differ depending on whether - `regression` is true or not. - - For `regression` = false (classification), `pcw_node_sums_delta[i]` is - incremented for every node i that it passes through, and the leaf it ends up - in is recorded in `leaves[i]`. Then, if the leaf is fertile and - initialized, the statistics for its corresponding accumulator slot - are updated in `pcw_candidate_splits_delta` and `pcw_total_splits_delta`. - - For `regression` = true, outputs contain the sum of the input_labels - for the appropriate nodes. In adddition, the *_squares outputs are filled - in with the sums of the squares of the input_labels. Since outputs are - all updated at once, the *_indicies outputs don't specify the output - dimension to update, rather the *_delta output contains updates for all the - outputs. For example, `pcw_totals_indices` specifies the accumulators to - update, and `pcw_total_splits_sums_delta` contains the complete output - updates for each of those accumulators. - - The attr `num_classes` is needed to appropriately size the outputs. - - input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` - gives the j-th feature of the i-th input. - input_labels: The training batch's labels; `input_labels[i]` is the class - of the i-th input. - tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child - of the i-th node, `tree[i][0] + 1` gives the index of the right child of - the i-th node, and `tree[i][1]` gives the index of the feature used to - split the i-th node. - tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th - node. - node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]` - is it's accumulator slot. Otherwise, `node_to_accumulator[i]` is -1. - candidate_split_features: `candidate_split_features[a][s]` is the - index of the feature being considered by split s of accumulator slot a. - candidate_split_thresholds: `candidate_split_thresholds[a][s]` is the - threshold value being considered by split s of accumulator slot a. - pcw_node_sums_delta: `pcw_node_sums_delta[i][c]` is the number of training - examples in this training batch with class c that passed through node i for - classification. For regression, it is the sum of the input_labels that - have passed through node i. - pcw_node_squares_delta: `pcw_node_squares_delta[i][c]` is the sum of the - squares of the input labels that have passed through node i for - regression. Not set for classification. - pcw_splits_indices:= A 2-d tensor of shape (?, 3) for classification and - (?, 2) for regression. - `pcw_splits_indices[i]` gives the coordinates of an entry in - candidate_split_pcw_sums and candidate_split_pcw_squares that need to be - updated. This is meant to be passed with `pcw_candidate_splits_*_delta` to - a scatter_add for candidate_split_pcw_*: - training_ops.scatter_add_ndim(candidate_split_pcw_sums - pcw_splits_indices, pcw_candidate_splits_sums_delta) - pcw_candidate_splits_sums_delta: For classification, - `pcw_candidate_splits_sums_delta[i]` is the - number of training examples in this training batch that correspond to - the i-th entry in `pcw_splits_indices` which took the *left* branch of - candidate split. For regression, it is the same but a 2-D tensor that has - the sum of the input_labels for each i-th entry in the indices. - pcw_candidate_splits_squares_delta: For regression, same as - `pcw_candidate_splits_sums_delta` but the sum of the squares. Not set - for classification. - pcw_totals_indices: For classification, 'pcw_totals_indices` contains the - indices (accumulator, class) into total_pcw_sums to update with - pcw_totals_sums_delta. For regression, it only contains the accumulator - (not the class), because pcw_totals_*_delta will contain all the outputs. - pcw_totals_sums_delta: For classification, `pcw_totals_sums_delta[i]` is the - number of training examples in this batch that ended up in the fertile - node with accumulator and class indicated by `pcw_totals_indices[i]`. - For regression, it is the sum of the input_labels corresponding to the - entries in `pcw_totals_indices[i]`. - pcw_totals_squares_delta: For regression, same as - `pcw_totals_sums_delta` but the sum of the squares. Not set - for classification. - leaves: `leaves[i]` is the leaf that input i ended up in. + .Attr("num_classes: int") + .Attr("regression: bool = false") + .Input("input_data: float") + .Input("input_labels: float") + + .Input("tree: int32") + .Input("tree_thresholds: float") + + .Input("node_to_accumulator: int32") + + .Input("candidate_split_features: int32") + .Input("candidate_split_thresholds: float") + + .Output("pcw_node_sums_delta: float") + .Output("pcw_node_squares_delta: float") + .Output("pcw_splits_indices: int32") + .Output("pcw_candidate_splits_sums_delta: float") + .Output("pcw_candidate_splits_squares_delta: float") + .Output("pcw_totals_indices: int32") + .Output("pcw_totals_sums_delta: float") + .Output("pcw_totals_squares_delta: float") + + .Output("leaves: int32") + .Doc(R"doc( +Calculates incremental statistics for a batch of training data. + +Each training example in `input_data` is sent through the decision tree +represented by `tree` and `tree_thresholds`. +The shape and contents of the outputs differ depending on whether +`regression` is true or not. + +For `regression` = false (classification), `pcw_node_sums_delta[i]` is +incremented for every node i that it passes through, and the leaf it ends up +in is recorded in `leaves[i]`. Then, if the leaf is fertile and +initialized, the statistics for its corresponding accumulator slot +are updated in `pcw_candidate_splits_delta` and `pcw_total_splits_delta`. + +For `regression` = true, outputs contain the sum of the input_labels +for the appropriate nodes. In adddition, the *_squares outputs are filled +in with the sums of the squares of the input_labels. Since outputs are +all updated at once, the *_indicies outputs don't specify the output +dimension to update, rather the *_delta output contains updates for all the +outputs. For example, `pcw_totals_indices` specifies the accumulators to +update, and `pcw_total_splits_sums_delta` contains the complete output +updates for each of those accumulators. + +The attr `num_classes` is needed to appropriately size the outputs. + +input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` + gives the j-th feature of the i-th input. +input_labels: The training batch's labels; `input_labels[i]` is the class + of the i-th input. +tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child + of the i-th node, `tree[i][0] + 1` gives the index of the right child of + the i-th node, and `tree[i][1]` gives the index of the feature used to + split the i-th node. +tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th + node. +node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]` + is it's accumulator slot. Otherwise, `node_to_accumulator[i]` is -1. +candidate_split_features: `candidate_split_features[a][s]` is the + index of the feature being considered by split s of accumulator slot a. +candidate_split_thresholds: `candidate_split_thresholds[a][s]` is the + threshold value being considered by split s of accumulator slot a. +pcw_node_sums_delta: `pcw_node_sums_delta[i][c]` is the number of training + examples in this training batch with class c that passed through node i for + classification. For regression, it is the sum of the input_labels that + have passed through node i. +pcw_node_squares_delta: `pcw_node_squares_delta[i][c]` is the sum of the + squares of the input labels that have passed through node i for + regression. Not set for classification. +pcw_splits_indices:= A 2-d tensor of shape (?, 3) for classification and + (?, 2) for regression. + `pcw_splits_indices[i]` gives the coordinates of an entry in + candidate_split_pcw_sums and candidate_split_pcw_squares that need to be + updated. This is meant to be passed with `pcw_candidate_splits_*_delta` to + a scatter_add for candidate_split_pcw_*: + training_ops.scatter_add_ndim(candidate_split_pcw_sums + pcw_splits_indices, pcw_candidate_splits_sums_delta) +pcw_candidate_splits_sums_delta: For classification, + `pcw_candidate_splits_sums_delta[i]` is the + number of training examples in this training batch that correspond to + the i-th entry in `pcw_splits_indices` which took the *left* branch of + candidate split. For regression, it is the same but a 2-D tensor that has + the sum of the input_labels for each i-th entry in the indices. +pcw_candidate_splits_squares_delta: For regression, same as + `pcw_candidate_splits_sums_delta` but the sum of the squares. Not set + for classification. +pcw_totals_indices: For classification, 'pcw_totals_indices` contains the + indices (accumulator, class) into total_pcw_sums to update with + pcw_totals_sums_delta. For regression, it only contains the accumulator + (not the class), because pcw_totals_*_delta will contain all the outputs. +pcw_totals_sums_delta: For classification, `pcw_totals_sums_delta[i]` is the + number of training examples in this batch that ended up in the fertile + node with accumulator and class indicated by `pcw_totals_indices[i]`. + For regression, it is the sum of the input_labels corresponding to the + entries in `pcw_totals_indices[i]`. +pcw_totals_squares_delta: For regression, same as + `pcw_totals_sums_delta` but the sum of the squares. Not set + for classification. +leaves: `leaves[i]` is the leaf that input i ended up in. )doc"); class CountExtremelyRandomStats : public OpKernel { diff --git a/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc index 9cb84fc63d..e1369f9d8c 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc @@ -26,29 +26,28 @@ using tensorforest::CheckTensorBounds; using tensorforest::Sum; REGISTER_OP("FinishedNodes") - .Attr("num_split_after_samples: int32") - .Input("leaves: int32") - .Input("node_to_accumulator: int32") - .Input("accumulator_sums: float") - - .Output("finished: int32") - .Doc(R"doc( - Determines which of the given leaf nodes are done accumulating. - - leaves:= A 1-d int32 tensor. Lists the nodes that are currently leaves. - node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]` - is it's accumulator slot. Otherwise, `node_to_accumulator[i]` is -1. - accumulator_sums: For classification, `accumulator_sums[a][c]` records how many - training examples have class c and have ended up in the fertile node - associated with accumulator slot a. It has the total sum in entry 0 for - convenience. For regression, it is the same except it contains the sum - of the input labels that have been seen, and entry 0 contains the number - of training examples that have been seen. - finished:= A 1-d int32 tensor. Contains the nodes that have total split - counts greater or equal to the num_split_after_samples attribute. + .Attr("num_split_after_samples: int") + .Input("leaves: int32") + .Input("node_to_accumulator: int32") + .Input("accumulator_sums: float") + + .Output("finished: int32") + .Doc(R"doc( +Determines which of the given leaf nodes are done accumulating. + +leaves:= A 1-d int32 tensor. Lists the nodes that are currently leaves. +node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]` + is it's accumulator slot. Otherwise, `node_to_accumulator[i]` is -1. +accumulator_sums: For classification, `accumulator_sums[a][c]` records how many + training examples have class c and have ended up in the fertile node + associated with accumulator slot a. It has the total sum in entry 0 for + convenience. For regression, it is the same except it contains the sum + of the input labels that have been seen, and entry 0 contains the number + of training examples that have been seen. +finished:= A 1-d int32 tensor. Contains the nodes that have total split + counts greater or equal to the num_split_after_samples attribute. )doc"); - class FinishedNodes : public OpKernel { public: explicit FinishedNodes(OpKernelConstruction* context) @@ -128,4 +127,3 @@ REGISTER_KERNEL_BUILDER(Name("FinishedNodes").Device(DEVICE_CPU), FinishedNodes); } // namespace tensorflow - diff --git a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc index 8ceb4a9f0c..182b1257b6 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc @@ -31,58 +31,57 @@ namespace tensorflow { using tensorforest::CheckTensorBounds; using tensorforest::IsAllInitialized; - REGISTER_OP("SampleInputs") - .Attr("split_initializations_per_input: int32") - .Attr("split_sampling_random_seed: int32") - .Input("input_data: float") - .Input("node_to_accumulator: int32") - .Input("leaves: int32") - .Input("candidate_split_features: int32") - .Input("candidate_split_thresholds: float") - .Output("accumulators_to_update: int32") - .Output("new_split_feature_rows: int32") - .Output("new_split_threshold_rows: float") - .Doc(R"doc( - Initializes candidate splits for newly fertile nodes. + .Attr("split_initializations_per_input: int") + .Attr("split_sampling_random_seed: int") + .Input("input_data: float") + .Input("node_to_accumulator: int32") + .Input("leaves: int32") + .Input("candidate_split_features: int32") + .Input("candidate_split_thresholds: float") + .Output("accumulators_to_update: int32") + .Output("new_split_feature_rows: int32") + .Output("new_split_threshold_rows: float") + .Doc(R"doc( +Initializes candidate splits for newly fertile nodes. - In an extremely random forest, we don't consider all possible threshold - values for a candidate split feature, but rather only a sampling of them. - This Op takes those samples from the training data in `input_data`. The - feature and threshold samples are stored in tensors that are indexed by - accumulator slot, so for each input, we must first look up which leaf - it ended up in (using `leaves`) and then which accumulator slot if any - that leaf maps to (using `node_to_accumulator`). +In an extremely random forest, we don't consider all possible threshold +values for a candidate split feature, but rather only a sampling of them. +This Op takes those samples from the training data in `input_data`. The +feature and threshold samples are stored in tensors that are indexed by +accumulator slot, so for each input, we must first look up which leaf +it ended up in (using `leaves`) and then which accumulator slot if any +that leaf maps to (using `node_to_accumulator`). - The attribute `split_initializations_per_input` controls how many splits - a single training example can initialize, and the attribute - `split_sampling_random_seed` sets the random number generator's seed - (a value of 0 means use the current time as the seed). +The attribute `split_initializations_per_input` controls how many splits +a single training example can initialize, and the attribute +`split_sampling_random_seed` sets the random number generator's seed +(a value of 0 means use the current time as the seed). - input_data: The features for the current batch of training data. - `input_data[i][j]` is the j-th feature of the i-th input. - node_to_accumulator: For a fertile node i, node_to_accumulator[i] is the - associated accumulator slot. For non-fertile nodes, it is -1. - leaves: `leaves[i]` is the leaf that the i-th input landed in, as - calculated by CountExtremelyRandomStats. - candidate_split_features: The current features for the candidate splits; - `candidate_split_features[a][s]` is the index of the feature being - considered by split s in accumulator slot a. - candidate_split_thresholds: The current thresholds for the candidate splits; - `candidate_split_thresholds[a][s]` is the threshold value being - considered by split s in accumulator slot a. - accumulators_to_update: A list of the accumulators to change in the - candidate_split_features and candidate_split_thresholds tensors. - new_split_feature_rows: The new values for the candidate_split_features - tensor. Intended to be used with - `tf.scatter_update(candidate_split_features, - accumulators_to_update, - new_split_feature_rows)` - new_split_threshold_rows: The new values for the candidate_split_thresholds - tensor. Intended to be used with - `tf.scatter_update(candidate_split_thresholds, - accumulators_to_update, - new_split_feature_thresholds)` +input_data: The features for the current batch of training data. + `input_data[i][j]` is the j-th feature of the i-th input. +node_to_accumulator: For a fertile node i, node_to_accumulator[i] is the + associated accumulator slot. For non-fertile nodes, it is -1. +leaves: `leaves[i]` is the leaf that the i-th input landed in, as + calculated by CountExtremelyRandomStats. +candidate_split_features: The current features for the candidate splits; + `candidate_split_features[a][s]` is the index of the feature being + considered by split s in accumulator slot a. +candidate_split_thresholds: The current thresholds for the candidate splits; + `candidate_split_thresholds[a][s]` is the threshold value being + considered by split s in accumulator slot a. +accumulators_to_update: A list of the accumulators to change in the + candidate_split_features and candidate_split_thresholds tensors. +new_split_feature_rows: The new values for the candidate_split_features + tensor. Intended to be used with + `tf.scatter_update(candidate_split_features, + accumulators_to_update, + new_split_feature_rows)` +new_split_threshold_rows: The new values for the candidate_split_thresholds + tensor. Intended to be used with + `tf.scatter_update(candidate_split_thresholds, + accumulators_to_update, + new_split_feature_thresholds)` )doc"); class SampleInputs : public OpKernel { diff --git a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc index 69e92c1a13..026262e47f 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc @@ -35,65 +35,63 @@ using tensorforest::CheckTensorBounds; using tensorforest::Initialize; using tensorforest::WeightedGiniImpurity; - REGISTER_OP("UpdateFertileSlots") - .Attr("max_depth: int32") - .Attr("regression: bool = False") - .Input("finished: int32") - .Input("non_fertile_leaves: int32") - .Input("non_fertile_leaf_scores: float") - .Input("end_of_tree: int32") - .Input("tree_depths: int32") - .Input("accumulator_sums: float") - .Input("node_to_accumulator: int32") - .Output("node_map_updates: int32") - .Output("accumulators_cleared: int32") - .Output("accumulators_allocated: int32") - .Output("new_nonfertile_leaves: int32") - .Output("new_nonfertile_leaves_scores: float") - .Doc(R"doc( - Updates accumulator slots to reflect finished or newly fertile nodes. - - Leaves at the depth of the attribute `max_depth` won't be made fertile - (i.e., won't be given an accumulator slot.) - - finished:= A 1-d int32 tensor containing the indices of fertile nodes that - are ready to decide on a split. - non_fertile_leaves:= A 1-d int32 tensor containing the indices of all the - currently non-fertile leaves. If there are free accumulator slots after - deallocation, UpdateFertileSlots will consider these nodes (plus the ones - in new_leaves) and potentially turn some of them fertile. - non_fertile_leaf_scores: `non_fertile_leaf_scores[i]` is the splitting score - of the non-fertile leaf `non_fertile_leaves[i]`. - end_of_tree: The end of tree tensor from the previous training iteration, used - with the finished input to calculate a list of new leaf indices created by - GrowTree, which will be considered to become fertile if there are free - slots. - tree_depths: `tree_depths[i]` is the depth in the tree of node i. - accumulator_sums: For classification, `accumulator_sums[a][c]` records how - many training examples have class c and have ended up in the fertile node - associated with accumulator slot a. It has the total sum in entry 0 for - convenience. For regression, it is the same except it contains the sum - of the input labels that have been seen, and entry 0 contains the number - of training examples that have been seen. - node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by - fertile node i, or -1 if node i isn't fertile. - node_map_updates:= A 2-d int32 tensor describing the changes that need to - be applied to the node_to_accumulator map. Intended to be used with - `tf.scatter_update(node_to_accumulator, - node_map_updates[0], - node_map_updates[1])`. - accumulators_cleared:= A 1-d int32 tensor containing the indices of all - the accumulator slots that need to be cleared. - accumulators_allocated:= A 1-d int32 tensor containing the indices of all - the accumulator slots that need to be allocated. - new_nonfertile_leaves:= A 1-d int32 tensor containing the indices of all the - leaves that are now non-fertile. - new_nonfertile_leaves_scores: `new_nonfertile_leaves_scores[i]` contains the - splitting score for the non-fertile leaf `new_nonfertile_leaves[i]`. + .Attr("max_depth: int") + .Attr("regression: bool = False") + .Input("finished: int32") + .Input("non_fertile_leaves: int32") + .Input("non_fertile_leaf_scores: float") + .Input("end_of_tree: int32") + .Input("tree_depths: int32") + .Input("accumulator_sums: float") + .Input("node_to_accumulator: int32") + .Output("node_map_updates: int32") + .Output("accumulators_cleared: int32") + .Output("accumulators_allocated: int32") + .Output("new_nonfertile_leaves: int32") + .Output("new_nonfertile_leaves_scores: float") + .Doc(R"doc( +Updates accumulator slots to reflect finished or newly fertile nodes. + +Leaves at the depth of the attribute `max_depth` won't be made fertile +(i.e., won't be given an accumulator slot.) + +finished:= A 1-d int32 tensor containing the indices of fertile nodes that + are ready to decide on a split. +non_fertile_leaves:= A 1-d int32 tensor containing the indices of all the + currently non-fertile leaves. If there are free accumulator slots after + deallocation, UpdateFertileSlots will consider these nodes (plus the ones + in new_leaves) and potentially turn some of them fertile. +non_fertile_leaf_scores: `non_fertile_leaf_scores[i]` is the splitting score + of the non-fertile leaf `non_fertile_leaves[i]`. +end_of_tree: The end of tree tensor from the previous training iteration, used + with the finished input to calculate a list of new leaf indices created by + GrowTree, which will be considered to become fertile if there are free + slots. +tree_depths: `tree_depths[i]` is the depth in the tree of node i. +accumulator_sums: For classification, `accumulator_sums[a][c]` records how + many training examples have class c and have ended up in the fertile node + associated with accumulator slot a. It has the total sum in entry 0 for + convenience. For regression, it is the same except it contains the sum + of the input labels that have been seen, and entry 0 contains the number + of training examples that have been seen. +node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by + fertile node i, or -1 if node i isn't fertile. +node_map_updates:= A 2-d int32 tensor describing the changes that need to + be applied to the node_to_accumulator map. Intended to be used with + `tf.scatter_update(node_to_accumulator, + node_map_updates[0], + node_map_updates[1])`. +accumulators_cleared:= A 1-d int32 tensor containing the indices of all + the accumulator slots that need to be cleared. +accumulators_allocated:= A 1-d int32 tensor containing the indices of all + the accumulator slots that need to be allocated. +new_nonfertile_leaves:= A 1-d int32 tensor containing the indices of all the + leaves that are now non-fertile. +new_nonfertile_leaves_scores: `new_nonfertile_leaves_scores[i]` contains the + splitting score for the non-fertile leaf `new_nonfertile_leaves[i]`. )doc"); - class UpdateFertileSlots : public OpKernel { public: explicit UpdateFertileSlots(OpKernelConstruction* context) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 77db4004c7..7f4c1abc1f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1367,6 +1367,7 @@ tf_cc_tests( "framework/op_def_builder_test.cc", "framework/op_def_util_test.cc", "framework/op_kernel_test.cc", + "framework/op_registration_test.cc", "framework/partial_tensor_shape_test.cc", "framework/rendezvous_test.cc", "framework/resource_mgr_test.cc", diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc index 6f66a21875..3f1b037b02 100644 --- a/tensorflow/core/framework/load_library.cc +++ b/tensorflow/core/framework/load_library.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include <memory> +#include <unordered_set> #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -47,11 +48,32 @@ Status LoadLibrary(const char* library_filename, void** result, Env* env = Env::Default(); void* lib; OpList op_list; + std::unordered_set<string> seen_op_names; { mutex_lock lock(mu); + OpRegistry::Global()->ProcessRegistrations(); TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher( - [&op_list](const OpDef& opdef) { *op_list.add_op() = opdef; })); - TF_RETURN_IF_ERROR(env->LoadLibrary(library_filename, &lib)); + [&op_list, &seen_op_names](const Status& s, + const OpDef& opdef) -> Status { + if (errors::IsAlreadyExists(s)) { + if (seen_op_names.find(opdef.name()) == seen_op_names.end()) { + // Over writing a registration of an op not in this custom op + // library. Treat this as not an error. + return Status::OK(); + } + } + *op_list.add_op() = opdef; + seen_op_names.insert(opdef.name()); + return s; + })); + OpRegistry::Global()->DeferRegistrations(); + Status s = env->LoadLibrary(library_filename, &lib); + if (!s.ok()) { + OpRegistry::Global()->ClearDeferredRegistrations(); + TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr)); + return s; + } + OpRegistry::Global()->ProcessRegistrations(); TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr)); } string str; diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index 2ea735a790..41bb07581f 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -47,17 +47,12 @@ OpRegistry::~OpRegistry() { for (const auto& e : registry_) delete e.second; } -void OpRegistry::Register(std::unique_ptr<OpRegistrationData> op_reg_data) { - OpRegistrationData* raw_ptr = op_reg_data.get(); - +void OpRegistry::Register(OpRegistrationDataFactory op_data_factory) { mutex_lock lock(mu_); if (initialized_) { - TF_QCHECK_OK(RegisterAlreadyLocked(std::move(op_reg_data))); + TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory)); } else { - deferred_.push_back(std::move(op_reg_data)); - } - if (watcher_) { - watcher_(raw_ptr->op_def); + deferred_.push_back(op_data_factory); } } @@ -133,6 +128,21 @@ void OpRegistry::Export(bool include_internal, OpList* ops) const { } } +void OpRegistry::DeferRegistrations() { + mutex_lock lock(mu_); + initialized_ = false; +} + +void OpRegistry::ClearDeferredRegistrations() { + mutex_lock lock(mu_); + deferred_.clear(); +} + +void OpRegistry::ProcessRegistrations() const { + mutex_lock lock(mu_); + CallDeferred(); +} + string OpRegistry::DebugString(bool include_internal) const { OpList op_list; Export(include_internal, &op_list); @@ -147,23 +157,34 @@ bool OpRegistry::CallDeferred() const { if (initialized_) return false; initialized_ = true; for (int i = 0; i < deferred_.size(); ++i) { - TF_QCHECK_OK(RegisterAlreadyLocked(std::move(deferred_[i]))); + TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i])); } deferred_.clear(); return true; } Status OpRegistry::RegisterAlreadyLocked( - std::unique_ptr<OpRegistrationData> op_reg_data) const { - TF_RETURN_IF_ERROR(ValidateOpDef(op_reg_data->op_def)); - - if (gtl::InsertIfNotPresent(®istry_, op_reg_data->op_def.name(), - op_reg_data.get())) { - op_reg_data.release(); // Ownership transferred to op_registry - return Status::OK(); + OpRegistrationDataFactory op_data_factory) const { + std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData); + Status s = op_data_factory(op_reg_data.get()); + if (s.ok()) { + s = ValidateOpDef(op_reg_data->op_def); + if (s.ok() && + !gtl::InsertIfNotPresent(®istry_, op_reg_data->op_def.name(), + op_reg_data.get())) { + s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name()); + } + } + Status watcher_status = s; + if (watcher_) { + watcher_status = watcher_(s, op_reg_data->op_def); + } + if (s.ok()) { + op_reg_data.release(); } else { - return errors::AlreadyExists("Op with name ", op_reg_data->op_def.name()); + op_reg_data.reset(); } + return watcher_status; } // static @@ -202,9 +223,15 @@ Status OpListOpRegistry::LookUp(const string& op_type_name, namespace register_op { OpDefBuilderReceiver::OpDefBuilderReceiver( const OpDefBuilderWrapper<true>& wrapper) { - std::unique_ptr<OpRegistrationData> data(new OpRegistrationData); - wrapper.builder().Finalize(data.get()); - OpRegistry::Global()->Register(std::move(data)); + OpRegistry::Global()->Register( + [wrapper](OpRegistrationData* op_reg_data) -> Status { + wrapper.builder().Finalize(op_reg_data); + // TODO(keveman): Add this check back again in a separate CL. + // if (!s.ok()) { + // return s; + // } + return Status::OK(); + }); } } // namespace register_op diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index cd484dacd4..6cd76c26e2 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -54,22 +54,23 @@ class OpRegistryInterface { }; // The standard implementation of OpRegistryInterface, along with a -// global singleton used for registering OpDefs via the REGISTER +// global singleton used for registering ops via the REGISTER // macros below. Thread-safe. // // Example registration: -// OpRegistry::Global()->Register([]()->OpDef{ -// OpDef def; -// // Populate def here. -// return def; +// OpRegistry::Global()->Register( +// [](OpRegistrationData* op_reg_data)->Status { +// // Populate *op_reg_data here. +// return Status::OK(); // }); class OpRegistry : public OpRegistryInterface { public: + typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory; + OpRegistry(); ~OpRegistry() override; - // Calls watcher and registers the passed OpDef. - void Register(std::unique_ptr<OpRegistrationData> op_data); + void Register(OpRegistrationDataFactory op_data_factory); Status LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const override; @@ -89,10 +90,12 @@ class OpRegistry : public OpRegistryInterface { void GetRegisteredOps(std::vector<OpDef>* op_defs); // Watcher, a function object. - // watcher_, if not null, is called every time an op is registered via the - // Register function. watcher_ is passed the OpDef of the op getting - // registered. - typedef std::function<void(const OpDef&)> Watcher; + // The watcher, if set by SetWatcher(), is called every time an op is + // registered via the Register function. The watcher is passed the Status + // obtained from building and adding the OpDef to the registry, and the OpDef + // itself if it was successfully built. A watcher returns a Status which is in + // turn returned as the final registration status. + typedef std::function<Status(const Status&, const OpDef&)> Watcher; // An OpRegistry object has only one watcher. This interface is not thread // safe, as different clients are free to set the watcher any time. @@ -100,11 +103,26 @@ class OpRegistry : public OpRegistryInterface { // operations : // SetWatcher(a_watcher); // Register some ops; + // op_registry->ProcessRegistrations(); // SetWatcher(nullptr); // Returns a non-OK status if a non-null watcher is over-written by another // non-null watcher. Status SetWatcher(const Watcher& watcher); + // Process the current list of deferred registrations. Note that calls to + // Export, LookUp and DebugString would also implicitly process the deferred + // registrations. + void ProcessRegistrations() const; + + // Defer the registrations until a later call to a function that processes + // deferred registrations are made. Normally, registrations that happen after + // calls to Export, LookUp, ProcessRegistrations and DebugString are processed + // immediately. Call this to defer future registrations. + void DeferRegistrations(); + + // Clear the registrations that have been deferred. + void ClearDeferredRegistrations(); + private: // Ensures that all the functions in deferred_ get called, their OpDef's // registered, and returns with deferred_ empty. Returns true the first @@ -114,13 +132,12 @@ class OpRegistry : public OpRegistryInterface { // Add 'def' to the registry with additional data 'data'. On failure, or if // there is already an OpDef with that name registered, returns a non-okay // status. - Status RegisterAlreadyLocked(std::unique_ptr<OpRegistrationData> op_data) - const EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status RegisterAlreadyLocked(OpRegistrationDataFactory op_data_factory) const + EXCLUSIVE_LOCKS_REQUIRED(mu_); mutable mutex mu_; // Functions in deferred_ may only be called with mu_ held. - mutable std::vector<std::unique_ptr<OpRegistrationData>> deferred_ - GUARDED_BY(mu_); + mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_); // Values are owned. mutable std::unordered_map<string, const OpRegistrationData*> registry_ GUARDED_BY(mu_); diff --git a/tensorflow/core/framework/op_registration_test.cc b/tensorflow/core/framework/op_registration_test.cc new file mode 100644 index 0000000000..9ab90df422 --- /dev/null +++ b/tensorflow/core/framework/op_registration_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> + +#include "tensorflow/core/framework/op.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +void Register(const string& op_name, OpRegistry* registry) { + registry->Register([op_name](OpRegistrationData* op_reg_data) -> Status { + op_reg_data->op_def.set_name(op_name); + return Status::OK(); + }); +} + +} // namespace + +TEST(OpRegistrationTest, TestBasic) { + std::unique_ptr<OpRegistry> registry(new OpRegistry); + Register("Foo", registry.get()); + OpList op_list; + registry->Export(true, &op_list); + EXPECT_EQ(op_list.op().size(), 1); + EXPECT_EQ(op_list.op(0).name(), "Foo"); +} + +TEST(OpRegistrationTest, TestDuplicate) { + std::unique_ptr<OpRegistry> registry(new OpRegistry); + Register("Foo", registry.get()); + registry->ProcessRegistrations(); + + registry->SetWatcher([](const Status& s, const OpDef& op_def) -> Status { + EXPECT_TRUE(errors::IsAlreadyExists(s)); + return Status::OK(); + }); + Register("Foo", registry.get()); + registry->ProcessRegistrations(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index a0602d6c06..63da598276 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -983,11 +983,11 @@ REGISTER_OP("StridedSlice") .Output("output: T") .Attr("T: type") .Attr("Index: {int32, int64}") - .Attr("begin_mask: int32 = 0") - .Attr("end_mask: int32 = 0") - .Attr("ellipse_mask: int32 = 0") - .Attr("new_axis_mask: int32 = 0") - .Attr("shrink_axis_mask: int32 = 0") + .Attr("begin_mask: int = 0") + .Attr("end_mask: int = 0") + .Attr("ellipse_mask: int = 0") + .Attr("new_axis_mask: int = 0") + .Attr("shrink_axis_mask: int = 0") .Doc(R"doc( Return a strided slice from `input`. diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py index 8c1112c16c..546cb36706 100644 --- a/tensorflow/python/framework/load_library.py +++ b/tensorflow/python/framework/load_library.py @@ -42,7 +42,11 @@ def load_op_library(library_filename): Pass "library_filename" to a platform-specific mechanism for dynamically loading a library. The rules for determining the exact location of the - library are platform-specific and are not documented here. + library are platform-specific and are not documented here. When the + library is loaded, ops and kernels registered in the library via the + REGISTER_* macros are made available in the TensorFlow process. Note + that ops with the same name as an existing op are rejected and not + registered with the process. Args: library_filename: Path to the plugin. diff --git a/tensorflow/user_ops/BUILD b/tensorflow/user_ops/BUILD index 6d8773aebf..682df75a68 100644 --- a/tensorflow/user_ops/BUILD +++ b/tensorflow/user_ops/BUILD @@ -32,6 +32,18 @@ py_tests( data = [":ackermann_op.so"], ) +tf_custom_op_library( + name = "duplicate_op.so", + srcs = ["duplicate_op.cc"], +) + +py_tests( + name = "duplicate_op_test", + size = "small", + srcs = ["duplicate_op_test.py"], + data = [":duplicate_op.so"], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/user_ops/duplicate_op.cc b/tensorflow/user_ops/duplicate_op.cc new file mode 100644 index 0000000000..9f622e4db5 --- /dev/null +++ b/tensorflow/user_ops/duplicate_op.cc @@ -0,0 +1,26 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +REGISTER_OP("Add").Doc(R"doc( +An op to test that duplicate registrations don't override previously +registered ops. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/user_ops/duplicate_op_test.py b/tensorflow/user_ops/duplicate_op_test.py new file mode 100644 index 0000000000..b61e68d75e --- /dev/null +++ b/tensorflow/user_ops/duplicate_op_test.py @@ -0,0 +1,39 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for custom user ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path + +import tensorflow as tf + + +class DuplicateOpTest(tf.test.TestCase): + + def testBasic(self): + library_filename = os.path.join(tf.resource_loader.get_data_files_path(), + 'duplicate_op.so') + duplicate = tf.load_op_library(library_filename) + + self.assertEqual(len(duplicate.OP_LIST.op), 0) + + with self.test_session(): + self.assertEqual(tf.add(1, 41).eval(), 42) + + +if __name__ == '__main__': + tf.test.main() |