#region Copyright notice and license
// Copyright 2015 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#endregion
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Grpc.Core;
using Grpc.Core.Internal;
using Grpc.Core.Utils;
using NUnit.Framework;
namespace Grpc.Core.Tests
{
///
/// Tests for response headers support.
///
public class ResponseHeadersTest
{
MockServiceHelper helper;
Server server;
Channel channel;
Metadata headers;
[SetUp]
public void Init()
{
helper = new MockServiceHelper();
server = helper.GetServer();
server.Start();
channel = helper.GetChannel();
headers = new Metadata { { "ascii-header", "abcdefg" } };
}
[TearDown]
public void Cleanup()
{
channel.ShutdownAsync().Wait();
server.ShutdownAsync().Wait();
}
[Test]
public async Task ResponseHeadersAsync_UnaryCall()
{
helper.UnaryHandler = new UnaryServerMethod(async (request, context) =>
{
await context.WriteResponseHeadersAsync(headers);
return "PASS";
});
var call = Calls.AsyncUnaryCall(helper.CreateUnaryCall(), "");
var responseHeaders = await call.ResponseHeadersAsync;
Assert.AreEqual(headers.Count, responseHeaders.Count);
Assert.AreEqual("ascii-header", responseHeaders[0].Key);
Assert.AreEqual("abcdefg", responseHeaders[0].Value);
Assert.AreEqual("PASS", await call.ResponseAsync);
}
[Test]
public async Task ResponseHeadersAsync_ClientStreamingCall()
{
helper.ClientStreamingHandler = new ClientStreamingServerMethod(async (requestStream, context) =>
{
await context.WriteResponseHeadersAsync(headers);
return "PASS";
});
var call = Calls.AsyncClientStreamingCall(helper.CreateClientStreamingCall());
await call.RequestStream.CompleteAsync();
var responseHeaders = await call.ResponseHeadersAsync;
Assert.AreEqual("ascii-header", responseHeaders[0].Key);
Assert.AreEqual("PASS", await call.ResponseAsync);
}
[Test]
public async Task ResponseHeadersAsync_ServerStreamingCall()
{
helper.ServerStreamingHandler = new ServerStreamingServerMethod(async (request, responseStream, context) =>
{
await context.WriteResponseHeadersAsync(headers);
await responseStream.WriteAsync("PASS");
});
var call = Calls.AsyncServerStreamingCall(helper.CreateServerStreamingCall(), "");
var responseHeaders = await call.ResponseHeadersAsync;
Assert.AreEqual("ascii-header", responseHeaders[0].Key);
CollectionAssert.AreEqual(new[] { "PASS" }, await call.ResponseStream.ToListAsync());
}
[Test]
public async Task ResponseHeadersAsync_DuplexStreamingCall()
{
helper.DuplexStreamingHandler = new DuplexStreamingServerMethod(async (requestStream, responseStream, context) =>
{
await context.WriteResponseHeadersAsync(headers);
while (await requestStream.MoveNext())
{
await responseStream.WriteAsync(requestStream.Current);
}
});
var call = Calls.AsyncDuplexStreamingCall(helper.CreateDuplexStreamingCall());
var responseHeaders = await call.ResponseHeadersAsync;
var messages = new[] { "PASS" };
await call.RequestStream.WriteAllAsync(messages);
Assert.AreEqual("ascii-header", responseHeaders[0].Key);
CollectionAssert.AreEqual(messages, await call.ResponseStream.ToListAsync());
}
[Test]
public void WriteResponseHeaders_NullNotAllowed()
{
helper.UnaryHandler = new UnaryServerMethod((request, context) =>
{
Assert.ThrowsAsync(typeof(ArgumentNullException), async () => await context.WriteResponseHeadersAsync(null));
return Task.FromResult("PASS");
});
Assert.AreEqual("PASS", Calls.BlockingUnaryCall(helper.CreateUnaryCall(), ""));
}
[Test]
public void WriteResponseHeaders_AllowedOnlyOnce()
{
helper.UnaryHandler = new UnaryServerMethod(async (request, context) =>
{
await context.WriteResponseHeadersAsync(headers);
try
{
await context.WriteResponseHeadersAsync(headers);
Assert.Fail();
}
catch (InvalidOperationException)
{
}
return "PASS";
});
Assert.AreEqual("PASS", Calls.BlockingUnaryCall(helper.CreateUnaryCall(), ""));
}
[Test]
public async Task WriteResponseHeaders_NotAllowedAfterWrite()
{
helper.ServerStreamingHandler = new ServerStreamingServerMethod(async (request, responseStream, context) =>
{
await responseStream.WriteAsync("A");
try
{
await context.WriteResponseHeadersAsync(headers);
Assert.Fail();
}
catch (InvalidOperationException)
{
}
await responseStream.WriteAsync("B");
});
var call = Calls.AsyncServerStreamingCall(helper.CreateServerStreamingCall(), "");
var responses = await call.ResponseStream.ToListAsync();
CollectionAssert.AreEqual(new[] { "A", "B" }, responses);
}
}
}