diff options
Diffstat (limited to 'tensorflow/core/common_runtime/shape_refiner.cc')
-rw-r--r-- | tensorflow/core/common_runtime/shape_refiner.cc | 235 |
1 files changed, 159 insertions, 76 deletions
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 5135355a94..8eb383a14f 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -88,10 +88,7 @@ Status ShapeRefiner::AddNode(const Node* node) { } // This needs to be filled in with real data in a second pass. - std::vector<const Tensor*> input_tensors(node->num_inputs()); - std::vector<Tensor> real_tensors(node->num_inputs()); - std::vector<bool> attempted_materialization(node->num_inputs()); - std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs()); + std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr); std::vector<ShapeHandle> input_tensors_as_shapes; // Create the inference context for this node with the existing input shapes. @@ -104,78 +101,7 @@ Status ShapeRefiner::AddNode(const Node* node) { } // Run the shape inference function, and return if there was an error. - if (op_reg_data->shape_inference_fn) { - TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn)); - } else { - TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape)); - } - - // We must run the shape function repeatedly, in case users write - // shape functions where they only conditionally call input_tensor() - // based on the values of another input tensor. - bool rerun_shape_fn; - do { - // If the result of running shape inference would have benefitted - // from knowing the values of input tensors, try to materialize - // the results of those tensors, and then run the shape inference - // function again using those known tensors. - rerun_shape_fn = false; - - // NOTE: It is possible to batch the extraction and - // materialization of inputs, instead of materializing one input - // at a time like we do below. If input-at-a-time computation - // becomes a bottleneck, we could separate ExtractConstantSubgraph - // into two functions: one that returns true if an input is - // derivable from constants, and another function that extracts - // the subgraph for multiple target nodes and executes the whole - // subgraph once. - - for (int i = 0; i < c->num_inputs(); ++i) { - if (!c->requested_input_tensor(i)) { - continue; - } - // Check if we have not already filled in the requested input, - // and if not, try to materialize the tensors. - if (!attempted_materialization[i]) { - attempted_materialization[i] = true; - - Tensor result; - bool evaluated = false; - TF_RETURN_IF_ERROR( - EvaluateConstantTensorForEdge(node, i, &evaluated, &result)); - if (evaluated) { - real_tensors[i] = result; - input_tensors[i] = &real_tensors[i]; - // We have more concrete information about a shape, - // so re-run shape inference. - rerun_shape_fn = true; - } - } - if (c->requested_input_tensor_as_partial_shape(i) && - !attempted_tensor_as_shape_conversion[i]) { - attempted_tensor_as_shape_conversion[i] = true; - if (i >= input_tensors_as_shapes.size()) { - input_tensors_as_shapes.resize(i + 1); - } - ShapeHandle s; - TF_RETURN_IF_ERROR(ConstantPartialShape(c.get(), node, i, &s)); - input_tensors_as_shapes[i] = s; - rerun_shape_fn = true; - } - } - - if (rerun_shape_fn) { - // We have more information about the shapes on this pass, - // so re-run shape inference. - c->set_input_tensors(input_tensors); - c->set_input_tensors_as_shapes(input_tensors_as_shapes); - if (op_reg_data->shape_inference_fn) { - TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c.get())); - } else { - TF_RETURN_IF_ERROR(shape_inference::UnknownShape(c.get())); - } - } - } while (rerun_shape_fn); + TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, c.get())); // Store the resulting InferenceContext object in the map. node_to_context_[node].swap(c); @@ -211,6 +137,74 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port, return Status::OK(); } +Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) { + auto it = node_to_context_.find(node); + if (it == node_to_context_.end()) { + *refined = true; + return AddNode(node); + } + InferenceContext* node_context = it->second.get(); + + // Give up if the context wasn't successfully built by the AddNode() method. + TF_RETURN_IF_ERROR(node_context->construction_status()); + + // Check if the shapes of the nodes in the fan-in of this node have changed, + // and if they have update the node input shapes. + for (const Edge* e : node->in_edges()) { + if (e->IsControlEdge()) continue; + + Node* input = e->src(); + auto iter = node_to_context_.find(input); + if (iter == node_to_context_.end()) { + return errors::FailedPrecondition( + "Input ", e->dst_input(), " ('", input->name(), "') for '", + node->name(), "' was not previously added to ShapeRefiner."); + } + + InferenceContext* c = iter->second.get(); + DCHECK_GE(e->dst_input(), 0); + if (node_context->set_input(e->dst_input(), c->output(e->src_output()))) { + *refined = true; + } + + // Also propagate handle shape and dtype of edges which are carrying + // resource handles. + if (e->src()->output_type(e->src_output()) == DT_RESOURCE) { + if (node_context->set_input_handle_dtype( + e->dst_input(), c->output_handle_dtype(e->src_output()))) { + *refined = true; + } + if (node_context->set_input_handle_shape( + e->dst_input(), c->output_handle_shape(e->src_output()))) { + *refined = true; + } + } + } + + if (!*refined) { + // No input shape has changed, we're done + return Status::OK(); + } + + // Get and run the shape function for this node to update the shapes of the + // outputs. + const OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data)); + if (op_reg_data->shape_inference_fn == nullptr && + require_shape_inference_fns_) { + return errors::InvalidArgument( + "No shape inference function exists for op '", node->type_string(), + "', did you forget to define it?"); + } + + if (!op_reg_data->shape_inference_fn) { + // There is nothing more we can infer + return Status::OK(); + } + + return RunShapeFn(node, op_reg_data, node_context); +} + Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node, int dst_idx, bool* evaluated, Tensor* result) { @@ -463,4 +457,93 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, return Status::OK(); } +Status ShapeRefiner::RunShapeFn(const Node* node, + const OpRegistrationData* op_reg_data, + shape_inference::InferenceContext* c) { + // This will be filled in with real data in a second pass. + std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr); + std::vector<Tensor> real_tensors(node->num_inputs()); + std::vector<bool> attempted_materialization(node->num_inputs()); + std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs()); + std::vector<ShapeHandle> input_tensors_as_shapes; + + // Run the shape inference function, and return if there was an error. + c->set_input_tensors(input_tensors); + c->set_input_tensors_as_shapes(input_tensors_as_shapes); + if (op_reg_data->shape_inference_fn) { + TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn)); + } else { + TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape)); + } + + // We must run the shape function repeatedly, in case users write + // shape functions where they only conditionally call input_tensor() + // based on the values of another input tensor. + bool rerun_shape_fn; + do { + // If the result of running shape inference would have benefitted + // from knowing the values of input tensors, try to materialize + // the results of those tensors, and then run the shape inference + // function again using those known tensors. + rerun_shape_fn = false; + + // NOTE: It is possible to batch the extraction and + // materialization of inputs, instead of materializing one input + // at a time like we do below. If input-at-a-time computation + // becomes a bottleneck, we could separate ExtractConstantSubgraph + // into two functions: one that returns true if an input is + // derivable from constants, and another function that extracts + // the subgraph for multiple target nodes and executes the whole + // subgraph once. + + for (int i = 0; i < c->num_inputs(); ++i) { + if (!c->requested_input_tensor(i)) { + continue; + } + // Check if we have not already filled in the requested input, + // and if not, try to materialize the tensors. + if (!attempted_materialization[i]) { + attempted_materialization[i] = true; + + Tensor result; + bool evaluated = false; + TF_RETURN_IF_ERROR( + EvaluateConstantTensorForEdge(node, i, &evaluated, &result)); + if (evaluated) { + real_tensors[i] = result; + input_tensors[i] = &real_tensors[i]; + // We have more concrete information about a shape, + // so re-run shape inference. + rerun_shape_fn = true; + } + } + if (c->requested_input_tensor_as_partial_shape(i) && + !attempted_tensor_as_shape_conversion[i]) { + attempted_tensor_as_shape_conversion[i] = true; + if (i >= input_tensors_as_shapes.size()) { + input_tensors_as_shapes.resize(i + 1); + } + ShapeHandle s; + TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s)); + input_tensors_as_shapes[i] = s; + rerun_shape_fn = true; + } + } + + if (rerun_shape_fn) { + // We have more information about the shapes on this pass, + // so re-run shape inference. + c->set_input_tensors(input_tensors); + c->set_input_tensors_as_shapes(input_tensors_as_shapes); + if (op_reg_data->shape_inference_fn) { + TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c)); + } else { + TF_RETURN_IF_ERROR(shape_inference::UnknownShape(c)); + } + } + } while (rerun_shape_fn); + + return Status::OK(); +} + } // namespace tensorflow |