blob: 4d464a76aaae4c14de1a64a91137b93ab4aa8d1f [file] [log] [blame]
// Copyright 2016 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/host/security_key/security_key_socket.h"
#include <utility>
#include "base/callback_helpers.h"
#include "base/stl_util.h"
#include "base/timer/timer.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/socket/stream_socket.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
namespace remoting {
namespace {
const size_t kRequestSizeBytes = 4;
const size_t kMaxRequestLength = 16384;
const size_t kRequestReadBufferLength = kRequestSizeBytes + kMaxRequestLength;
// SSH Failure Code
const char kSshError[] = {0x05};
} // namespace
SecurityKeySocket::SecurityKeySocket(std::unique_ptr<net::StreamSocket> socket,
base::TimeDelta timeout,
const base::Closure& timeout_callback)
: socket_(std::move(socket)),
read_buffer_(base::MakeRefCounted<net::IOBufferWithSize>(
kRequestReadBufferLength)) {
timer_.reset(new base::OneShotTimer());
timer_->Start(FROM_HERE, timeout, timeout_callback);
}
SecurityKeySocket::~SecurityKeySocket() {
DCHECK(thread_checker_.CalledOnValidThread());
}
bool SecurityKeySocket::GetAndClearRequestData(std::string* data_out) {
DCHECK(thread_checker_.CalledOnValidThread());
DCHECK(!waiting_for_request_);
if (!IsRequestComplete() || IsRequestTooLarge()) {
return false;
}
// The request size is not part of the data; don't send it.
data_out->assign(request_data_.begin() + kRequestSizeBytes,
request_data_.end());
request_data_.clear();
return true;
}
void SecurityKeySocket::SendResponse(const std::string& response_data) {
DCHECK(thread_checker_.CalledOnValidThread());
DCHECK(!write_buffer_);
std::string response_length_string = GetResponseLengthAsBytes(response_data);
int response_len = response_length_string.size() + response_data.size();
std::unique_ptr<std::string> response(
new std::string(response_length_string + response_data));
write_buffer_ = base::MakeRefCounted<net::DrainableIOBuffer>(
base::MakeRefCounted<net::StringIOBuffer>(std::move(response)),
response_len);
DCHECK(write_buffer_->BytesRemaining());
DoWrite();
}
void SecurityKeySocket::SendSshError() {
DCHECK(thread_checker_.CalledOnValidThread());
SendResponse(std::string(kSshError, base::size(kSshError)));
}
void SecurityKeySocket::StartReadingRequest(
const base::Closure& request_received_callback) {
DCHECK(thread_checker_.CalledOnValidThread());
DCHECK(request_received_callback_.is_null());
waiting_for_request_ = true;
request_received_callback_ = request_received_callback;
DoRead();
}
void SecurityKeySocket::OnDataWritten(int result) {
DCHECK(thread_checker_.CalledOnValidThread());
DCHECK(write_buffer_);
if (result < 0) {
LOG(ERROR) << "Error sending response: " << result;
return;
}
ResetTimer();
write_buffer_->DidConsume(result);
if (!write_buffer_->BytesRemaining()) {
write_buffer_ = nullptr;
return;
}
DoWrite();
}
void SecurityKeySocket::DoWrite() {
DCHECK(thread_checker_.CalledOnValidThread());
DCHECK(write_buffer_);
net::NetworkTrafficAnnotationTag traffic_annotation =
net::DefineNetworkTrafficAnnotation("security_key_socket", R"(
semantics {
sender: "Security Key Socket"
description:
"This request performs the communication between processes when "
"handling security key (gnubby) authentication."
trigger:
"Performing an action (such as signing into a website with "
"two-factor authentication enabled) that requires a security key "
"touch."
data: "Security key protocol data."
destination: LOCAL
}
policy {
cookies_allowed: NO
setting: "This feature cannot be disabled in Settings."
chrome_policy {
RemoteAccessHostAllowGnubbyAuth {
RemoteAccessHostAllowGnubbyAuth: false
}
}
})");
int result = socket_->Write(
write_buffer_.get(), write_buffer_->BytesRemaining(),
base::Bind(&SecurityKeySocket::OnDataWritten, base::Unretained(this)),
traffic_annotation);
if (result != net::ERR_IO_PENDING) {
OnDataWritten(result);
}
}
void SecurityKeySocket::OnDataRead(int result) {
DCHECK(thread_checker_.CalledOnValidThread());
if (result <= 0) {
if (result < 0) {
LOG(ERROR) << "Error reading request: " << result;
socket_read_error_ = true;
}
waiting_for_request_ = false;
base::ResetAndReturn(&request_received_callback_).Run();
return;
}
ResetTimer();
// TODO(joedow): If there are multiple requests in a burst, it is possible
// that we could read too many bytes from the buffer (e.g. all of request #1
// and some of request #2). We should consider using the request header to
// determine the request length and only read that amount from buffer.
request_data_.insert(request_data_.end(), read_buffer_->data(),
read_buffer_->data() + result);
if (IsRequestComplete()) {
waiting_for_request_ = false;
base::ResetAndReturn(&request_received_callback_).Run();
return;
}
DoRead();
}
void SecurityKeySocket::DoRead() {
DCHECK(thread_checker_.CalledOnValidThread());
int result = socket_->Read(
read_buffer_.get(), kRequestReadBufferLength,
base::Bind(&SecurityKeySocket::OnDataRead, base::Unretained(this)));
if (result != net::ERR_IO_PENDING) {
OnDataRead(result);
}
}
bool SecurityKeySocket::IsRequestComplete() const {
DCHECK(thread_checker_.CalledOnValidThread());
if (request_data_.size() < kRequestSizeBytes) {
return false;
}
return GetRequestLength() <= request_data_.size();
}
bool SecurityKeySocket::IsRequestTooLarge() const {
DCHECK(thread_checker_.CalledOnValidThread());
if (request_data_.size() < kRequestSizeBytes) {
return false;
}
return GetRequestLength() > kMaxRequestLength;
}
size_t SecurityKeySocket::GetRequestLength() const {
DCHECK(request_data_.size() >= kRequestSizeBytes);
return ((request_data_[0] & 255) << 24) + ((request_data_[1] & 255) << 16) +
((request_data_[2] & 255) << 8) + (request_data_[3] & 255) +
kRequestSizeBytes;
}
std::string SecurityKeySocket::GetResponseLengthAsBytes(
const std::string& response) const {
std::string response_len;
response_len.reserve(kRequestSizeBytes);
int len = response.size();
response_len.push_back((len >> 24) & 255);
response_len.push_back((len >> 16) & 255);
response_len.push_back((len >> 8) & 255);
response_len.push_back(len & 255);
return response_len;
}
void SecurityKeySocket::ResetTimer() {
if (timer_->IsRunning()) {
timer_->Reset();
}
}
} // namespace remoting