Allow multiple port forwards for a single Pod (#1804)

mine
Grzegorz Burzyński 2022-10-18 14:47:34 +02:00 committed by GitHub
parent f25d7865a6
commit 0cefb3ec12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 116 additions and 10 deletions

View File

@ -47,19 +47,10 @@ func (p *PortForwardExtender) portFwdCmd(evt *tcell.EventKey) *tcell.EventKey {
p.App().Flash().Err(err)
return nil
}
pod, err := fetchPod(p.App().factory, podName)
if err != nil {
if err := ensurePodPortFwdAllowed(p.App().factory, podName); err != nil {
p.App().Flash().Err(err)
return nil
}
if pod.Status.Phase != v1.PodRunning {
p.App().Flash().Errf("pod must be running. Current status=%v", pod.Status.Phase)
return nil
}
if p.App().factory.Forwarders().IsPodForwarded(path) {
p.App().Flash().Errf("A PortForward already exists for pod %s", pod.Name)
return nil
}
if err := showFwdDialog(p, podName, startFwdCB); err != nil {
p.App().Flash().Err(err)
}
@ -83,6 +74,18 @@ func (p *PortForwardExtender) fetchPodName(path string) (string, error) {
// ----------------------------------------------------------------------------
// Helpers...
func ensurePodPortFwdAllowed(factory dao.Factory, podName string) error {
pod, err := fetchPod(factory, podName)
if err != nil {
return err
}
if pod.Status.Phase != v1.PodRunning {
return fmt.Errorf("pod must be running. Current status=%v", pod.Status.Phase)
}
return nil
}
func runForward(v ResourceViewer, pf watch.Forwarder, f *portforward.PortForwarder) {
v.App().factory.AddForwarder(pf)

View File

@ -0,0 +1,103 @@
package view
import (
"errors"
"testing"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/informers"
"github.com/derailed/k9s/internal/client"
"github.com/derailed/k9s/internal/dao"
"github.com/derailed/k9s/internal/watch"
)
func TestEnsurePodPortFwdAllowed(t *testing.T) {
testCases := []struct {
name string
podExists bool
podPhase corev1.PodPhase
expectError bool
}{
{
name: "pod_doesnt_exist",
expectError: true,
},
{
name: "pod_exists_pending",
podExists: true,
podPhase: corev1.PodPending,
expectError: true,
},
{
name: "pod_is_running",
podExists: true,
podPhase: corev1.PodRunning,
expectError: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
f := testFactory{}
if tc.podExists {
f.expectedGet = &unstructured.Unstructured{
Object: map[string]interface{}{
"status": map[string]interface{}{
"phase": tc.podPhase,
},
},
}
}
err := ensurePodPortFwdAllowed(f, "ns/name")
if tc.expectError {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
type testFactory struct {
expectedGet runtime.Object
}
var _ dao.Factory = testFactory{}
func (t testFactory) Client() client.Connection {
return nil
}
func (t testFactory) Get(string, string, bool, labels.Selector) (runtime.Object, error) {
if t.expectedGet != nil {
return t.expectedGet, nil
}
return nil, errors.New("not found")
}
func (t testFactory) List(string, string, bool, labels.Selector) ([]runtime.Object, error) {
return nil, nil
}
func (t testFactory) ForResource(string, string) (informers.GenericInformer, error) {
return nil, nil
}
func (t testFactory) CanForResource(string, string, []string) (informers.GenericInformer, error) {
return nil, nil
}
func (t testFactory) Forwarders() watch.Forwarders {
return nil
}
func (t testFactory) WaitForCacheSync() {}
func (t testFactory) DeleteForwarder(string) {}