diff options
author | 2017-03-13 12:49:15 -0800 | |
---|---|---|
committer | 2017-03-13 14:04:46 -0700 | |
commit | 7b8e31c58140fe6c6bdd3a0d946b978c2a216702 (patch) | |
tree | a395da22f5c9cee3f184cb5b7180620d5b62fb84 /tensorflow/core/kernels/remote_fused_graph_execute_utils.h | |
parent | 6121fe5be59bb13ca71fd6992239b4f99e43a2be (diff) |
Factor out shape inference propagation to RemoteFusedGraphExecuteUtils
Change: 149984977
Diffstat (limited to 'tensorflow/core/kernels/remote_fused_graph_execute_utils.h')
-rw-r--r-- | tensorflow/core/kernels/remote_fused_graph_execute_utils.h | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h index 21484a3a74..48a128ec07 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" @@ -89,6 +91,15 @@ class RemoteFusedGraphExecuteUtils { const std::vector<TensorShape>& shapes, NodeDef* node_def); + static Status PropagateShapeInference( + const GraphDef& graph_def, + const std::vector<std::pair<string, Tensor>>& input_node_info_list, + Graph* graph, ShapeRefiner* shape_refiner); + + static Status BuildTensorShapeMapFromGraph(const Graph& graph, + const ShapeRefiner& shape_refiner, + TensorShapeMap* tensor_shape_map); + private: static ExecutorBuildRegistry* GetExecutorBuildRegistry(); |