blob: 27e1f7b22335c7ae1b39bce9f4646408f106a9af [file] [log] [blame]
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "u2f_hid_device.h"
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/command_line.h"
#include "base/threading/thread_task_runner_handle.h"
#include "crypto/random.h"
#include "device/base/device_client.h"
#include "device/hid/hid_connection.h"
#include "u2f_apdu_command.h"
#include "u2f_message.h"
namespace device {
namespace switches {
static constexpr char kEnableU2fHidTest[] = "enable-u2f-hid-tests";
} // namespace switches
U2fHidDevice::U2fHidDevice(scoped_refptr<HidDeviceInfo> device_info)
: U2fDevice(),
state_(State::INIT),
device_info_(device_info),
weak_factory_(this) {
}
U2fHidDevice::~U2fHidDevice() {
// Cleanup connection
if (connection_ && !connection_->closed())
connection_->Close();
}
void U2fHidDevice::DeviceTransact(std::unique_ptr<U2fApduCommand> command,
const DeviceCallback& callback) {
Transition(std::move(command), callback);
}
void U2fHidDevice::Transition(std::unique_ptr<U2fApduCommand> command,
const DeviceCallback& callback) {
switch (state_) {
case State::INIT:
state_ = State::BUSY;
ArmTimeout(callback);
Connect(base::Bind(&U2fHidDevice::OnConnect, weak_factory_.GetWeakPtr(),
base::Passed(&command), callback));
break;
case State::CONNECTED:
state_ = State::BUSY;
ArmTimeout(callback);
AllocateChannel(std::move(command), callback);
break;
case State::IDLE: {
state_ = State::BUSY;
std::unique_ptr<U2fMessage> msg = U2fMessage::Create(
channel_id_, U2fMessage::Type::CMD_MSG, command->GetEncodedCommand());
ArmTimeout(callback);
// Write message to the device
WriteMessage(std::move(msg), true,
base::Bind(&U2fHidDevice::MessageReceived,
weak_factory_.GetWeakPtr(), callback));
break;
}
case State::BUSY:
pending_transactions_.push_back({std::move(command), callback});
break;
case State::DEVICE_ERROR:
default:
base::WeakPtr<U2fHidDevice> self = weak_factory_.GetWeakPtr();
callback.Run(false, nullptr);
// Executing callbacks may free |this|. Check |self| first.
while (self && !pending_transactions_.empty()) {
// Respond to any pending requests
DeviceCallback pending_cb = pending_transactions_.front().second;
pending_transactions_.pop_front();
pending_cb.Run(false, nullptr);
}
break;
}
}
void U2fHidDevice::Connect(const HidService::ConnectCallback& callback) {
HidService* hid_service = DeviceClient::Get()->GetHidService();
hid_service->Connect(device_info_->device_id(), callback);
}
void U2fHidDevice::OnConnect(std::unique_ptr<U2fApduCommand> command,
const DeviceCallback& callback,
scoped_refptr<HidConnection> connection) {
if (state_ == State::DEVICE_ERROR)
return;
timeout_callback_.Cancel();
if (connection) {
connection_ = connection;
state_ = State::CONNECTED;
} else {
state_ = State::DEVICE_ERROR;
}
Transition(std::move(command), callback);
}
void U2fHidDevice::AllocateChannel(std::unique_ptr<U2fApduCommand> command,
const DeviceCallback& callback) {
// Send random nonce to device to verify received message
std::vector<uint8_t> nonce(8);
crypto::RandBytes(nonce.data(), nonce.size());
std::unique_ptr<U2fMessage> message =
U2fMessage::Create(channel_id_, U2fMessage::Type::CMD_INIT, nonce);
WriteMessage(
std::move(message), true,
base::Bind(&U2fHidDevice::OnAllocateChannel, weak_factory_.GetWeakPtr(),
nonce, base::Passed(&command), callback));
}
void U2fHidDevice::OnAllocateChannel(std::vector<uint8_t> nonce,
std::unique_ptr<U2fApduCommand> command,
const DeviceCallback& callback,
bool success,
std::unique_ptr<U2fMessage> message) {
if (state_ == State::DEVICE_ERROR)
return;
timeout_callback_.Cancel();
if (!success || !message) {
state_ = State::DEVICE_ERROR;
Transition(nullptr, callback);
return;
}
// Channel allocation response is defined as:
// 0: 8 byte nonce
// 8: 4 byte channel id
// 12: Protocol version id
// 13: Major device version
// 14: Minor device version
// 15: Build device version
// 16: Capabilities
std::vector<uint8_t> payload = message->GetMessagePayload();
if (payload.size() != 17) {
state_ = State::DEVICE_ERROR;
Transition(nullptr, callback);
return;
}
std::vector<uint8_t> received_nonce(std::begin(payload),
std::begin(payload) + 8);
if (nonce != received_nonce) {
state_ = State::DEVICE_ERROR;
Transition(nullptr, callback);
return;
}
size_t index = 8;
channel_id_ = payload[index++] << 24;
channel_id_ |= payload[index++] << 16;
channel_id_ |= payload[index++] << 8;
channel_id_ |= payload[index++];
capabilities_ = payload[16];
state_ = State::IDLE;
Transition(std::move(command), callback);
}
void U2fHidDevice::WriteMessage(std::unique_ptr<U2fMessage> message,
bool response_expected,
U2fHidMessageCallback callback) {
if (!connection_ || !message || message->NumPackets() == 0) {
std::move(callback).Run(false, nullptr);
return;
}
scoped_refptr<net::IOBufferWithSize> buffer = message->PopNextPacket();
connection_->Write(
buffer, buffer->size(),
base::Bind(&U2fHidDevice::PacketWritten, weak_factory_.GetWeakPtr(),
base::Passed(&message), true, base::Passed(&callback)));
}
void U2fHidDevice::PacketWritten(std::unique_ptr<U2fMessage> message,
bool response_expected,
U2fHidMessageCallback callback,
bool success) {
if (success && message->NumPackets() > 0) {
WriteMessage(std::move(message), response_expected, std::move(callback));
} else if (success && response_expected) {
ReadMessage(std::move(callback));
} else {
std::move(callback).Run(success, nullptr);
}
}
void U2fHidDevice::ReadMessage(U2fHidMessageCallback callback) {
if (!connection_) {
std::move(callback).Run(false, nullptr);
return;
}
connection_->Read(base::Bind(&U2fHidDevice::OnRead,
weak_factory_.GetWeakPtr(),
base::Passed(&callback)));
}
void U2fHidDevice::OnRead(U2fHidMessageCallback callback,
bool success,
scoped_refptr<net::IOBuffer> buf,
size_t size) {
if (!success || !buf) {
std::move(callback).Run(success, nullptr);
return;
}
std::vector<uint8_t> read_buffer(buf->data(), buf->data() + size);
std::unique_ptr<U2fMessage> read_message =
U2fMessage::CreateFromSerializedData(read_buffer);
if (!read_message) {
std::move(callback).Run(false, nullptr);
return;
}
// Received a message from a different channel, so try again
if (channel_id_ != read_message->channel_id()) {
connection_->Read(base::Bind(&U2fHidDevice::OnRead,
weak_factory_.GetWeakPtr(),
base::Passed(&callback)));
return;
}
if (read_message->MessageComplete()) {
std::move(callback).Run(success, std::move(read_message));
return;
}
// Continue reading additional packets
connection_->Read(
base::Bind(&U2fHidDevice::OnReadContinuation, weak_factory_.GetWeakPtr(),
base::Passed(&read_message), base::Passed(&callback)));
}
void U2fHidDevice::OnReadContinuation(std::unique_ptr<U2fMessage> message,
U2fHidMessageCallback callback,
bool success,
scoped_refptr<net::IOBuffer> buf,
size_t size) {
if (!success || !buf) {
std::move(callback).Run(success, nullptr);
return;
}
std::vector<uint8_t> read_buffer(buf->data(), buf->data() + size);
message->AddContinuationPacket(read_buffer);
if (message->MessageComplete()) {
std::move(callback).Run(success, std::move(message));
return;
}
connection_->Read(
base::Bind(&U2fHidDevice::OnReadContinuation, weak_factory_.GetWeakPtr(),
base::Passed(&message), base::Passed(&callback)));
}
void U2fHidDevice::MessageReceived(const DeviceCallback& callback,
bool success,
std::unique_ptr<U2fMessage> message) {
if (state_ == State::DEVICE_ERROR)
return;
timeout_callback_.Cancel();
if (!success) {
state_ = State::DEVICE_ERROR;
Transition(nullptr, callback);
return;
}
std::unique_ptr<U2fApduResponse> response = nullptr;
if (message)
response = U2fApduResponse::CreateFromMessage(message->GetMessagePayload());
state_ = State::IDLE;
base::WeakPtr<U2fHidDevice> self = weak_factory_.GetWeakPtr();
callback.Run(success, std::move(response));
// Executing |callback| may have freed |this|. Check |self| first.
if (self && !pending_transactions_.empty()) {
// If any transactions were queued, process the first one
std::unique_ptr<U2fApduCommand> pending_cmd =
std::move(pending_transactions_.front().first);
DeviceCallback pending_cb = pending_transactions_.front().second;
pending_transactions_.pop_front();
Transition(std::move(pending_cmd), pending_cb);
}
}
void U2fHidDevice::TryWink(const WinkCallback& callback) {
// Only try to wink if device claims support
if (!(capabilities_ & kWinkCapability) || state_ != State::IDLE) {
callback.Run();
return;
}
std::unique_ptr<U2fMessage> wink_message = U2fMessage::Create(
channel_id_, U2fMessage::Type::CMD_WINK, std::vector<uint8_t>());
WriteMessage(
std::move(wink_message), true,
base::Bind(&U2fHidDevice::OnWink, weak_factory_.GetWeakPtr(), callback));
}
void U2fHidDevice::OnWink(const WinkCallback& callback,
bool success,
std::unique_ptr<U2fMessage> response) {
callback.Run();
}
void U2fHidDevice::ArmTimeout(const DeviceCallback& callback) {
DCHECK(timeout_callback_.IsCancelled());
timeout_callback_.Reset(base::Bind(&U2fHidDevice::OnTimeout,
weak_factory_.GetWeakPtr(), callback));
// Setup timeout task for 3 seconds
base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
FROM_HERE, timeout_callback_.callback(),
base::TimeDelta::FromMilliseconds(3000));
}
void U2fHidDevice::OnTimeout(const DeviceCallback& callback) {
state_ = State::DEVICE_ERROR;
Transition(nullptr, callback);
}
std::string U2fHidDevice::GetId() {
std::ostringstream id("hid:", std::ios::ate);
id << device_info_->device_id();
return id.str();
}
// static
bool U2fHidDevice::IsTestEnabled() {
base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
return command_line->HasSwitch(switches::kEnableU2fHidTest);
}
} // namespace device