blob: afa07a25cdaec4ce471013b0ea292d3e999edc1f [file] [log] [blame]
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/devtools_bridge/socket_tunnel_server.h"
#include "base/bind.h"
#include "base/location.h"
#include "components/devtools_bridge/abstract_data_channel.h"
#include "components/devtools_bridge/session_dependency_factory.h"
#include "components/devtools_bridge/socket_tunnel_connection.h"
#include "components/devtools_bridge/socket_tunnel_packet_handler.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/socket/unix_domain_client_socket_posix.h"
namespace devtools_bridge {
class SocketTunnelServer::Connection : public SocketTunnelConnection {
public:
class Delegate {
public:
virtual void RemoveConnection(int index) = 0;
virtual void SendPacket(
const void* data, size_t length) = 0;
};
Connection(Delegate* delegate, int index, const std::string& socket_name)
: SocketTunnelConnection(index),
delegate_(delegate),
socket_(socket_name, true) {
}
void Connect() {
int result = socket()->Connect(base::Bind(
&Connection::OnConnectionComplete, base::Unretained(this)));
if (result != net::ERR_IO_PENDING)
OnConnectionComplete(result);
}
void ClosedByClient() {
if (socket()->IsConnected()) {
socket()->Disconnect();
SendControlPacket(SERVER_CLOSE);
}
delegate_->RemoveConnection(index_);
}
protected:
net::StreamSocket* socket() override {
return &socket_;
}
void OnDataPacketRead(const void* data, size_t length) override {
delegate_->SendPacket(data, length);
ReadNextChunk();
}
void OnReadError(int error) override {
socket()->Disconnect();
SendControlPacket(SERVER_CLOSE);
delegate_->RemoveConnection(index_);
delegate_ = NULL;
}
private:
void OnConnectionComplete(int result) {
if (result == net::OK) {
SendControlPacket(SERVER_OPEN_ACK);
ReadNextChunk();
} else {
SendControlPacket(SERVER_CLOSE);
delegate_->RemoveConnection(index_);
delegate_ = NULL;
}
}
void SendControlPacket(ServerOpCode op_code) {
char buffer[kControlPacketSizeBytes];
BuildControlPacket(buffer, op_code);
delegate_->SendPacket(buffer, kControlPacketSizeBytes);
}
Delegate* delegate_;
net::UnixDomainClientSocket socket_;
};
/**
* Lives on the IO thread.
*/
class SocketTunnelServer::ConnectionController
: private Connection::Delegate {
public:
ConnectionController(
scoped_refptr<base::TaskRunner> io_task_runner,
scoped_refptr<AbstractDataChannel::Proxy> data_channel,
const std::string& socket_name)
: io_task_runner_(io_task_runner),
data_channel_(data_channel),
socket_name_(socket_name) {
DCHECK(data_channel_.get());
}
void HandleControlPacket(int connection_index, int op_code) {
DCHECK(connection_index < kMaxConnectionCount);
switch (op_code) {
case SocketTunnelConnection::CLIENT_OPEN:
if (connections_[connection_index].get() != NULL) {
DLOG(ERROR) << "Opening connection which already open: "
<< connection_index;
HandleProtocolError();
return;
}
connections_[connection_index].reset(
new Connection(this, connection_index, socket_name_));
connections_[connection_index]->Connect();
break;
case SocketTunnelConnection::CLIENT_CLOSE:
if (connections_[connection_index].get() == NULL) {
// Ignore. Client may close the connection before received
// notification from the server.
return;
}
connections_[connection_index]->ClosedByClient();
break;
default:
DLOG(ERROR) << "Invalid op_code: " << op_code;
HandleProtocolError();
return;
}
}
void HandleDataPacket(int connection_index,
scoped_refptr<net::IOBufferWithSize> packet) {
Connection* connection = connections_[connection_index].get();
if (connection != NULL)
connection->Write(packet);
}
void HandleProtocolError() {
data_channel_->Close();
}
void CloseAllConnections() {
for (int i = 0; i < kMaxConnectionCount; i++) {
connections_[i].reset();
}
}
private:
static void DeleteConnectionImpl(Connection*) {}
// Connection::Delegate implementation
void RemoveConnection(int connection_index) override {
// Remove immediately, delete later to preserve this of the caller.
Connection* connection = connections_[connection_index].release();
io_task_runner_->PostTask(
FROM_HERE, base::Bind(&ConnectionController::DeleteConnectionImpl,
base::Owned(connection)));
}
void SendPacket(const void* data, size_t length) override {
data_channel_->SendBinaryMessage(data, length);
}
static const int kMaxConnectionCount =
SocketTunnelConnection::kMaxConnectionCount;
scoped_refptr<base::TaskRunner> io_task_runner_;
scoped_refptr<AbstractDataChannel::Proxy> data_channel_;
scoped_ptr<Connection> connections_[kMaxConnectionCount];
const std::string socket_name_;
};
class SocketTunnelServer::DataChannelObserver
: public AbstractDataChannel::Observer,
private SocketTunnelPacketHandler {
public:
DataChannelObserver(scoped_refptr<base::TaskRunner> io_task_runner,
scoped_ptr<ConnectionController> controller)
: io_task_runner_(io_task_runner),
controller_(controller.Pass()) {
}
~DataChannelObserver() override {
// Deleting on IO thread allows post tasks with base::Unretained
// because all of them will be processed before deletion.
io_task_runner_->PostTask(
FROM_HERE, base::Bind(&DataChannelObserver::DeleteControllerOnIOThread,
base::Passed(&controller_)));
}
void OnOpen() override {
// Nothing to do. Activity could only be initiated by a control packet.
}
void OnClose() override {
io_task_runner_->PostTask(
FROM_HERE, base::Bind(
&ConnectionController::CloseAllConnections,
base::Unretained(controller_.get())));
}
void OnMessage(const void* data, size_t length) override {
DecodePacket(data, length);
}
private:
static void DeleteControllerOnIOThread(
scoped_ptr<ConnectionController> controller) {}
// SocketTunnelPacketHandler implementation.
void HandleControlPacket(int connection_index, int op_code) override {
io_task_runner_->PostTask(
FROM_HERE, base::Bind(
&ConnectionController::HandleControlPacket,
base::Unretained(controller_.get()),
connection_index,
op_code));
}
void HandleDataPacket(int connection_index,
scoped_refptr<net::IOBufferWithSize> data) override {
io_task_runner_->PostTask(
FROM_HERE, base::Bind(
&ConnectionController::HandleDataPacket,
base::Unretained(controller_.get()),
connection_index,
data));
}
void HandleProtocolError() override {
io_task_runner_->PostTask(
FROM_HERE, base::Bind(
&ConnectionController::HandleProtocolError,
base::Unretained(controller_.get())));
}
const scoped_refptr<base::TaskRunner> io_task_runner_;
scoped_ptr<ConnectionController> controller_;
};
SocketTunnelServer::SocketTunnelServer(SessionDependencyFactory* factory,
AbstractDataChannel* data_channel,
const std::string& socket_name)
: data_channel_(data_channel) {
scoped_ptr<ConnectionController> controller(
new ConnectionController(factory->io_thread_task_runner(),
data_channel->proxy(),
socket_name));
scoped_ptr<DataChannelObserver> data_channel_observer(
new DataChannelObserver(factory->io_thread_task_runner(),
controller.Pass()));
data_channel_->RegisterObserver(data_channel_observer.Pass());
}
SocketTunnelServer::~SocketTunnelServer() {
data_channel_->UnregisterObserver();
}
} // namespace devtools_bridge