aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc')
-rw-r--r--tensorflow/cc/BUILD1
-rw-r--r--tensorflow/cc/client/client_session.cc18
-rw-r--r--tensorflow/cc/client/client_session.h28
-rw-r--r--tensorflow/cc/client/client_session_test.cc21
-rw-r--r--tensorflow/cc/framework/gradient_checker.cc2
-rw-r--r--tensorflow/cc/gradients/image_grad_test.cc14
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc6
-rw-r--r--tensorflow/cc/gradients/nn_grad.cc85
-rw-r--r--tensorflow/cc/gradients/nn_grad_test.cc22
-rw-r--r--tensorflow/cc/saved_model/loader.cc95
10 files changed, 229 insertions, 63 deletions
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index d686ccfe29..588a45ea43 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -349,6 +349,7 @@ tf_cc_test(
srcs = ["gradients/image_grad_test.cc"],
deps = [
":cc_ops",
+ ":client_session",
":grad_op_registry",
":grad_testutil",
":gradient_checker",
diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc
index ba056a8f3a..0e61089a59 100644
--- a/tensorflow/cc/client/client_session.cc
+++ b/tensorflow/cc/client/client_session.cc
@@ -127,4 +127,22 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
target_node_names, outputs, run_metadata);
}
+Status ClientSession::MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) {
+ TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
+ return impl()->session_->MakeCallable(callable_options, out_handle);
+}
+
+Status ClientSession::RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors,
+ run_metadata);
+}
+
+Status ClientSession::ReleaseCallable(CallableHandle handle) {
+ return impl()->session_->ReleaseCallable(handle);
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h
index 5fb4109f7d..7dd653eec4 100644
--- a/tensorflow/cc/client/client_session.h
+++ b/tensorflow/cc/client/client_session.h
@@ -87,7 +87,33 @@ class ClientSession {
const std::vector<Operation>& run_outputs,
std::vector<Tensor>* outputs, RunMetadata* run_metadata) const;
- // TODO(keveman): Add support for partial run.
+ /// \brief A handle to a subgraph, created with
+ /// `ClientSession::MakeCallable()`.
+ typedef int64 CallableHandle;
+
+ /// \brief Creates a `handle` for invoking the subgraph defined by
+ /// `callable_options`.
+ /// NOTE: This API is still experimental and may change.
+ Status MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle);
+
+ /// \brief Invokes the subgraph named by `handle` with the given options and
+ /// input tensors.
+ ///
+ /// The order of tensors in `feed_tensors` must match the order of names in
+ /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will
+ /// match the order of names in `CallableOptions::fetch()` when this subgraph
+ /// was created.
+ /// NOTE: This API is still experimental and may change.
+ Status RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata);
+
+ /// \brief Releases resources associated with the given `handle` in this
+ /// session.
+ /// NOTE: This API is still experimental and may change.
+ Status ReleaseCallable(CallableHandle handle);
private:
class Impl;
diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc
index ea5cf5a1f1..559ffea7e8 100644
--- a/tensorflow/cc/client/client_session_test.cc
+++ b/tensorflow/cc/client/client_session_test.cc
@@ -95,5 +95,26 @@ TEST(ClientSessionTest, MultiThreaded) {
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
}
+TEST(ClientSessionTest, Callable) {
+ Scope root = Scope::NewRootScope();
+ auto a = Placeholder(root, DT_INT32);
+ auto b = Placeholder(root, DT_INT32);
+ auto c = Add(root, a, b);
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+
+ CallableOptions options;
+ options.add_feed(a.node()->name());
+ options.add_feed(b.node()->name());
+ options.add_fetch(c.node()->name());
+ ClientSession::CallableHandle callable;
+ TF_CHECK_OK(session.MakeCallable(options, &callable));
+ TF_EXPECT_OK(session.RunCallable(
+ callable, {test::AsTensor<int>({1}, {}), test::AsTensor<int>({41}, {})},
+ &outputs, nullptr));
+ test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({42}, {}));
+ TF_EXPECT_OK(session.ReleaseCallable(callable));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc
index 695180c23b..a1eb0d9d08 100644
--- a/tensorflow/cc/framework/gradient_checker.cc
+++ b/tensorflow/cc/framework/gradient_checker.cc
@@ -247,7 +247,7 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs,
auto y_pos_flat = y_pos[y_idx].flat<Y_T>();
auto y_neg_flat = y_neg[y_idx].flat<Y_T>();
const int64 y_size = y_shapes[y_idx].num_elements();
- const Y_T scale = Y_T{2 * delta};
+ const Y_T scale = 2 * delta;
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<JAC_T>();
for (int c = 0; c < y_size; ++c) {
SetJacobian<Y_T, JAC_T>(&jacobian, r * x_stride + unit_dimension,
diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc
index b9271522ed..2e55c7561b 100644
--- a/tensorflow/cc/gradients/image_grad_test.cc
+++ b/tensorflow/cc/gradients/image_grad_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/gradients/grad_testutil.h"
-#include "tensorflow/cc/ops/image_ops_internal.h"
+#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -27,8 +27,8 @@ namespace tensorflow {
namespace {
using ops::Const;
-using ops::ResizeBilinear;
using ops::ResizeBicubic;
+using ops::ResizeBilinear;
using ops::ResizeNearestNeighbor;
class ImageGradTest : public ::testing::Test {
@@ -38,7 +38,7 @@ class ImageGradTest : public ::testing::Test {
enum OpType { RESIZE_NEAREST, RESIZE_BILINEAR, RESIZE_BICUBIC };
template <typename T>
- Tensor MakeData(TensorShape& data_shape) {
+ Tensor MakeData(const TensorShape& data_shape) {
DataType data_type = DataTypeToEnum<T>::v();
Tensor data(data_type, data_shape);
auto data_flat = data.flat<T>();
@@ -57,15 +57,15 @@ class ImageGradTest : public ::testing::Test {
*y = ResizeNearestNeighbor(
scope_, *x, y_shape,
ResizeNearestNeighbor::AlignCorners(align_corners));
- break;
+ return;
case RESIZE_BILINEAR:
*y = ResizeBilinear(scope_, *x, y_shape,
ResizeBilinear::AlignCorners(align_corners));
- break;
+ return;
case RESIZE_BICUBIC:
*y = ResizeBicubic(scope_, *x, y_shape,
ResizeBicubic::AlignCorners(align_corners));
- break;
+ return;
}
assert(false);
}
@@ -79,7 +79,7 @@ class ImageGradTest : public ::testing::Test {
ClientSession session(scope_);
std::vector<Tensor> outputs;
- TF_ASSERT_OK(session.Run({}, {y}, &outputs));
+ TF_ASSERT_OK(session.Run({y}, &outputs));
EXPECT_EQ(outputs.size(), 1);
EXPECT_EQ(outputs[0].shape(), TensorShape({1, 4, 6, 1}));
}
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index fd7b6fe662..1c9bdff5e1 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -475,11 +475,7 @@ TEST_F(CWiseUnaryGradTest, Tan_Complex) {
auto x_fn = [this](const int i) {
return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
};
- // TODO(kbsriram)
- // Enable when tan kernel supports complex inputs
- if (false) {
- TestCWiseGrad<complex64, complex64>(TAN, x_fn);
- }
+ TestCWiseGrad<complex64, complex64>(TAN, x_fn);
}
TEST_F(CWiseUnaryGradTest, Atan) {
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc
index c73482d5f4..588e96cb19 100644
--- a/tensorflow/cc/gradients/nn_grad.cc
+++ b/tensorflow/cc/gradients/nn_grad.cc
@@ -47,6 +47,72 @@ Status SoftmaxGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad);
+bool IsZero(const Scope& scope, const Output& grad) {
+ string op_type_name = grad.op().node()->type_string();
+ if (op_type_name == "ZerosLike" || op_type_name == "Zeros") {
+ return true;
+ }
+ // The Operation we were provided is not named something obvious so
+ // we need to actually look at its contents.
+ // The original python code did this by calling a utility function called
+ // tensor_util.constant_value.
+ // There is no C++ equivalent to tensor_util.constant_value so we do nothing
+ // for the moment.
+ return false;
+}
+
+// Multiply after broadcasting vec to match dimensions of mat.
+// Args:
+// vec: A 1-D tensor of dimension [D0]
+// mat: A 2-D tensor of dimesnion [D0, D1]
+//
+// Returns:
+// A tensor of dimension [D0, D1], the result fo vec * mat.
+Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) {
+ auto reshaped = ExpandDims(scope, vec, -1);
+ return Multiply(scope, reshaped, mat);
+}
+
+Status SoftmaxCrossEntropyWithLogitsGrad(const Scope& scope,
+ const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // Softmax gradient with cross entropy logits function.
+ // We multiply the backprop for cost with the gradients - op.output[1].
+ // There is no gradient for labels.
+
+ // The outputs of the network are at input index 0.
+ auto logits = op.input(0);
+ // The "truth" labels are at index 1.
+ auto softmax_grad = op.output(1);
+
+ // The loss is the output at index 0, and backprop is the output at index 1.
+ auto grad_loss = grad_inputs[0];
+ auto grad_grad = grad_inputs[1];
+
+ auto grad = BroadcastMul(scope, grad_loss, softmax_grad);
+ if (!IsZero(scope, grad_grad)) {
+ std::vector<int> axis;
+ auto logits_softmax = Softmax(scope, logits);
+
+ auto grad_grad_expand = ExpandDims(scope, grad_grad, 1);
+ auto logits_softmax_expand = ExpandDims(scope, logits_softmax, 2);
+ auto matmul_result =
+ BatchMatMul(scope, grad_grad_expand, logits_softmax_expand);
+ axis.push_back(1);
+ auto squeeze_result = Squeeze(scope, matmul_result, Squeeze::Axis(axis));
+ auto subtraction_result = Subtract(scope, grad_grad, squeeze_result);
+ auto multiply_result = Multiply(scope, subtraction_result, logits_softmax);
+ grad = Add(scope, grad, multiply_result);
+ }
+ auto minus_log_softmax = Multiply(scope, LogSoftmax(scope, logits), -1.0f);
+ grad_outputs->push_back(grad);
+ grad_outputs->push_back(BroadcastMul(scope, grad_loss, minus_log_softmax));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("SoftmaxCrossEntropyWithLogits",
+ SoftmaxCrossEntropyWithLogitsGrad);
+
Status LogSoftmaxGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
@@ -195,9 +261,9 @@ Status MaxPool3DGradHelper(const Scope& scope, const Operation& op,
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
MaxPool3DGrad::Attrs grad_attrs;
- auto dx = MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0],
- ksize, strides, padding,
- grad_attrs.DataFormat(data_format));
+ auto dx =
+ MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], ksize,
+ strides, padding, grad_attrs.DataFormat(data_format));
grad_outputs->push_back(dx);
return scope.status();
}
@@ -216,10 +282,9 @@ Status AvgPoolGradHelper(const Scope& scope, const Operation& op,
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
internal::AvgPoolGrad::Attrs grad_attrs;
- auto dx =
- internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), grad_inputs[0],
- ksize, strides, padding,
- grad_attrs.DataFormat(data_format));
+ auto dx = internal::AvgPoolGrad(scope, Shape(scope, op.input(0)),
+ grad_inputs[0], ksize, strides, padding,
+ grad_attrs.DataFormat(data_format));
grad_outputs->push_back(dx);
return scope.status();
}
@@ -238,9 +303,9 @@ Status AvgPool3DGradHelper(const Scope& scope, const Operation& op,
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
AvgPool3DGrad::Attrs grad_attrs;
- auto dx = AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0],
- ksize, strides, padding,
- grad_attrs.DataFormat(data_format));
+ auto dx =
+ AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], ksize,
+ strides, padding, grad_attrs.DataFormat(data_format));
grad_outputs->push_back(dx);
return scope.status();
}
diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc
index b4d457a9d1..aa72cf7ba2 100644
--- a/tensorflow/cc/gradients/nn_grad_test.cc
+++ b/tensorflow/cc/gradients/nn_grad_test.cc
@@ -25,6 +25,8 @@ limitations under the License.
namespace tensorflow {
namespace {
+using ops::AvgPool;
+using ops::AvgPool3D;
using ops::BiasAdd;
using ops::Conv2D;
using ops::Elu;
@@ -33,11 +35,9 @@ using ops::FractionalMaxPool;
using ops::L2Loss;
using ops::LogSoftmax;
using ops::LRN;
-using ops::AvgPool;
-using ops::AvgPool3D;
using ops::MaxPool;
-using ops::MaxPoolV2;
using ops::MaxPool3D;
+using ops::MaxPoolV2;
using ops::Placeholder;
using ops::Relu;
using ops::Relu6;
@@ -111,6 +111,20 @@ TEST_F(NNGradTest, SoftmaxGrad) {
RunTest(x, shape, y, shape);
}
+TEST_F(NNGradTest, SoftmaxCrossEntropyWithLogitsGrad) {
+ TensorShape logits_shape({5, 3});
+ TensorShape loss_shape({5});
+
+ auto logits = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape));
+ auto labels = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape));
+ auto y =
+ tensorflow::ops::SoftmaxCrossEntropyWithLogits(scope_, logits, labels);
+ // Note the reversal of the backprop and loss orders. Issue #18734 has been
+ // opened for this.
+ RunTest({logits, labels}, {logits_shape, logits_shape}, {y.backprop, y.loss},
+ {logits_shape, loss_shape});
+}
+
TEST_F(NNGradTest, LogSoftmaxGrad) {
TensorShape shape({5, 3});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
@@ -253,7 +267,7 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
RunTest(x, x_shape, y, y_shape);
}
-TEST_F(NNGradTest, LRN){
+TEST_F(NNGradTest, LRN) {
TensorShape x_shape({1, 1, 2, 1});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
auto y = LRN(scope_, x);
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc
index 07807ed2f3..98be66a6ad 100644
--- a/tensorflow/cc/saved_model/loader.cc
+++ b/tensorflow/cc/saved_model/loader.cc
@@ -74,6 +74,54 @@ void AddAssetsTensorsToInputs(const StringPiece export_dir,
}
}
+// Like Session::Run(), but uses the Make/Run/ReleaseCallable() API to avoid
+// leaving behind non-GC'ed state.
+//
+// Detailed motivation behind this approach, from ashankar@:
+//
+// Each call to Session::Run() that identifies a new subgraph (based on feeds
+// and fetches) creates some datastructures that live as long as the session
+// (the partitioned graph, associated executors etc.).
+//
+// A pathological case of this would be if say the initialization op
+// (main_op/legacy_init_op) involves the use of a large constant. Then we
+// allocate memory for that large constant that will just stick around till the
+// session dies. With this Callable mechanism, that memory will be released
+// right after ReleaseCallable returns.
+//
+// However, the resource manager state remains.
+Status RunOnce(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor>>& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunMetadata* run_metadata,
+ Session* session) {
+ CallableOptions callable_options;
+ std::vector<Tensor> feed_tensors;
+ *callable_options.mutable_run_options() = run_options;
+ for (const auto& input : inputs) {
+ const string& name = input.first;
+ const Tensor& tensor = input.second;
+ callable_options.add_feed(name);
+ feed_tensors.push_back(tensor);
+ }
+ for (const string& output_tensor_name : output_tensor_names) {
+ callable_options.add_fetch(output_tensor_name);
+ }
+ for (const string& target_node_name : target_node_names) {
+ callable_options.add_target(target_node_name);
+ }
+
+ Session::CallableHandle callable_handle;
+ TF_RETURN_IF_ERROR(session->MakeCallable(callable_options, &callable_handle));
+ const Status run_status = session->RunCallable(callable_handle, feed_tensors,
+ outputs, run_metadata);
+ // Be sure to call ReleaseCallable() regardless of the outcome of
+ // RunCallable().
+ session->ReleaseCallable(callable_handle).IgnoreError();
+ return run_status;
+}
+
bool HasMainOp(const MetaGraphDef& meta_graph_def) {
const auto& collection_def_map = meta_graph_def.collection_def();
if (collection_def_map.find(kSavedModelMainOpKey) !=
@@ -86,10 +134,11 @@ bool HasMainOp(const MetaGraphDef& meta_graph_def) {
Status RunMainOp(const RunOptions& run_options, const string& export_dir,
const MetaGraphDef& meta_graph_def,
const std::vector<AssetFileDef>& asset_file_defs,
- Session* session) {
- LOG(INFO) << "Running MainOp on SavedModel bundle.";
+ Session* session, const string& main_op_key) {
+ LOG(INFO) << "Running MainOp with key " << main_op_key
+ << " on SavedModel bundle.";
const auto& collection_def_map = meta_graph_def.collection_def();
- const auto main_op_it = collection_def_map.find(kSavedModelMainOpKey);
+ const auto main_op_it = collection_def_map.find(main_op_key);
if (main_op_it != collection_def_map.end()) {
if (main_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
@@ -99,8 +148,8 @@ Status RunMainOp(const RunOptions& run_options, const string& export_dir,
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata;
const StringPiece main_op_name = main_op_it->second.node_list().value(0);
- return session->Run(run_options, inputs, {}, {main_op_name.ToString()},
- nullptr /* outputs */, &run_metadata);
+ return RunOnce(run_options, inputs, {}, {main_op_name.ToString()},
+ nullptr /* outputs */, &run_metadata, session);
}
return Status::OK();
}
@@ -137,32 +186,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata;
- return session->Run(run_options, inputs, {}, {restore_op_name.ToString()},
- nullptr /* outputs */, &run_metadata);
-}
-
-Status RunLegacyInitOp(const RunOptions& run_options, const string& export_dir,
- const MetaGraphDef& meta_graph_def,
- const std::vector<AssetFileDef>& asset_file_defs,
- Session* session) {
- LOG(INFO) << "Running LegacyInitOp on SavedModel bundle.";
- const auto& collection_def_map = meta_graph_def.collection_def();
- const auto init_op_it = collection_def_map.find(kSavedModelLegacyInitOpKey);
- if (init_op_it != collection_def_map.end()) {
- if (init_op_it->second.node_list().value_size() != 1) {
- return errors::FailedPrecondition(strings::StrCat(
- "Expected exactly one serving init op in : ", export_dir));
- }
- std::vector<std::pair<string, Tensor>> inputs;
- AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
- RunMetadata run_metadata;
- const StringPiece legacy_init_op_name =
- init_op_it->second.node_list().value(0);
- return session->Run(run_options, inputs, {},
- {legacy_init_op_name.ToString()}, nullptr /* outputs */,
- &run_metadata);
- }
- return Status::OK();
+ return RunOnce(run_options, inputs, {}, {restore_op_name.ToString()},
+ nullptr /* outputs */, &run_metadata, session);
}
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
@@ -204,11 +229,11 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
if (HasMainOp(bundle->meta_graph_def)) {
TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir,
bundle->meta_graph_def, asset_file_defs,
- bundle->session.get()));
+ bundle->session.get(), kSavedModelMainOpKey));
} else {
- TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir,
- bundle->meta_graph_def, asset_file_defs,
- bundle->session.get()));
+ TF_RETURN_IF_ERROR(RunMainOp(
+ run_options, export_dir, bundle->meta_graph_def, asset_file_defs,
+ bundle->session.get(), kSavedModelLegacyInitOpKey));
}
return Status::OK();
}