aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/remote_fused_graph_execute_utils.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-13 12:49:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-13 14:04:46 -0700
commit7b8e31c58140fe6c6bdd3a0d946b978c2a216702 (patch)
treea395da22f5c9cee3f184cb5b7180620d5b62fb84 /tensorflow/core/kernels/remote_fused_graph_execute_utils.h
parent6121fe5be59bb13ca71fd6992239b4f99e43a2be (diff)
Factor out shape inference propagation to RemoteFusedGraphExecuteUtils
Change: 149984977
Diffstat (limited to 'tensorflow/core/kernels/remote_fused_graph_execute_utils.h')
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils.h11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h
index 21484a3a74..48a128ec07 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h
@@ -20,6 +20,8 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
@@ -89,6 +91,15 @@ class RemoteFusedGraphExecuteUtils {
const std::vector<TensorShape>& shapes,
NodeDef* node_def);
+ static Status PropagateShapeInference(
+ const GraphDef& graph_def,
+ const std::vector<std::pair<string, Tensor>>& input_node_info_list,
+ Graph* graph, ShapeRefiner* shape_refiner);
+
+ static Status BuildTensorShapeMapFromGraph(const Graph& graph,
+ const ShapeRefiner& shape_refiner,
+ TensorShapeMap* tensor_shape_map);
+
private:
static ExecutorBuildRegistry* GetExecutorBuildRegistry();