From a908fca3a54f8b17131d322045249ce24ca8e184 Mon Sep 17 00:00:00 2001 From: Stefan Majer Date: Tue, 24 Mar 2026 15:24:10 +0100 Subject: [PATCH] Simplify client tests --- go/client/client_test.go | 186 ++++++++++++----------------- go/client/test_interceptor.go | 72 +++++++++++ go/client/test_interceptor_test.go | 51 ++++++++ 3 files changed, 201 insertions(+), 108 deletions(-) create mode 100644 go/client/test_interceptor.go create mode 100644 go/client/test_interceptor_test.go diff --git a/go/client/client_test.go b/go/client/client_test.go index 686439b7..3dc2399e 100644 --- a/go/client/client_test.go +++ b/go/client/client_test.go @@ -12,6 +12,7 @@ import ( "os" "strings" "testing" + "testing/synctest" "time" @@ -19,7 +20,6 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/metal-stack/api/go/client" apiv2 "github.com/metal-stack/api/go/metalstack/api/v2" - "github.com/metal-stack/api/go/metalstack/api/v2/apiv2connect" infrav2 "github.com/metal-stack/api/go/metalstack/infra/v2" "github.com/metal-stack/api/go/metalstack/infra/v2/infrav2connect" "github.com/stretchr/testify/require" @@ -27,60 +27,88 @@ import ( func Test_Client(t *testing.T) { var ( - vs = &mockVersionService{} - ts = &mockTokenService{} - mux = http.NewServeMux() log = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) ) - mux.Handle(apiv2connect.NewVersionServiceHandler(vs)) - mux.Handle(apiv2connect.NewTokenServiceHandler(ts)) - server := httptest.NewTLSServer(mux) - server.EnableHTTP2 = true - defer func() { - server.Close() - }() - - tokenString, err := generateToken(2 * time.Second) - require.NoError(t, err) - - c, err := client.New(&client.DialConfig{ - BaseURL: server.URL, - Token: tokenString, - Transport: server.Client().Transport, - TokenRenewal: &client.TokenRenewal{ - PersistTokenFn: func(token string) error { - ts.token = token - t.Log("token persisted:", token) - return nil + synctest.Test(t, func(t *testing.T) { + tokenString, err := generateToken(2 * time.Second) + require.NoError(t, err) + var renewedToken string + + c, err := client.New(&client.DialConfig{ + BaseURL: "http://localhost", + Token: tokenString, + + Interceptors: []connect.Interceptor{ + client.NewTestInterceptor(t, []client.ClientCall{ + { + WantRequest: &apiv2.VersionServiceGetRequest{}, + WantResponse: func() connect.AnyResponse { + return connect.NewResponse(&apiv2.VersionServiceGetResponse{ + Version: &apiv2.Version{Version: "1.0"}, + }) + }, + }, + { + WantRequest: &apiv2.VersionServiceGetRequest{}, + WantResponse: func() connect.AnyResponse { + return connect.NewResponse(&apiv2.VersionServiceGetResponse{ + Version: &apiv2.Version{Version: "1.0"}, + }) + }, + }, + { + WantRequest: &apiv2.TokenServiceRefreshRequest{}, + WantResponse: func() connect.AnyResponse { + tokenString, err := generateToken(2 * time.Second) + require.NoError(t, err) + + return connect.NewResponse(&apiv2.TokenServiceRefreshResponse{ + Secret: tokenString, + }) + }, + }, + { + WantRequest: &apiv2.VersionServiceGetRequest{}, + WantResponse: func() connect.AnyResponse { + return connect.NewResponse(&apiv2.VersionServiceGetResponse{ + Version: &apiv2.Version{Version: "1.0"}, + }) + }, + }, + }), + }, + TokenRenewal: &client.TokenRenewal{ + PersistTokenFn: func(token string) error { + renewedToken = token + return nil + }, }, - }, - Log: log, + Log: log, + }) + + require.NoError(t, err) + v, err := c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{}) + require.NoError(t, err) + require.NotNil(t, v) + require.Equal(t, "1.0", v.Version.Version) + require.Empty(t, renewedToken) + + time.Sleep(1 * time.Second) + v, err = c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{}) + require.NoError(t, err) + require.NotNil(t, v) + require.Equal(t, "1.0", v.Version.Version) + require.Empty(t, renewedToken) + + time.Sleep(3 * time.Second) + v, err = c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{}) + require.NoError(t, err) + require.NotNil(t, v) + require.Equal(t, "1.0", v.Version.Version) + require.NotEmpty(t, renewedToken) + require.NotEqual(t, renewedToken, tokenString, "haven't changed") }) - require.NoError(t, err) - v, err := c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{}) - require.NoError(t, err) - require.NotNil(t, v) - require.Equal(t, "1.0", v.Version.Version) - require.False(t, ts.wasCalled) - require.Equal(t, tokenString, vs.token) - - time.Sleep(1 * time.Second) - v, err = c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{}) - require.NoError(t, err) - require.NotNil(t, v) - require.Equal(t, "1.0", v.Version.Version) - require.False(t, ts.wasCalled) - require.Equal(t, tokenString, vs.token) - - time.Sleep(1 * time.Second) - v, err = c.Apiv2().Version().Get(t.Context(), &apiv2.VersionServiceGetRequest{}) - require.NoError(t, err) - require.NotNil(t, v) - require.Equal(t, "1.0", v.Version.Version) - - require.True(t, ts.wasCalled) - require.NotEqual(t, tokenString, ts.token, "token must have changed") } func generateToken(duration time.Duration) (string, error) { @@ -102,64 +130,6 @@ func generateToken(duration time.Duration) (string, error) { return tokenString, nil } -type mockVersionService struct { - token string -} - -func (m *mockVersionService) Get(ctx context.Context, req *apiv2.VersionServiceGetRequest) (*apiv2.VersionServiceGetResponse, error) { - callinfo, _ := connect.CallInfoForHandlerContext(ctx) - authHeader := callinfo.RequestHeader().Get("Authorization") - - _, token, found := strings.Cut(authHeader, "Bearer ") - - if !found { - return nil, fmt.Errorf("unable to extract token from header:%s", authHeader) - } - - m.token = token - return &apiv2.VersionServiceGetResponse{Version: &apiv2.Version{Version: "1.0"}}, nil -} - -type mockTokenService struct { - wasCalled bool - token string -} - -// Create implements apiv2connect.TokenServiceHandler. -func (m *mockTokenService) Create(context.Context, *apiv2.TokenServiceCreateRequest) (*apiv2.TokenServiceCreateResponse, error) { - panic("unimplemented") -} - -// Get implements apiv2connect.TokenServiceHandler. -func (m *mockTokenService) Get(context.Context, *apiv2.TokenServiceGetRequest) (*apiv2.TokenServiceGetResponse, error) { - panic("unimplemented") -} - -// List implements apiv2connect.TokenServiceHandler. -func (m *mockTokenService) List(context.Context, *apiv2.TokenServiceListRequest) (*apiv2.TokenServiceListResponse, error) { - panic("unimplemented") -} - -// Refresh implements apiv2connect.TokenServiceHandler. -func (m *mockTokenService) Refresh(ctx context.Context, _ *apiv2.TokenServiceRefreshRequest) (*apiv2.TokenServiceRefreshResponse, error) { - token, err := generateToken(2 * time.Second) - if err != nil { - return nil, err - } - m.wasCalled = true - return &apiv2.TokenServiceRefreshResponse{Token: &apiv2.Token{}, Secret: token}, nil -} - -// Revoke implements apiv2connect.TokenServiceHandler. -func (m *mockTokenService) Revoke(context.Context, *apiv2.TokenServiceRevokeRequest) (*apiv2.TokenServiceRevokeResponse, error) { - panic("unimplemented") -} - -// Update implements apiv2connect.TokenServiceHandler. -func (m *mockTokenService) Update(context.Context, *apiv2.TokenServiceUpdateRequest) (*apiv2.TokenServiceUpdateResponse, error) { - panic("unimplemented") -} - func Test_ClientInterceptors(t *testing.T) { var ( bs = &mockBMCService{} diff --git a/go/client/test_interceptor.go b/go/client/test_interceptor.go new file mode 100644 index 00000000..6724d5d6 --- /dev/null +++ b/go/client/test_interceptor.go @@ -0,0 +1,72 @@ +package client + +import ( + "context" + "reflect" + "testing" + + "connectrpc.com/connect" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/runtime/protoimpl" + "google.golang.org/protobuf/testing/protocmp" +) + +type TestClientInterceptor struct { + t *testing.T + calls []ClientCall + count int +} + +type ClientCall struct { + WantRequest proto.Message + WantResponse func() connect.AnyResponse + WantError *connect.Error +} + +func NewTestInterceptor(t *testing.T, calls []ClientCall) *TestClientInterceptor { + return &TestClientInterceptor{ + t: t, + calls: calls, + } +} + +func (t *TestClientInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) { + defer func() { t.count++ }() + + if t.count >= len(t.calls) { + t.t.Errorf("received an unexpected client call of type %T: %v", ar.Any(), ar.Any()) + t.t.FailNow() + } + + call := t.calls[t.count] + + if diff := cmp.Diff(call.WantRequest, ar.Any(), protocmp.Transform(), IgnoreUnexported(), cmpopts.IgnoreTypes(protoimpl.MessageState{})); diff != "" { + t.t.Errorf("request diff (+got -want):\n %s", diff) + t.t.FailNow() + } + + if call.WantError != nil { + return nil, call.WantError + } + + return call.WantResponse(), nil + } +} + +func (t *TestClientInterceptor) WrapStreamingClient(connect.StreamingClientFunc) connect.StreamingClientFunc { + t.t.Errorf("streaming not supported") + return nil +} + +func (t *TestClientInterceptor) WrapStreamingHandler(connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + t.t.Errorf("streaming not supported") + return nil +} + +func IgnoreUnexported() cmp.Option { + // the exporter opt allows all unexported fields: https://github.com/google/go-cmp/pull/176 + return cmp.Exporter(func(reflect.Type) bool { return true }) +} diff --git a/go/client/test_interceptor_test.go b/go/client/test_interceptor_test.go new file mode 100644 index 00000000..c8f938a3 --- /dev/null +++ b/go/client/test_interceptor_test.go @@ -0,0 +1,51 @@ +package client_test + +import ( + "log/slog" + "testing" + + "connectrpc.com/connect" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + client "github.com/metal-stack/api/go/client" + apiv2 "github.com/metal-stack/api/go/metalstack/api/v2" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/runtime/protoimpl" + "google.golang.org/protobuf/testing/protocmp" +) + +func TestInterceptor(t *testing.T) { + cl, err := client.New(&client.DialConfig{ + BaseURL: "http://this-is-just-for-testing", + Interceptors: []connect.Interceptor{ + client.NewTestInterceptor(t, []client.ClientCall{ + { + WantRequest: &apiv2.IPServiceGetRequest{ + Ip: "1.2.3.4", + }, + WantResponse: func() connect.AnyResponse { + return connect.NewResponse(&apiv2.IPServiceGetResponse{ + Ip: &apiv2.IP{Ip: "1.2.3.4"}, + }) + }, + }, + }), + }, + UserAgent: "cli-test", + Log: slog.Default(), + }) + require.NoError(t, err) + + resp, err := cl.Apiv2().IP().Get(t.Context(), &apiv2.IPServiceGetRequest{ + Ip: "1.2.3.4", + }) + require.NoError(t, err) + + if diff := cmp.Diff(&apiv2.IPServiceGetResponse{ + Ip: &apiv2.IP{ + Ip: "1.2.3.4", + }, + }, resp, protocmp.Transform(), client.IgnoreUnexported(), cmpopts.IgnoreTypes(protoimpl.MessageState{})); diff != "" { + t.Errorf("diff = %s", diff) + } +}