aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/shape_refiner.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/shape_refiner.h')
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.h8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h
index 2d04ea1505..9709bd0302 100644
--- a/tensorflow/core/common_runtime/shape_refiner.h
+++ b/tensorflow/core/common_runtime/shape_refiner.h
@@ -55,6 +55,11 @@ class ShapeRefiner {
Status SetShape(const Node* node, int output_port,
shape_inference::ShapeHandle shape);
+ // Update the input shapes of node in case the shapes of the fan-ins of 'node'
+ // have themselves been modified (For example, in case of incremental shape
+ // refinement). Sets refined to true if any of the node shape has changed.
+ Status UpdateNode(const Node* node, bool* refined);
+
// Returns the InferenceContext for 'node', if present.
shape_inference::InferenceContext* GetContext(const Node* node) const {
auto it = node_to_context_.find(node);
@@ -108,6 +113,9 @@ class ShapeRefiner {
const Node* node, int dst_idx,
shape_inference::ShapeHandle* result);
+ Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
+ shape_inference::InferenceContext* c);
+
int32 graph_def_version_;
const OpRegistryInterface* const ops_registry_;