| // Copyright 2018 The Chromium OS Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| package main |
| |
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "io/ioutil" |
| "log" |
| "os" |
| "regexp" |
| "strconv" |
| "strings" |
| |
| pb "chromiumos/vm_tools/tremplin_proto" |
| "github.com/lxc/lxd/client" |
| "github.com/lxc/lxd/shared/api" |
| "google.golang.org/grpc" |
| ) |
| |
| // downloadRegexp extracts the download type and progress percentage from |
| // download operation metadata. |
| var downloadRegexp *regexp.Regexp |
| |
| func init() { |
| // Example matches: |
| // "metadata: 100% (5.23MB/s)" matches ("metadata", "100") |
| // "rootfs: 23% (358.09kB/s)" matches ("rootfs", "23") |
| downloadRegexp = regexp.MustCompile("([[:alpha:]]+): ([[:digit:]]+)% [0-9A-Za-z /.()]*$") |
| } |
| |
| // getContainerName converts an LXD source path (/1.0/containers/foo) to a container name. |
| func getContainerName(s string) (string, error) { |
| components := strings.Split(s, "/") |
| // Expected components are: "", "1.0", "containers", "<container name>". |
| if len(components) != 4 { |
| return "", fmt.Errorf("invalid source path: %q", s) |
| } |
| if components[2] != "containers" { |
| return "", fmt.Errorf("source path is not a container: %q", s) |
| } |
| return components[3], nil |
| } |
| |
| // getDownloadPercentage extracts the download progress (as a percentage) |
| // from an api.Operation's Metadata map. |
| func getDownloadPercentage(opMetadata map[string]interface{}) (int32, error) { |
| progress, ok := opMetadata["download_progress"].(string) |
| if !ok { |
| return 0, errors.New("could not read operation download progress") |
| } |
| |
| matches := downloadRegexp.FindStringSubmatch(progress) |
| if matches == nil { |
| return 0, fmt.Errorf("didn't find download status in %q", progress) |
| } |
| |
| downloadPercent, err := strconv.ParseInt(matches[2], 10, 32) |
| if err != nil { |
| return 0, fmt.Errorf("failed to convert download percent to int: %q", matches[2]) |
| } |
| |
| // Count metadata download as 0% of the total, since the entire rootfs still |
| // needs to be downloaded. |
| if matches[1] == "metadata" { |
| downloadPercent = 0 |
| } |
| |
| return int32(downloadPercent), nil |
| } |
| |
| // server is used to implement the gRPC tremplin.Server. |
| type tremplinServer struct { |
| lxd lxd.ContainerServer |
| grpcServer *grpc.Server |
| listenerClient pb.TremplinListenerClient |
| } |
| |
| // execProgram runs a program in a container to completion, capturing its |
| // return value, stdout, and stderr. |
| func (s *tremplinServer) execProgram(containerName string, args []string) (ret int, stdout string, stderr string, err error) { |
| req := api.ContainerExecPost{ |
| Command: args, |
| WaitForWS: true, |
| Interactive: false, |
| } |
| |
| stdoutSink := &stdioSink{} |
| stderrSink := &stdioSink{} |
| execArgs := &lxd.ContainerExecArgs{ |
| Stdin: &stdioSink{}, |
| Stdout: stdoutSink, |
| Stderr: stderrSink, |
| } |
| |
| op, err := s.lxd.ExecContainer(containerName, req, execArgs) |
| if err != nil { |
| return 0, "", "", err |
| } |
| |
| if err = op.Wait(); err != nil { |
| return 0, "", "", err |
| } |
| |
| retVal, ok := op.Metadata["return"].(float64) |
| if !ok { |
| return 0, "", "", fmt.Errorf("return value for %q is not a float64", args[0]) |
| } |
| return int(retVal), stdoutSink.String(), stderrSink.String(), nil |
| } |
| |
| func (s *tremplinServer) handleCreateOperation(op api.Operation) { |
| containers := op.Resources["containers"] |
| |
| if len(containers) != 1 { |
| log.Printf("Got %v containers instead of 1", len(containers)) |
| return |
| } |
| |
| name, err := getContainerName(containers[0]) |
| if err != nil { |
| log.Printf("Failed to get container name for operation: %v", err) |
| return |
| } |
| |
| req := &pb.ContainerCreationProgress{ |
| ContainerName: name, |
| } |
| |
| switch op.StatusCode { |
| case api.Success: |
| req.Status = pb.ContainerCreationProgress_CREATED |
| case api.Running: |
| req.Status = pb.ContainerCreationProgress_DOWNLOADING |
| downloadPercent, err := getDownloadPercentage(op.Metadata) |
| if err != nil { |
| log.Printf("Failed to parse download percentage: %v", err) |
| return |
| } |
| req.DownloadProgress = downloadPercent |
| case api.Cancelled, api.Failure: |
| req.Status = pb.ContainerCreationProgress_FAILED |
| req.FailureReason = op.Err |
| default: |
| req.Status = pb.ContainerCreationProgress_UNKNOWN |
| req.FailureReason = fmt.Sprintf("unhandled create status: %s", op.Status) |
| } |
| |
| _, err = s.listenerClient.UpdateCreateStatus(context.Background(), req) |
| if err != nil { |
| log.Printf("Could not update create status on host: %v", err) |
| return |
| } |
| } |
| |
| // CreateContainer implements tremplin.CreateContainer. |
| func (s *tremplinServer) CreateContainer(ctx context.Context, in *pb.CreateContainerRequest) (*pb.CreateContainerResponse, error) { |
| log.Printf("Received CreateContainer RPC: %s", in.ContainerName) |
| |
| response := &pb.CreateContainerResponse{} |
| |
| container, _, err := s.lxd.GetContainer(in.ContainerName) |
| if err != nil { |
| response.Status = pb.CreateContainerResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to check if container exists: %v", err) |
| return response, nil |
| } |
| if container != nil { |
| response.Status = pb.CreateContainerResponse_EXISTS |
| return response, nil |
| } |
| |
| imageServer, err := lxd.ConnectSimpleStreams(in.ImageServer, nil) |
| if err != nil { |
| response.Status = pb.CreateContainerResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to connect to simplestreams image server: %v", err) |
| return response, nil |
| } |
| |
| alias, _, err := imageServer.GetImageAlias(in.ImageAlias) |
| if err != nil { |
| response.Status = pb.CreateContainerResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to get alias: %v", err) |
| return response, nil |
| } |
| |
| image, _, err := imageServer.GetImage(alias.Target) |
| if err != nil { |
| response.Status = pb.CreateContainerResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to get image for alias: %v", err) |
| return response, nil |
| } |
| |
| containersPost := api.ContainersPost{ |
| Name: in.ContainerName, |
| Source: api.ContainerSource{ |
| Type: "image", |
| Alias: alias.Name, |
| }, |
| } |
| op, err := s.lxd.CreateContainerFromImage(imageServer, *image, containersPost) |
| if err != nil { |
| response.Status = pb.CreateContainerResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to create container from image: %v", err) |
| return response, nil |
| } |
| |
| _, err = op.AddHandler(func(op api.Operation) { s.handleCreateOperation(op) }) |
| if err != nil { |
| log.Fatal("Failed to add create operation handler: ", err) |
| } |
| |
| response.Status = pb.CreateContainerResponse_CREATING |
| |
| return response, nil |
| } |
| |
| // StartContainer implements tremplin.StartContainer. |
| func (s *tremplinServer) StartContainer(ctx context.Context, in *pb.StartContainerRequest) (*pb.StartContainerResponse, error) { |
| log.Printf("Received StartContainer RPC: %s", in.ContainerName) |
| |
| response := &pb.StartContainerResponse{} |
| |
| container, etag, err := s.lxd.GetContainer(in.ContainerName) |
| if err != nil { |
| response.Status = pb.StartContainerResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to find container: %v", err) |
| return response, nil |
| } |
| |
| if container.StatusCode == api.Running { |
| response.Status = pb.StartContainerResponse_RUNNING |
| return response, nil |
| } |
| |
| // Prepare SSH keys and token. |
| // Clear out all existing devices for the container. |
| containerPut := container.Writable() |
| containerPut.Devices = map[string]map[string]string{} |
| err = os.MkdirAll(fmt.Sprintf("/run/sshd/%s", container.Name), 0644) |
| if err != nil { |
| response.Status = pb.StartContainerResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to create ssh key dir: %v", err) |
| return response, nil |
| } |
| bindMounts := []struct { |
| name string |
| content string |
| source string |
| dest string |
| }{ |
| { |
| name: "container_token", |
| content: in.Token, |
| source: fmt.Sprintf("/run/tokens/%s_token", container.Name), |
| dest: "/dev/.container_token", |
| }, |
| { |
| name: "ssh_authorized_keys", |
| content: in.ContainerPublicKey, |
| source: fmt.Sprintf("/run/sshd/%s/authorized_keys", container.Name), |
| dest: "/dev/.ssh/ssh_authorized_keys", |
| }, |
| { |
| name: "ssh_host_key", |
| content: in.HostPrivateKey, |
| source: fmt.Sprintf("/run/sshd/%s/ssh_host_key", container.Name), |
| dest: "/dev/.ssh/ssh_host_key", |
| }, |
| } |
| for _, b := range bindMounts { |
| // Disregard bind mounts without values. |
| if b.content == "" { |
| continue |
| } |
| |
| err = ioutil.WriteFile(b.source, []byte(b.content), 0644) |
| if err != nil { |
| response.Status = pb.StartContainerResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to write %q: %v", b.source, err) |
| return response, nil |
| } |
| |
| containerPut.Devices[b.name] = map[string]string{ |
| "source": b.source, |
| "path": b.dest, |
| "type": "disk", |
| } |
| } |
| |
| op, err := s.lxd.UpdateContainer(container.Name, containerPut, etag) |
| if err != nil { |
| response.Status = pb.StartContainerResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to set up devices: %v", err) |
| return response, nil |
| } |
| |
| if err = op.Wait(); err != nil { |
| response.Status = pb.StartContainerResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to wait for container update: %v", err) |
| return response, nil |
| } |
| |
| reqState := api.ContainerStatePut{ |
| Action: "start", |
| Timeout: -1, |
| } |
| op, err = s.lxd.UpdateContainerState(container.Name, reqState, "") |
| if err != nil { |
| response.Status = pb.StartContainerResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to start container: %v", err) |
| return response, nil |
| } |
| |
| if err = op.Wait(); err != nil { |
| response.Status = pb.StartContainerResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to wait for container startup: %v", err) |
| return response, nil |
| } |
| |
| switch op.StatusCode { |
| case api.Success: |
| response.Status = pb.StartContainerResponse_STARTED |
| case api.Cancelled, api.Failure: |
| response.Status = pb.StartContainerResponse_FAILED |
| response.FailureReason = op.Err |
| } |
| |
| return response, nil |
| } |
| |
| // GetContainerUsername implements tremplin.GetContainerUsername. |
| func (s *tremplinServer) GetContainerUsername(ctx context.Context, in *pb.GetContainerUsernameRequest) (*pb.GetContainerUsernameResponse, error) { |
| log.Printf("Received GetContainerUsername RPC: %s", in.ContainerName) |
| |
| response := &pb.GetContainerUsernameResponse{} |
| |
| c, _, err := s.lxd.GetContainer(in.ContainerName) |
| if err != nil { |
| response.Status = pb.GetContainerUsernameResponse_CONTAINER_NOT_FOUND |
| response.FailureReason = fmt.Sprintf("failed to find container: %v", err) |
| return response, nil |
| } |
| if c.StatusCode != api.Running { |
| response.Status = pb.GetContainerUsernameResponse_CONTAINER_NOT_RUNNING |
| response.FailureReason = fmt.Sprintf("container not running, status is: %d", c.StatusCode) |
| return response, nil |
| } |
| |
| ret, stdout, stderr, err := s.execProgram(in.ContainerName, []string{"id", "-nu", "1000"}) |
| if err != nil { |
| response.Status = pb.GetContainerUsernameResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to run id program: %v", err) |
| return response, nil |
| } |
| if ret != 0 { |
| response.Status = pb.GetContainerUsernameResponse_USER_NOT_FOUND |
| response.FailureReason = fmt.Sprintf("failed to get user for uid: %v", stderr) |
| return response, nil |
| } |
| |
| response.Status = pb.GetContainerUsernameResponse_SUCCESS |
| response.Username = stdout |
| return response, nil |
| } |
| |
| // SetUpUser implements tremplin.SetUpUser. |
| func (s *tremplinServer) SetUpUser(ctx context.Context, in *pb.SetUpUserRequest) (*pb.SetUpUserResponse, error) { |
| log.Printf("Received SetUpUser RPC: %s (username %s)", in.ContainerName, in.ContainerUsername) |
| |
| response := &pb.SetUpUserResponse{} |
| |
| // Check if uid 1000 exists first - if it does, leave it alone. |
| ret, stdout, stderr, err := s.execProgram(in.ContainerName, []string{"id", "-nu", "1000"}) |
| if err != nil { |
| response.Status = pb.SetUpUserResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to check uid: %v", err) |
| return response, nil |
| } |
| if ret == 0 { |
| response.Status = pb.SetUpUserResponse_EXISTS |
| response.FailureReason = fmt.Sprintf("user exists: %s", stdout) |
| return response, nil |
| } |
| |
| ret, stdout, stderr, err = s.execProgram(in.ContainerName, |
| []string{"useradd", "-u", "1000", "-s", "/bin/bash", "-m", in.ContainerUsername}) |
| if err != nil { |
| response.Status = pb.SetUpUserResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to run useradd: %v", err) |
| return response, nil |
| } |
| if ret != 0 { |
| response.Status = pb.SetUpUserResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to add user: %s", stderr) |
| return response, nil |
| } |
| |
| groups := []string{ |
| "audio", |
| "cdrom", |
| "dialout", |
| "floppy", |
| "plugdev", |
| "sudo", |
| "users", |
| "video", |
| } |
| |
| // Add groups, but don't fail - groups might not exist in the container. |
| for _, group := range groups { |
| ret, stdout, stderr, err = s.execProgram(in.ContainerName, |
| []string{"usermod", "-aG", group, in.ContainerUsername}) |
| if err != nil { |
| response.Status = pb.SetUpUserResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to run useradd: %v", err) |
| return response, nil |
| } |
| } |
| |
| // Enable loginctl linger for the target user. |
| ret, stdout, stderr, err = s.execProgram(in.ContainerName, |
| []string{"loginctl", "enable-linger", in.ContainerUsername}) |
| if err != nil { |
| response.Status = pb.SetUpUserResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to run loginctl: %v", err) |
| return response, nil |
| } |
| if ret != 0 { |
| response.Status = pb.SetUpUserResponse_FAILED |
| response.FailureReason = fmt.Sprintf("failed to enable linger: %s", stderr) |
| return response, nil |
| } |
| |
| response.Status = pb.SetUpUserResponse_SUCCESS |
| |
| return response, nil |
| } |