blob: 964443a9cdb039ce2a06a20c3507256503f55eaa [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/websockets/websocket_basic_stream_adapters.h"
#include <utility>
#include "base/memory/scoped_refptr.h"
#include "base/strings/string_piece.h"
#include "net/base/host_port_pair.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/base/test_completion_callback.h"
#include "net/dns/mock_host_resolver.h"
#include "net/http/http_network_session.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/client_socket_pool_manager_impl.h"
#include "net/socket/socket_test_util.h"
#include "net/socket/ssl_client_socket_pool.h"
#include "net/socket/transport_client_socket_pool.h"
#include "net/ssl/ssl_config.h"
#include "net/test/gtest_util.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
namespace test {
const char* const kGroupName = "ssl/www.example.org:443";
class WebSocketClientSocketHandleAdapterTest : public testing::Test {
protected:
WebSocketClientSocketHandleAdapterTest()
: host_port_pair_("www.example.org", 443),
socket_pool_manager_(std::make_unique<ClientSocketPoolManagerImpl>(
&net_log_,
&socket_factory_,
nullptr,
nullptr,
&host_resolver,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
"test_shard",
nullptr,
HttpNetworkSession::NORMAL_SOCKET_POOL)),
transport_params_(base::MakeRefCounted<TransportSocketParams>(
host_port_pair_,
false,
OnHostResolutionCallback(),
TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)),
ssl_params_(base::MakeRefCounted<SSLSocketParams>(transport_params_,
nullptr,
nullptr,
host_port_pair_,
SSLConfig(),
PRIVACY_MODE_DISABLED,
0,
false)) {}
bool InitClientSocketHandle(ClientSocketHandle* connection) {
TestCompletionCallback callback;
int rv = connection->Init(
kGroupName, ssl_params_, MEDIUM, SocketTag(),
ClientSocketPool::RespectLimits::ENABLED, callback.callback(),
socket_pool_manager_->GetSSLSocketPool(), NetLogWithSource());
rv = callback.GetResult(rv);
return rv == OK;
}
const HostPortPair host_port_pair_;
NetLog net_log_;
MockClientSocketFactory socket_factory_;
MockHostResolver host_resolver;
std::unique_ptr<ClientSocketPoolManagerImpl> socket_pool_manager_;
scoped_refptr<TransportSocketParams> transport_params_;
scoped_refptr<SSLSocketParams> ssl_params_;
};
TEST_F(WebSocketClientSocketHandleAdapterTest, Uninitialized) {
auto connection = std::make_unique<ClientSocketHandle>();
WebSocketClientSocketHandleAdapter adapter(std::move(connection));
EXPECT_FALSE(adapter.is_initialized());
}
TEST_F(WebSocketClientSocketHandleAdapterTest, IsInitialized) {
StaticSocketDataProvider data(nullptr, 0, nullptr, 0);
socket_factory_.AddSocketDataProvider(&data);
SSLSocketDataProvider ssl_socket_data(ASYNC, OK);
socket_factory_.AddSSLSocketDataProvider(&ssl_socket_data);
auto connection = std::make_unique<ClientSocketHandle>();
ClientSocketHandle* const connection_ptr = connection.get();
WebSocketClientSocketHandleAdapter adapter(std::move(connection));
EXPECT_FALSE(adapter.is_initialized());
EXPECT_TRUE(InitClientSocketHandle(connection_ptr));
EXPECT_TRUE(adapter.is_initialized());
}
TEST_F(WebSocketClientSocketHandleAdapterTest, Disconnect) {
StaticSocketDataProvider data(nullptr, 0, nullptr, 0);
socket_factory_.AddSocketDataProvider(&data);
SSLSocketDataProvider ssl_socket_data(ASYNC, OK);
socket_factory_.AddSSLSocketDataProvider(&ssl_socket_data);
auto connection = std::make_unique<ClientSocketHandle>();
EXPECT_TRUE(InitClientSocketHandle(connection.get()));
StreamSocket* const socket = connection->socket();
WebSocketClientSocketHandleAdapter adapter(std::move(connection));
EXPECT_TRUE(adapter.is_initialized());
EXPECT_TRUE(socket->IsConnected());
adapter.Disconnect();
EXPECT_FALSE(socket->IsConnected());
}
TEST_F(WebSocketClientSocketHandleAdapterTest, Read) {
MockRead reads[] = {MockRead(SYNCHRONOUS, "foo"), MockRead("bar")};
StaticSocketDataProvider data(reads, arraysize(reads), nullptr, 0);
socket_factory_.AddSocketDataProvider(&data);
SSLSocketDataProvider ssl_socket_data(ASYNC, OK);
socket_factory_.AddSSLSocketDataProvider(&ssl_socket_data);
auto connection = std::make_unique<ClientSocketHandle>();
EXPECT_TRUE(InitClientSocketHandle(connection.get()));
WebSocketClientSocketHandleAdapter adapter(std::move(connection));
EXPECT_TRUE(adapter.is_initialized());
// Buffer larger than each MockRead.
const int kReadBufSize = 1024;
auto read_buf = base::MakeRefCounted<IOBuffer>(kReadBufSize);
int rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionCallback());
ASSERT_EQ(3, rv);
EXPECT_EQ("foo", base::StringPiece(read_buf->data(), rv));
TestCompletionCallback callback;
rv = adapter.Read(read_buf.get(), kReadBufSize, callback.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback.WaitForResult();
ASSERT_EQ(3, rv);
EXPECT_EQ("bar", base::StringPiece(read_buf->data(), rv));
EXPECT_TRUE(data.AllReadDataConsumed());
EXPECT_TRUE(data.AllWriteDataConsumed());
}
TEST_F(WebSocketClientSocketHandleAdapterTest, ReadIntoSmallBuffer) {
MockRead reads[] = {MockRead(SYNCHRONOUS, "foo"), MockRead("bar")};
StaticSocketDataProvider data(reads, arraysize(reads), nullptr, 0);
socket_factory_.AddSocketDataProvider(&data);
SSLSocketDataProvider ssl_socket_data(ASYNC, OK);
socket_factory_.AddSSLSocketDataProvider(&ssl_socket_data);
auto connection = std::make_unique<ClientSocketHandle>();
EXPECT_TRUE(InitClientSocketHandle(connection.get()));
WebSocketClientSocketHandleAdapter adapter(std::move(connection));
EXPECT_TRUE(adapter.is_initialized());
// Buffer smaller than each MockRead.
const int kReadBufSize = 2;
auto read_buf = base::MakeRefCounted<IOBuffer>(kReadBufSize);
int rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionCallback());
ASSERT_EQ(2, rv);
EXPECT_EQ("fo", base::StringPiece(read_buf->data(), rv));
rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionCallback());
ASSERT_EQ(1, rv);
EXPECT_EQ("o", base::StringPiece(read_buf->data(), rv));
TestCompletionCallback callback1;
rv = adapter.Read(read_buf.get(), kReadBufSize, callback1.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback1.WaitForResult();
ASSERT_EQ(2, rv);
EXPECT_EQ("ba", base::StringPiece(read_buf->data(), rv));
rv = adapter.Read(read_buf.get(), kReadBufSize, CompletionCallback());
ASSERT_EQ(1, rv);
EXPECT_EQ("r", base::StringPiece(read_buf->data(), rv));
EXPECT_TRUE(data.AllReadDataConsumed());
EXPECT_TRUE(data.AllWriteDataConsumed());
}
TEST_F(WebSocketClientSocketHandleAdapterTest, Write) {
MockWrite writes[] = {MockWrite(SYNCHRONOUS, "foo"), MockWrite("bar")};
StaticSocketDataProvider data(nullptr, 0, writes, arraysize(writes));
socket_factory_.AddSocketDataProvider(&data);
SSLSocketDataProvider ssl_socket_data(ASYNC, OK);
socket_factory_.AddSSLSocketDataProvider(&ssl_socket_data);
auto connection = std::make_unique<ClientSocketHandle>();
EXPECT_TRUE(InitClientSocketHandle(connection.get()));
WebSocketClientSocketHandleAdapter adapter(std::move(connection));
EXPECT_TRUE(adapter.is_initialized());
auto write_buf1 = base::MakeRefCounted<StringIOBuffer>("foo");
int rv = adapter.Write(write_buf1.get(), write_buf1->size(),
CompletionCallback(), TRAFFIC_ANNOTATION_FOR_TESTS);
ASSERT_EQ(3, rv);
auto write_buf2 = base::MakeRefCounted<StringIOBuffer>("bar");
TestCompletionCallback callback;
rv = adapter.Write(write_buf2.get(), write_buf2->size(), callback.callback(),
TRAFFIC_ANNOTATION_FOR_TESTS);
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = callback.WaitForResult();
ASSERT_EQ(3, rv);
EXPECT_TRUE(data.AllReadDataConsumed());
EXPECT_TRUE(data.AllWriteDataConsumed());
}
// Test that if both Read() and Write() returns asynchronously,
// the two callbacks are handled correctly.
TEST_F(WebSocketClientSocketHandleAdapterTest, AsyncReadAndWrite) {
MockRead reads[] = {MockRead("foobar")};
MockWrite writes[] = {MockWrite("baz")};
StaticSocketDataProvider data(reads, arraysize(reads), writes,
arraysize(writes));
socket_factory_.AddSocketDataProvider(&data);
SSLSocketDataProvider ssl_socket_data(ASYNC, OK);
socket_factory_.AddSSLSocketDataProvider(&ssl_socket_data);
auto connection = std::make_unique<ClientSocketHandle>();
EXPECT_TRUE(InitClientSocketHandle(connection.get()));
WebSocketClientSocketHandleAdapter adapter(std::move(connection));
EXPECT_TRUE(adapter.is_initialized());
const int kReadBufSize = 1024;
auto read_buf = base::MakeRefCounted<IOBuffer>(kReadBufSize);
TestCompletionCallback read_callback;
int rv = adapter.Read(read_buf.get(), kReadBufSize, read_callback.callback());
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
auto write_buf = base::MakeRefCounted<StringIOBuffer>("baz");
TestCompletionCallback write_callback;
rv = adapter.Write(write_buf.get(), write_buf->size(),
write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
rv = read_callback.WaitForResult();
ASSERT_EQ(6, rv);
EXPECT_EQ("foobar", base::StringPiece(read_buf->data(), rv));
rv = write_callback.WaitForResult();
ASSERT_EQ(3, rv);
EXPECT_TRUE(data.AllReadDataConsumed());
EXPECT_TRUE(data.AllWriteDataConsumed());
}
} // namespace test
} // namespace net