diff options
Diffstat (limited to 'tensorflow/core/common_runtime/shape_refiner.h')
-rw-r--r-- | tensorflow/core/common_runtime/shape_refiner.h | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index 2d04ea1505..9709bd0302 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -55,6 +55,11 @@ class ShapeRefiner { Status SetShape(const Node* node, int output_port, shape_inference::ShapeHandle shape); + // Update the input shapes of node in case the shapes of the fan-ins of 'node' + // have themselves been modified (For example, in case of incremental shape + // refinement). Sets refined to true if any of the node shape has changed. + Status UpdateNode(const Node* node, bool* refined); + // Returns the InferenceContext for 'node', if present. shape_inference::InferenceContext* GetContext(const Node* node) const { auto it = node_to_context_.find(node); @@ -108,6 +113,9 @@ class ShapeRefiner { const Node* node, int dst_idx, shape_inference::ShapeHandle* result); + Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, + shape_inference::InferenceContext* c); + int32 graph_def_version_; const OpRegistryInterface* const ops_registry_; |