blob: c74668458ca77517c5c80506e9b6cfbeba7f5b44 [file] [log] [blame]
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package main
import (
"context"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"regexp"
"strconv"
"strings"
pb "chromiumos/vm_tools/tremplin_proto"
"github.com/lxc/lxd/client"
"github.com/lxc/lxd/shared/api"
"google.golang.org/grpc"
)
// downloadRegexp extracts the download type and progress percentage from
// download operation metadata.
var downloadRegexp *regexp.Regexp
func init() {
// Example matches:
// "metadata: 100% (5.23MB/s)" matches ("metadata", "100")
// "rootfs: 23% (358.09kB/s)" matches ("rootfs", "23")
downloadRegexp = regexp.MustCompile("([[:alpha:]]+): ([[:digit:]]+)% [0-9A-Za-z /.()]*$")
}
// getContainerName converts an LXD source path (/1.0/containers/foo) to a container name.
func getContainerName(s string) (string, error) {
components := strings.Split(s, "/")
// Expected components are: "", "1.0", "containers", "<container name>".
if len(components) != 4 {
return "", fmt.Errorf("invalid source path: %q", s)
}
if components[2] != "containers" {
return "", fmt.Errorf("source path is not a container: %q", s)
}
return components[3], nil
}
// getDownloadPercentage extracts the download progress (as a percentage)
// from an api.Operation's Metadata map.
func getDownloadPercentage(opMetadata map[string]interface{}) (int32, error) {
progress, ok := opMetadata["download_progress"].(string)
if !ok {
return 0, errors.New("could not read operation download progress")
}
matches := downloadRegexp.FindStringSubmatch(progress)
if matches == nil {
return 0, fmt.Errorf("didn't find download status in %q", progress)
}
downloadPercent, err := strconv.ParseInt(matches[2], 10, 32)
if err != nil {
return 0, fmt.Errorf("failed to convert download percent to int: %q", matches[2])
}
// Count metadata download as 0% of the total, since the entire rootfs still
// needs to be downloaded.
if matches[1] == "metadata" {
downloadPercent = 0
}
return int32(downloadPercent), nil
}
// server is used to implement the gRPC tremplin.Server.
type tremplinServer struct {
lxd lxd.ContainerServer
grpcServer *grpc.Server
listenerClient pb.TremplinListenerClient
}
// execProgram runs a program in a container to completion, capturing its
// return value, stdout, and stderr.
func (s *tremplinServer) execProgram(containerName string, args []string) (ret int, stdout string, stderr string, err error) {
req := api.ContainerExecPost{
Command: args,
WaitForWS: true,
Interactive: false,
}
stdoutSink := &stdioSink{}
stderrSink := &stdioSink{}
execArgs := &lxd.ContainerExecArgs{
Stdin: &stdioSink{},
Stdout: stdoutSink,
Stderr: stderrSink,
}
op, err := s.lxd.ExecContainer(containerName, req, execArgs)
if err != nil {
return 0, "", "", err
}
if err = op.Wait(); err != nil {
return 0, "", "", err
}
opAPI := op.Get()
retVal, ok := opAPI.Metadata["return"].(float64)
if !ok {
return 0, "", "", fmt.Errorf("return value for %q is not a float64", args[0])
}
return int(retVal), stdoutSink.String(), stderrSink.String(), nil
}
func (s *tremplinServer) handleCreateOperation(op api.Operation) {
containers := op.Resources["containers"]
if len(containers) != 1 {
log.Printf("Got %v containers instead of 1", len(containers))
return
}
name, err := getContainerName(containers[0])
if err != nil {
log.Printf("Failed to get container name for operation: %v", err)
return
}
req := &pb.ContainerCreationProgress{
ContainerName: name,
}
switch op.StatusCode {
case api.Success:
req.Status = pb.ContainerCreationProgress_CREATED
case api.Running:
req.Status = pb.ContainerCreationProgress_DOWNLOADING
downloadPercent, err := getDownloadPercentage(op.Metadata)
if err != nil {
log.Printf("Failed to parse download percentage: %v", err)
return
}
req.DownloadProgress = downloadPercent
case api.Cancelled, api.Failure:
req.Status = pb.ContainerCreationProgress_FAILED
req.FailureReason = op.Err
default:
req.Status = pb.ContainerCreationProgress_UNKNOWN
req.FailureReason = fmt.Sprintf("unhandled create status: %s", op.Status)
}
_, err = s.listenerClient.UpdateCreateStatus(context.Background(), req)
if err != nil {
log.Printf("Could not update create status on host: %v", err)
return
}
}
// CreateContainer implements tremplin.CreateContainer.
func (s *tremplinServer) CreateContainer(ctx context.Context, in *pb.CreateContainerRequest) (*pb.CreateContainerResponse, error) {
log.Printf("Received CreateContainer RPC: %s", in.ContainerName)
response := &pb.CreateContainerResponse{}
container, _, _ := s.lxd.GetContainer(in.ContainerName)
if container != nil {
response.Status = pb.CreateContainerResponse_EXISTS
return response, nil
}
imageServer, err := lxd.ConnectSimpleStreams(in.ImageServer, nil)
if err != nil {
response.Status = pb.CreateContainerResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to connect to simplestreams image server: %v", err)
return response, nil
}
alias, _, err := imageServer.GetImageAlias(in.ImageAlias)
if err != nil {
response.Status = pb.CreateContainerResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to get alias: %v", err)
return response, nil
}
image, _, err := imageServer.GetImage(alias.Target)
if err != nil {
response.Status = pb.CreateContainerResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to get image for alias: %v", err)
return response, nil
}
containersPost := api.ContainersPost{
Name: in.ContainerName,
Source: api.ContainerSource{
Type: "image",
Alias: alias.Name,
},
}
op, err := s.lxd.CreateContainerFromImage(imageServer, *image, containersPost)
if err != nil {
response.Status = pb.CreateContainerResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to create container from image: %v", err)
return response, nil
}
_, err = op.AddHandler(func(op api.Operation) { s.handleCreateOperation(op) })
if err != nil {
log.Fatal("Failed to add create operation handler: ", err)
}
response.Status = pb.CreateContainerResponse_CREATING
return response, nil
}
// StartContainer implements tremplin.StartContainer.
func (s *tremplinServer) StartContainer(ctx context.Context, in *pb.StartContainerRequest) (*pb.StartContainerResponse, error) {
log.Printf("Received StartContainer RPC: %s", in.ContainerName)
response := &pb.StartContainerResponse{}
container, etag, err := s.lxd.GetContainer(in.ContainerName)
if err != nil {
response.Status = pb.StartContainerResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to find container: %v", err)
return response, nil
}
if container.StatusCode == api.Running {
response.Status = pb.StartContainerResponse_RUNNING
return response, nil
}
// Prepare SSH keys and token.
// Clear out all existing devices for the container.
containerPut := container.Writable()
containerPut.Devices = map[string]map[string]string{}
err = os.MkdirAll(fmt.Sprintf("/run/sshd/%s", container.Name), 0644)
if err != nil {
response.Status = pb.StartContainerResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to create ssh key dir: %v", err)
return response, nil
}
bindMounts := []struct {
name string
content string
source string
dest string
}{
{
name: "container_token",
content: in.Token,
source: fmt.Sprintf("/run/tokens/%s_token", container.Name),
dest: "/dev/.container_token",
},
{
name: "ssh_authorized_keys",
content: in.ContainerPublicKey,
source: fmt.Sprintf("/run/sshd/%s/authorized_keys", container.Name),
dest: "/dev/.ssh/ssh_authorized_keys",
},
{
name: "ssh_host_key",
content: in.HostPrivateKey,
source: fmt.Sprintf("/run/sshd/%s/ssh_host_key", container.Name),
dest: "/dev/.ssh/ssh_host_key",
},
}
for _, b := range bindMounts {
// Disregard bind mounts without values.
if b.content == "" {
continue
}
err = ioutil.WriteFile(b.source, []byte(b.content), 0644)
if err != nil {
response.Status = pb.StartContainerResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to write %q: %v", b.source, err)
return response, nil
}
containerPut.Devices[b.name] = map[string]string{
"source": b.source,
"path": b.dest,
"type": "disk",
}
}
op, err := s.lxd.UpdateContainer(container.Name, containerPut, etag)
if err != nil {
response.Status = pb.StartContainerResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to set up devices: %v", err)
return response, nil
}
if err = op.Wait(); err != nil {
response.Status = pb.StartContainerResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to wait for container update: %v", err)
return response, nil
}
reqState := api.ContainerStatePut{
Action: "start",
Timeout: -1,
}
op, err = s.lxd.UpdateContainerState(container.Name, reqState, "")
if err != nil {
response.Status = pb.StartContainerResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to start container: %v", err)
return response, nil
}
if err = op.Wait(); err != nil {
response.Status = pb.StartContainerResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to wait for container startup: %v", err)
return response, nil
}
opAPI := op.Get()
switch opAPI.StatusCode {
case api.Success:
response.Status = pb.StartContainerResponse_STARTED
case api.Cancelled, api.Failure:
response.Status = pb.StartContainerResponse_FAILED
response.FailureReason = opAPI.Err
}
return response, nil
}
// GetContainerUsername implements tremplin.GetContainerUsername.
func (s *tremplinServer) GetContainerUsername(ctx context.Context, in *pb.GetContainerUsernameRequest) (*pb.GetContainerUsernameResponse, error) {
log.Printf("Received GetContainerUsername RPC: %s", in.ContainerName)
response := &pb.GetContainerUsernameResponse{}
c, _, err := s.lxd.GetContainer(in.ContainerName)
if err != nil {
response.Status = pb.GetContainerUsernameResponse_CONTAINER_NOT_FOUND
response.FailureReason = fmt.Sprintf("failed to find container: %v", err)
return response, nil
}
if c.StatusCode != api.Running {
response.Status = pb.GetContainerUsernameResponse_CONTAINER_NOT_RUNNING
response.FailureReason = fmt.Sprintf("container not running, status is: %d", c.StatusCode)
return response, nil
}
ret, stdout, stderr, err := s.execProgram(in.ContainerName, []string{"id", "-nu", "1000"})
if err != nil {
response.Status = pb.GetContainerUsernameResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to run id program: %v", err)
return response, nil
}
if ret != 0 {
response.Status = pb.GetContainerUsernameResponse_USER_NOT_FOUND
response.FailureReason = fmt.Sprintf("failed to get user for uid: %v", stderr)
return response, nil
}
response.Status = pb.GetContainerUsernameResponse_SUCCESS
response.Username = stdout
return response, nil
}
// SetUpUser implements tremplin.SetUpUser.
func (s *tremplinServer) SetUpUser(ctx context.Context, in *pb.SetUpUserRequest) (*pb.SetUpUserResponse, error) {
log.Printf("Received SetUpUser RPC: %s (username %s)", in.ContainerName, in.ContainerUsername)
response := &pb.SetUpUserResponse{}
// Check if uid 1000 exists first - if it does, leave it alone.
ret, stdout, stderr, err := s.execProgram(in.ContainerName, []string{"id", "-nu", "1000"})
if err != nil {
response.Status = pb.SetUpUserResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to check uid: %v", err)
return response, nil
}
if ret == 0 {
response.Status = pb.SetUpUserResponse_EXISTS
response.FailureReason = fmt.Sprintf("user exists: %s", stdout)
return response, nil
}
ret, stdout, stderr, err = s.execProgram(in.ContainerName,
[]string{"useradd", "-u", "1000", "-s", "/bin/bash", "-m", in.ContainerUsername})
if err != nil {
response.Status = pb.SetUpUserResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to run useradd: %v", err)
return response, nil
}
if ret != 0 {
response.Status = pb.SetUpUserResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to add user: %s", stderr)
return response, nil
}
groups := []string{
"audio",
"cdrom",
"dialout",
"floppy",
"plugdev",
"sudo",
"users",
"video",
}
// Add groups, but don't fail - groups might not exist in the container.
for _, group := range groups {
ret, stdout, stderr, err = s.execProgram(in.ContainerName,
[]string{"usermod", "-aG", group, in.ContainerUsername})
if err != nil {
response.Status = pb.SetUpUserResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to run useradd: %v", err)
return response, nil
}
}
// Enable loginctl linger for the target user.
ret, stdout, stderr, err = s.execProgram(in.ContainerName,
[]string{"loginctl", "enable-linger", in.ContainerUsername})
if err != nil {
response.Status = pb.SetUpUserResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to run loginctl: %v", err)
return response, nil
}
if ret != 0 {
response.Status = pb.SetUpUserResponse_FAILED
response.FailureReason = fmt.Sprintf("failed to enable linger: %s", stderr)
return response, nil
}
response.Status = pb.SetUpUserResponse_SUCCESS
return response, nil
}