aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-18 15:20:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-18 15:23:57 -0700
commitd5f4d9bbac520ad9eae6614fe678e9d1568435a4 (patch)
tree4d0e5115006e12127650b4f97d9b34942938abf3 /tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h
parentd8672f1839df67f765baaa34a4c806ee1d433842 (diff)
Add a way to fuse a graph by remote graph executor so that users don't need to be aware of supported op types, node names, subgraph stracture etc.
PiperOrigin-RevId: 162411763
Diffstat (limited to 'tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h')
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h25
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h
index a0df50162b..3fa052108e 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
@@ -59,6 +60,30 @@ class RemoteFusedGraphExecuteOpTestUtils {
TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOpTestUtils);
};
+class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
+ public:
+ TestRemoteFusedGraphExecutor(const std::unordered_set<string>& fused_op_types,
+ const string& executor_name);
+
+ int GetVersion() final;
+ bool Init(const RemoteFusedGraphExecuteInfo&) final;
+ bool Finalize() final;
+ bool SetupGraph() final;
+ bool ExecuteGraph() final;
+ bool TeardownGraph() final;
+ bool FillInputNode(const string&, const Tensor&) final;
+ bool ReadOutputNode(const string&, TensorAllocatorFunc) final;
+ Status FuseRemoteGraph(const GraphDef& original_graph_def,
+ const std::vector<string>& inputs,
+ const std::vector<string>& outputs,
+ GraphDef* fused_graph_def) final;
+ bool IsEnabled() const final;
+
+ private:
+ const std::unordered_set<string> fused_op_types_;
+ const string executor_name_;
+};
+
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_