aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r--tensorflow/c/c_api_test.cc40
1 files changed, 22 insertions, 18 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index adcdefbaf3..7094d5d32d 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -1474,19 +1474,17 @@ class CApiGradientsTest : public ::testing::Test {
TF_DeleteStatus(s_);
}
- void TestGradientsSuccess(bool grad_inputs_provided,
- const char* prefix = nullptr) {
+ void TestGradientsSuccess(bool grad_inputs_provided) {
TF_Output inputs[2];
TF_Output outputs[1];
TF_Output grad_outputs[2];
TF_Output expected_grad_outputs[2];
BuildSuccessGraph(inputs, outputs);
- BuildExpectedGraph(grad_inputs_provided, prefix, expected_grad_outputs);
+ BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs);
- AddGradients(grad_inputs_provided, prefix, inputs, 2, outputs, 1,
+ AddGradients(grad_inputs_provided, nullptr, inputs, 2, outputs, 1,
grad_outputs);
-
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
// Compare that the graphs match.
@@ -1604,7 +1602,6 @@ class CApiGradientsTest : public ::testing::Test {
}
void BuildExpectedGraph(bool grad_inputs_provided,
- const char* grad_prefix,
TF_Output* expected_grad_outputs) {
// The expected graph looks like this if grad_inputs_provided.
// If grad_inputs_provided is false, Const_0 will be a OnesLike op.
@@ -1633,10 +1630,6 @@ class CApiGradientsTest : public ::testing::Test {
//
const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
- const char* prefix = grad_prefix;
- if (prefix == nullptr) {
- prefix = "gradients";
- }
TF_Operation* const0 =
FloatConst2x2(expected_graph_, s_, const0_val, "Const_0");
TF_Operation* const1 =
@@ -1649,14 +1642,13 @@ class CApiGradientsTest : public ::testing::Test {
const float const3_val[] = {1.0, 1.0, 1.0, 1.0};
const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs");
} else {
- const3 = OnesLike(expected_graph_, s_, matmul,
- strings::StrCat(prefix, "/OnesLike").c_str());
+ const3 = OnesLike(expected_graph_, s_, matmul, "gradients/OnesLike");
}
TF_Operation* matmul1 = MatMul(expected_graph_, s_, const3, const1,
- strings::StrCat(prefix, "/MatMul").c_str(), false, true);
+ "gradients/MatMul", false, true);
TF_Operation* matmul2 = MatMul(expected_graph_, s_, const0, const3,
- strings::StrCat(prefix, "/MatMul_1").c_str(), true, false);
+ "gradients/MatMul_1", true, false);
expected_grad_outputs[0] = {matmul1, 0};
expected_grad_outputs[1] = {matmul2, 0};
}
@@ -1727,10 +1719,6 @@ TEST_F(CApiGradientsTest, Gradients_NoGradInputs) {
TestGradientsSuccess(false);
}
-TEST_F(CApiGradientsTest, Gradients_NoGradInputsWithScopeName) {
- TestGradientsSuccess(false, "gradscope");
-}
-
TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_GradInputs) {
TestGradientsError(true);
}
@@ -1739,6 +1727,22 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
TestGradientsError(false);
}
+TEST_F(CApiGradientsTest, Gradients_WithPrefix) {
+ TF_Output inputs[2];
+ TF_Output outputs[1];
+ TF_Output grad_outputs[2];
+
+ BuildSuccessGraph(inputs, outputs);
+ AddGradients(false, "mygrads", inputs, 2, outputs, 1, grad_outputs);
+ EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ AddGradients(false, "mygrads_1", inputs, 2, outputs, 1, grad_outputs);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ AddGradients(false, "mygrads_1", inputs, 2, outputs, 1, grad_outputs);
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
void ScalarFloatFromTensor(const TF_Tensor* t, float* f) {
ASSERT_TRUE(t != nullptr);
ASSERT_EQ(TF_FLOAT, TF_TensorType(t));