// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "vm_tools/cicerone/tremplin_listener_impl.h"

#include <arpa/inet.h>
#include <inttypes.h>
#include <stdio.h>

#include <memory>
#include <string>

#include <base/bind.h>
#include <base/logging.h>
#include <base/strings/string_util.h>
#include <base/threading/thread_task_runner_handle.h>

#include "vm_tools/cicerone/service.h"

namespace vm_tools {
namespace cicerone {

TremplinListenerImpl::TremplinListenerImpl(
    base::WeakPtr<vm_tools::cicerone::Service> service)
    : service_(service), task_runner_(base::ThreadTaskRunnerHandle::Get()) {}

void TremplinListenerImpl::OverridePeerAddressForTesting(
    const std::string& testing_peer_address) {
  base::AutoLock lock_scope(testing_peer_address_lock_);
  testing_peer_address_ = testing_peer_address;
}

grpc::Status TremplinListenerImpl::TremplinReady(
    grpc::ServerContext* ctx,
    const vm_tools::tremplin::TremplinStartupInfo* request,
    vm_tools::tremplin::EmptyMessage* response) {
  uint32_t cid = ExtractCidFromPeerAddress(ctx);
  if (cid == 0) {
    return grpc::Status(grpc::FAILED_PRECONDITION,
                        "Failed parsing vsock cid for TremplinListener");
  }

  bool result = false;
  base::WaitableEvent event(base::WaitableEvent::ResetPolicy::AUTOMATIC,
                            base::WaitableEvent::InitialState::NOT_SIGNALED);
  task_runner_->PostTask(
      FROM_HERE, base::Bind(&vm_tools::cicerone::Service::ConnectTremplin,
                            service_, cid, &result, &event));
  event.Wait();
  if (!result) {
    LOG(ERROR) << "Received TremplinReady but could not find matching VM: "
               << ctx->peer();
    return grpc::Status(grpc::FAILED_PRECONDITION,
                        "Cannot find VM for TremplinListener");
  }

  return grpc::Status::OK;
}

grpc::Status TremplinListenerImpl::UpdateCreateStatus(
    grpc::ServerContext* ctx,
    const vm_tools::tremplin::ContainerCreationProgress* request,
    vm_tools::tremplin::EmptyMessage* response) {
  uint32_t cid = ExtractCidFromPeerAddress(ctx);
  if (cid == 0) {
    return grpc::Status(grpc::FAILED_PRECONDITION,
                        "Failed parsing vsock cid for TremplinListener");
  }

  bool result = false;
  base::WaitableEvent event(base::WaitableEvent::ResetPolicy::AUTOMATIC,
                            base::WaitableEvent::InitialState::NOT_SIGNALED);
  if (request->status() == tremplin::ContainerCreationProgress::DOWNLOADING) {
    task_runner_->PostTask(
        FROM_HERE,
        base::Bind(&vm_tools::cicerone::Service::LxdContainerDownloading,
                   service_, cid, request->container_name(),
                   request->download_progress(), &result, &event));
  } else {
    vm_tools::cicerone::Service::CreateStatus status;
    switch (request->status()) {
      case tremplin::ContainerCreationProgress::CREATED:
        status = vm_tools::cicerone::Service::CreateStatus::CREATED;
        break;
      case tremplin::ContainerCreationProgress::DOWNLOAD_TIMED_OUT:
        status = vm_tools::cicerone::Service::CreateStatus::DOWNLOAD_TIMED_OUT;
        break;
      case tremplin::ContainerCreationProgress::CANCELLED:
        status = vm_tools::cicerone::Service::CreateStatus::CANCELLED;
        break;
      case tremplin::ContainerCreationProgress::FAILED:
        status = vm_tools::cicerone::Service::CreateStatus::FAILED;
        break;
      default:
        status = vm_tools::cicerone::Service::CreateStatus::UNKNOWN;
        break;
    }
    task_runner_->PostTask(
        FROM_HERE, base::Bind(&vm_tools::cicerone::Service::LxdContainerCreated,
                              service_, cid, request->container_name(), status,
                              request->failure_reason(), &result, &event));
  }

  event.Wait();
  if (!result) {
    LOG(ERROR)
        << "Received UpdateCreateStatus RPC but could not find matching VM: "
        << ctx->peer();
    return grpc::Status(grpc::FAILED_PRECONDITION,
                        "Cannot find VM for TremplinListener");
  }

  return grpc::Status::OK;
}

grpc::Status TremplinListenerImpl::UpdateStartStatus(
    grpc::ServerContext* ctx,
    const vm_tools::tremplin::ContainerStartProgress* request,
    vm_tools::tremplin::EmptyMessage* response) {
  uint32_t cid = ExtractCidFromPeerAddress(ctx);
  if (cid == 0) {
    return grpc::Status(grpc::FAILED_PRECONDITION,
                        "Failed parsing vsock cid for TremplinListener");
  }

  bool result = false;
  base::WaitableEvent event(base::WaitableEvent::ResetPolicy::AUTOMATIC,
                            base::WaitableEvent::InitialState::NOT_SIGNALED);
  vm_tools::cicerone::Service::StartStatus status;
  switch (request->status()) {
    case tremplin::ContainerStartProgress::STARTED:
      status = vm_tools::cicerone::Service::StartStatus::STARTED;
      break;
    case tremplin::ContainerStartProgress::CANCELLED:
      status = vm_tools::cicerone::Service::StartStatus::CANCELLED;
      break;
    case tremplin::ContainerStartProgress::FAILED:
      status = vm_tools::cicerone::Service::StartStatus::FAILED;
      break;
    default:
      status = vm_tools::cicerone::Service::StartStatus::UNKNOWN;
      break;
  }
  task_runner_->PostTask(
      FROM_HERE, base::Bind(&vm_tools::cicerone::Service::LxdContainerStarting,
                            service_, cid, request->container_name(), status,
                            request->failure_reason(), &result, &event));

  event.Wait();
  if (!result) {
    LOG(ERROR)
        << "Received UpdateStartStatus RPC but could not find matching VM: "
        << ctx->peer();
    return grpc::Status(grpc::FAILED_PRECONDITION,
                        "Cannot find VM for TremplinListener");
  }

  return grpc::Status::OK;
}

grpc::Status TremplinListenerImpl::UpdateExportStatus(
    grpc::ServerContext* ctx,
    const vm_tools::tremplin::ContainerExportProgress* request,
    vm_tools::tremplin::EmptyMessage* response) {
  uint32_t cid = ExtractCidFromPeerAddress(ctx);
  if (cid == 0) {
    return grpc::Status(grpc::FAILED_PRECONDITION,
                        "Failed parsing vsock cid for TremplinListener");
  }

  ExportLxdContainerProgressSignal progress_signal;
  if (!ExportLxdContainerProgressSignal::Status_IsValid(
          static_cast<int>(request->status()))) {
    return grpc::Status(grpc::FAILED_PRECONDITION,
                        "Invalid status field in protobuf request");
  }
  progress_signal.set_status(
      static_cast<ExportLxdContainerProgressSignal::Status>(request->status()));
  progress_signal.set_container_name(request->container_name());
  progress_signal.set_progress_percent(request->progress_percent());
  progress_signal.set_failure_reason(request->failure_reason());
  base::WaitableEvent event(base::WaitableEvent::ResetPolicy::AUTOMATIC,
                            base::WaitableEvent::InitialState::NOT_SIGNALED);
  bool result = false;
  task_runner_->PostTask(
      FROM_HERE,
      base::Bind(&vm_tools::cicerone::Service::ContainerExportProgress,
                 service_, cid, &progress_signal, &result, &event));
  event.Wait();
  if (!result) {
    LOG(ERROR) << "Failure updating container export progress";
    return grpc::Status(grpc::FAILED_PRECONDITION,
                        "Failure in UpdateExportStatus");
  }

  return grpc::Status::OK;
}

grpc::Status TremplinListenerImpl::UpdateImportStatus(
    grpc::ServerContext* ctx,
    const vm_tools::tremplin::ContainerImportProgress* request,
    vm_tools::tremplin::EmptyMessage* response) {
  uint32_t cid = ExtractCidFromPeerAddress(ctx);
  if (cid == 0) {
    return grpc::Status(grpc::FAILED_PRECONDITION,
                        "Failed parsing vsock cid for TremplinListener");
  }

  ImportLxdContainerProgressSignal progress_signal;
  if (!ImportLxdContainerProgressSignal::Status_IsValid(
          static_cast<int>(request->status()))) {
    return grpc::Status(grpc::FAILED_PRECONDITION,
                        "Invalid status field in protobuf request");
  }
  progress_signal.set_status(
      static_cast<ImportLxdContainerProgressSignal::Status>(request->status()));
  progress_signal.set_container_name(request->container_name());
  progress_signal.set_progress_percent(request->progress_percent());
  progress_signal.set_failure_reason(request->failure_reason());
  base::WaitableEvent event(base::WaitableEvent::ResetPolicy::AUTOMATIC,
                            base::WaitableEvent::InitialState::NOT_SIGNALED);
  bool result = false;
  task_runner_->PostTask(
      FROM_HERE,
      base::Bind(&vm_tools::cicerone::Service::ContainerImportProgress,
                 service_, cid, &progress_signal, &result, &event));
  event.Wait();
  if (!result) {
    LOG(ERROR) << "Failure updating container import progress";
    return grpc::Status(grpc::FAILED_PRECONDITION,
                        "Failure in UpdateImportStatus");
  }

  return grpc::Status::OK;
}

// Returns 0 on failure, otherwise returns the 32-bit vsock cid.
uint32_t TremplinListenerImpl::ExtractCidFromPeerAddress(
    grpc::ServerContext* ctx) {
  uint32_t cid = 0;
  std::string peer_address = ctx->peer();
  {
    base::AutoLock lock_scope(testing_peer_address_lock_);
    if (!testing_peer_address_.empty()) {
      peer_address = testing_peer_address_;
    }
  }
  if (sscanf(peer_address.c_str(), "vsock:%" SCNu32, &cid) != 1) {
    LOG(WARNING) << "Failed to parse peer address " << peer_address;
    return 0;
  }
  return cid;
}

}  // namespace cicerone
}  // namespace vm_tools
