diff options
Diffstat (limited to 'tensorflow/compiler/jit/mark_for_compilation_pass.h')
-rw-r--r-- | tensorflow/compiler/jit/mark_for_compilation_pass.h | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index e9acbfb19e..f1137af3c1 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -40,20 +40,18 @@ class MarkForCompilationPass : public GraphOptimizationPass { Status Run(const GraphOptimizationPassOptions& options) override; - // Run() just calls RunImpl() if --tf_xla_auto_jit is enabled. To run the pass - // unconditionally, call RunImpl() directly. - // is_compilable_fn, if set, is a predicate that must be true for a node to - // be compiled. + private: Status RunImpl(const GraphOptimizationPassOptions& options, const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn = {}); + + friend class MarkForCompilationPassTestHelper; }; // Returns true iff 'ndef' is a call to a function that is compilable. A // function is compilable iff every operator in the function body is // compilable. bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ |