aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/shape_refiner.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/shape_refiner.cc')
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc235
1 files changed, 159 insertions, 76 deletions
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index 5135355a94..8eb383a14f 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -88,10 +88,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
}
// This needs to be filled in with real data in a second pass.
- std::vector<const Tensor*> input_tensors(node->num_inputs());
- std::vector<Tensor> real_tensors(node->num_inputs());
- std::vector<bool> attempted_materialization(node->num_inputs());
- std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
+ std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
std::vector<ShapeHandle> input_tensors_as_shapes;
// Create the inference context for this node with the existing input shapes.
@@ -104,78 +101,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
}
// Run the shape inference function, and return if there was an error.
- if (op_reg_data->shape_inference_fn) {
- TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
- } else {
- TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
- }
-
- // We must run the shape function repeatedly, in case users write
- // shape functions where they only conditionally call input_tensor()
- // based on the values of another input tensor.
- bool rerun_shape_fn;
- do {
- // If the result of running shape inference would have benefitted
- // from knowing the values of input tensors, try to materialize
- // the results of those tensors, and then run the shape inference
- // function again using those known tensors.
- rerun_shape_fn = false;
-
- // NOTE: It is possible to batch the extraction and
- // materialization of inputs, instead of materializing one input
- // at a time like we do below. If input-at-a-time computation
- // becomes a bottleneck, we could separate ExtractConstantSubgraph
- // into two functions: one that returns true if an input is
- // derivable from constants, and another function that extracts
- // the subgraph for multiple target nodes and executes the whole
- // subgraph once.
-
- for (int i = 0; i < c->num_inputs(); ++i) {
- if (!c->requested_input_tensor(i)) {
- continue;
- }
- // Check if we have not already filled in the requested input,
- // and if not, try to materialize the tensors.
- if (!attempted_materialization[i]) {
- attempted_materialization[i] = true;
-
- Tensor result;
- bool evaluated = false;
- TF_RETURN_IF_ERROR(
- EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
- if (evaluated) {
- real_tensors[i] = result;
- input_tensors[i] = &real_tensors[i];
- // We have more concrete information about a shape,
- // so re-run shape inference.
- rerun_shape_fn = true;
- }
- }
- if (c->requested_input_tensor_as_partial_shape(i) &&
- !attempted_tensor_as_shape_conversion[i]) {
- attempted_tensor_as_shape_conversion[i] = true;
- if (i >= input_tensors_as_shapes.size()) {
- input_tensors_as_shapes.resize(i + 1);
- }
- ShapeHandle s;
- TF_RETURN_IF_ERROR(ConstantPartialShape(c.get(), node, i, &s));
- input_tensors_as_shapes[i] = s;
- rerun_shape_fn = true;
- }
- }
-
- if (rerun_shape_fn) {
- // We have more information about the shapes on this pass,
- // so re-run shape inference.
- c->set_input_tensors(input_tensors);
- c->set_input_tensors_as_shapes(input_tensors_as_shapes);
- if (op_reg_data->shape_inference_fn) {
- TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c.get()));
- } else {
- TF_RETURN_IF_ERROR(shape_inference::UnknownShape(c.get()));
- }
- }
- } while (rerun_shape_fn);
+ TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, c.get()));
// Store the resulting InferenceContext object in the map.
node_to_context_[node].swap(c);
@@ -211,6 +137,74 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port,
return Status::OK();
}
+Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
+ auto it = node_to_context_.find(node);
+ if (it == node_to_context_.end()) {
+ *refined = true;
+ return AddNode(node);
+ }
+ InferenceContext* node_context = it->second.get();
+
+ // Give up if the context wasn't successfully built by the AddNode() method.
+ TF_RETURN_IF_ERROR(node_context->construction_status());
+
+ // Check if the shapes of the nodes in the fan-in of this node have changed,
+ // and if they have update the node input shapes.
+ for (const Edge* e : node->in_edges()) {
+ if (e->IsControlEdge()) continue;
+
+ Node* input = e->src();
+ auto iter = node_to_context_.find(input);
+ if (iter == node_to_context_.end()) {
+ return errors::FailedPrecondition(
+ "Input ", e->dst_input(), " ('", input->name(), "') for '",
+ node->name(), "' was not previously added to ShapeRefiner.");
+ }
+
+ InferenceContext* c = iter->second.get();
+ DCHECK_GE(e->dst_input(), 0);
+ if (node_context->set_input(e->dst_input(), c->output(e->src_output()))) {
+ *refined = true;
+ }
+
+ // Also propagate handle shape and dtype of edges which are carrying
+ // resource handles.
+ if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
+ if (node_context->set_input_handle_dtype(
+ e->dst_input(), c->output_handle_dtype(e->src_output()))) {
+ *refined = true;
+ }
+ if (node_context->set_input_handle_shape(
+ e->dst_input(), c->output_handle_shape(e->src_output()))) {
+ *refined = true;
+ }
+ }
+ }
+
+ if (!*refined) {
+ // No input shape has changed, we're done
+ return Status::OK();
+ }
+
+ // Get and run the shape function for this node to update the shapes of the
+ // outputs.
+ const OpRegistrationData* op_reg_data;
+ TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
+ if (op_reg_data->shape_inference_fn == nullptr &&
+ require_shape_inference_fns_) {
+ return errors::InvalidArgument(
+ "No shape inference function exists for op '", node->type_string(),
+ "', did you forget to define it?");
+ }
+
+ if (!op_reg_data->shape_inference_fn) {
+ // There is nothing more we can infer
+ return Status::OK();
+ }
+
+ return RunShapeFn(node, op_reg_data, node_context);
+}
+
Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
int dst_idx, bool* evaluated,
Tensor* result) {
@@ -463,4 +457,93 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
return Status::OK();
}
+Status ShapeRefiner::RunShapeFn(const Node* node,
+ const OpRegistrationData* op_reg_data,
+ shape_inference::InferenceContext* c) {
+ // This will be filled in with real data in a second pass.
+ std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
+ std::vector<Tensor> real_tensors(node->num_inputs());
+ std::vector<bool> attempted_materialization(node->num_inputs());
+ std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
+ std::vector<ShapeHandle> input_tensors_as_shapes;
+
+ // Run the shape inference function, and return if there was an error.
+ c->set_input_tensors(input_tensors);
+ c->set_input_tensors_as_shapes(input_tensors_as_shapes);
+ if (op_reg_data->shape_inference_fn) {
+ TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
+ } else {
+ TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
+ }
+
+ // We must run the shape function repeatedly, in case users write
+ // shape functions where they only conditionally call input_tensor()
+ // based on the values of another input tensor.
+ bool rerun_shape_fn;
+ do {
+ // If the result of running shape inference would have benefitted
+ // from knowing the values of input tensors, try to materialize
+ // the results of those tensors, and then run the shape inference
+ // function again using those known tensors.
+ rerun_shape_fn = false;
+
+ // NOTE: It is possible to batch the extraction and
+ // materialization of inputs, instead of materializing one input
+ // at a time like we do below. If input-at-a-time computation
+ // becomes a bottleneck, we could separate ExtractConstantSubgraph
+ // into two functions: one that returns true if an input is
+ // derivable from constants, and another function that extracts
+ // the subgraph for multiple target nodes and executes the whole
+ // subgraph once.
+
+ for (int i = 0; i < c->num_inputs(); ++i) {
+ if (!c->requested_input_tensor(i)) {
+ continue;
+ }
+ // Check if we have not already filled in the requested input,
+ // and if not, try to materialize the tensors.
+ if (!attempted_materialization[i]) {
+ attempted_materialization[i] = true;
+
+ Tensor result;
+ bool evaluated = false;
+ TF_RETURN_IF_ERROR(
+ EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
+ if (evaluated) {
+ real_tensors[i] = result;
+ input_tensors[i] = &real_tensors[i];
+ // We have more concrete information about a shape,
+ // so re-run shape inference.
+ rerun_shape_fn = true;
+ }
+ }
+ if (c->requested_input_tensor_as_partial_shape(i) &&
+ !attempted_tensor_as_shape_conversion[i]) {
+ attempted_tensor_as_shape_conversion[i] = true;
+ if (i >= input_tensors_as_shapes.size()) {
+ input_tensors_as_shapes.resize(i + 1);
+ }
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s));
+ input_tensors_as_shapes[i] = s;
+ rerun_shape_fn = true;
+ }
+ }
+
+ if (rerun_shape_fn) {
+ // We have more information about the shapes on this pass,
+ // so re-run shape inference.
+ c->set_input_tensors(input_tensors);
+ c->set_input_tensors_as_shapes(input_tensors_as_shapes);
+ if (op_reg_data->shape_inference_fn) {
+ TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c));
+ } else {
+ TF_RETURN_IF_ERROR(shape_inference::UnknownShape(c));
+ }
+ }
+ } while (rerun_shape_fn);
+
+ return Status::OK();
+}
+
} // namespace tensorflow