aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_proto_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_proto_util.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util.cc12
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc
index b9c0b0c4ee..026a0e8fba 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include <string>
@@ -36,6 +37,17 @@ HloProto MakeHloProto(const HloModule& module) {
return proto;
}
+StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
+ const HloModuleProto& proto, const HloModuleConfig& module_config) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
+ HloModule::CreateFromProto(proto, module_config));
+ TF_RETURN_IF_ERROR(
+ HloVerifier(/*layout_sensitive=*/true, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status());
+ return std::move(module);
+}
+
StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
const HloProto& hlo_proto) {
if (!hlo_proto.has_hlo_module()) {