k9s/internal/dao/port_forwarder.go

196 lines
5.2 KiB
Go

package dao
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/derailed/k9s/internal/client"
"github.com/derailed/k9s/internal/port"
"github.com/rs/zerolog/log"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
metav1beta1 "k8s.io/apimachinery/pkg/apis/meta/v1beta1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/runtime/serializer"
"k8s.io/cli-runtime/pkg/genericclioptions"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/portforward"
"k8s.io/client-go/transport/spdy"
)
// PortForwarder tracks a port forward stream.
type PortForwarder struct {
Factory
genericclioptions.IOStreams
stopChan, readyChan chan struct{}
active bool
path string
tunnel port.PortTunnel
age time.Time
}
// NewPortForwarder returns a new port forward streamer.
func NewPortForwarder(f Factory) *PortForwarder {
return &PortForwarder{
Factory: f,
stopChan: make(chan struct{}),
readyChan: make(chan struct{}),
}
}
// String dumps as string.
func (p *PortForwarder) String() string {
return fmt.Sprintf("%s|%s", p.path, p.tunnel)
}
// Age returns the port forward age.
func (p *PortForwarder) Age() string {
return time.Since(p.age).String()
}
// Active returns the forward status.
func (p *PortForwarder) Active() bool {
return p.active
}
// SetActive mark a portforward as active.
func (p *PortForwarder) SetActive(b bool) {
p.active = b
}
// Port returns the port mapping.
func (p *PortForwarder) Port() string {
return p.tunnel.PortMap()
}
// ContainerPort returns the container port.
func (p *PortForwarder) ContainerPort() string {
return p.tunnel.ContainerPort
}
// LocalPort returns the local port.
func (p *PortForwarder) LocalPort() string {
return p.tunnel.LocalPort
}
// ID returns a pf id.
func (p *PortForwarder) ID() string {
return PortForwardID(p.path, p.tunnel.Container, p.tunnel.PortMap())
}
// Container returns the target's container.
func (p *PortForwarder) Container() string {
return p.tunnel.Container
}
// Stop terminates a port forward.
func (p *PortForwarder) Stop() {
log.Debug().Msgf("<<< Stopping PortForward %s", p.ID())
p.active = false
close(p.stopChan)
}
// FQN returns the portforward unique id.
func (p *PortForwarder) FQN() string {
return p.path + ":" + p.tunnel.Container
}
// HasPortMapping checks if port mapping is defined for this fwd.
func (p *PortForwarder) HasPortMapping(portMap string) bool {
return p.tunnel.PortMap() == portMap
}
// Start initiates a port forward session for a given pod and ports.
func (p *PortForwarder) Start(path string, tt port.PortTunnel) (*portforward.PortForwarder, error) {
p.path, p.tunnel, p.age = path, tt, time.Now()
ns, n := client.Namespaced(path)
auth, err := p.Client().CanI(ns, "v1/pods", []string{client.GetVerb})
if err != nil {
return nil, err
}
if !auth {
return nil, fmt.Errorf("user is not authorized to get pods")
}
podName := strings.Split(n, "|")[0]
var res Pod
res.Init(p, client.NewGVR("v1/pods"))
pod, err := res.GetInstance(client.FQN(ns, podName))
if err != nil {
return nil, err
}
if pod.Status.Phase != v1.PodRunning {
return nil, fmt.Errorf("unable to forward port because pod is not running. Current status=%v", pod.Status.Phase)
}
auth, err = p.Client().CanI(ns, "v1/pods:portforward", []string{client.CreateVerb})
if err != nil {
return nil, err
}
if !auth {
return nil, fmt.Errorf("user is not authorized to update portforward")
}
cfg, err := p.Client().RestConfig()
if err != nil {
return nil, err
}
cfg.GroupVersion = &schema.GroupVersion{Group: "", Version: "v1"}
cfg.APIPath = "/api"
codec, _ := codec()
cfg.NegotiatedSerializer = codec.WithoutConversion()
clt, err := rest.RESTClientFor(cfg)
if err != nil {
return nil, err
}
req := clt.Post().
Resource("pods").
Namespace(ns).
Name(podName).
SubResource("portforward")
return p.forwardPorts("POST", req.URL(), tt.Address, tt.PortMap())
}
func (p *PortForwarder) forwardPorts(method string, url *url.URL, addr, portMap string) (*portforward.PortForwarder, error) {
cfg, err := p.Client().Config().RESTConfig()
if err != nil {
return nil, err
}
transport, upgrader, err := spdy.RoundTripperFor(cfg)
if err != nil {
return nil, err
}
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, method, url)
return portforward.NewOnAddresses(dialer, []string{addr}, []string{portMap}, p.stopChan, p.readyChan, p.Out, p.ErrOut)
}
// ----------------------------------------------------------------------------
// Helpers...
// PortForwardID computes port-forward identifier.
func PortForwardID(path, co, portMap string) string {
if strings.Contains(path, "|") {
return path + "|" + portMap
}
return path + "|" + co + "|" + portMap
}
func codec() (serializer.CodecFactory, runtime.ParameterCodec) {
scheme := runtime.NewScheme()
gv := schema.GroupVersion{Group: "", Version: "v1"}
metav1.AddToGroupVersion(scheme, gv)
scheme.AddKnownTypes(gv, &metav1beta1.Table{}, &metav1beta1.TableOptions{})
scheme.AddKnownTypes(metav1beta1.SchemeGroupVersion, &metav1beta1.Table{}, &metav1beta1.TableOptions{})
return serializer.NewCodecFactory(scheme), runtime.NewParameterCodec(scheme)
}