diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_proto_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_proto_util.cc | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index b9c0b0c4ee..026a0e8fba 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include <string> @@ -36,6 +37,17 @@ HloProto MakeHloProto(const HloModule& module) { return proto; } +StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config) { + TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module, + HloModule::CreateFromProto(proto, module_config)); + TF_RETURN_IF_ERROR( + HloVerifier(/*layout_sensitive=*/true, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status()); + return std::move(module); +} + StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes( const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { |