From 922ed5d6ae5d16b6bce467f6f6b64cc54d59c5fe Mon Sep 17 00:00:00 2001 From: Mallory Hill Date: Mon, 22 Jun 2026 13:39:21 -0400 Subject: [PATCH] HYPERFLEET-1259 - fix: SQL injection protection for orderBy/search --- pkg/db/sql_helpers.go | 137 +++++++++++---- pkg/db/sql_helpers_test.go | 4 +- pkg/services/generic.go | 13 +- pkg/services/generic_test.go | 6 +- .../resource_sql_injection_test.go | 160 ++++++++++++++++++ 5 files changed, 280 insertions(+), 40 deletions(-) create mode 100644 test/integration/resource_sql_injection_test.go diff --git a/pkg/db/sql_helpers.go b/pkg/db/sql_helpers.go index d164c0e1..257cb188 100755 --- a/pkg/db/sql_helpers.go +++ b/pkg/db/sql_helpers.go @@ -61,8 +61,62 @@ var statusFieldMappings = map[string]string{ "status.conditions": "status_conditions", } +// OrderByAllowedFields defines valid columns for orderBy and search operations per resource type. +// When adding new GORM columns to model structs, update this map to make them sortable/searchable. +var OrderByAllowedFields = map[string]map[string]bool{ + "Cluster": { + "id": true, + "name": true, + "created_time": true, + "updated_time": true, + "deleted_time": true, + "kind": true, + "created_by": true, + "updated_by": true, + "deleted_by": true, + "generation": true, + "href": true, + "status_conditions": true, // mapped from status.conditions + }, + "NodePool": { + "id": true, + "name": true, + "created_time": true, + "updated_time": true, + "deleted_time": true, + "kind": true, + "created_by": true, + "updated_by": true, + "deleted_by": true, + "generation": true, + "href": true, + "owner_id": true, + "owner_kind": true, + "owner_href": true, + "status_conditions": true, + }, + "Resource": { + "id": true, + "name": true, + "created_time": true, + "updated_time": true, + "deleted_time": true, + "kind": true, + "created_by": true, + "updated_by": true, + "deleted_by": true, + "generation": true, + "href": true, + "owner_id": true, + "owner_kind": true, + "owner_href": true, + }, +} + // getField gets the sql field associated with a name. -func getField(name string, disallowedFields map[string]string) (field string, err *errors.ServiceError) { +func getField( + name string, disallowedFields map[string]string, allowedFields map[string]bool, +) (field string, err *errors.ServiceError) { // We want to accept names with trailing and leading spaces trimmedName := strings.Trim(name, " ") @@ -126,6 +180,13 @@ func getField(name string, disallowedFields map[string]string) (field string, er err = errors.BadRequest("%s is not a valid field name", name) return } + + // Validate field name against allowlist to prevent SQL injection + if allowedFields != nil && !allowedFields[trimmedName] { + err = errors.BadRequest("field '%s' is not valid for ordering or searching", name) + return + } + field = trimmedName return } @@ -569,7 +630,8 @@ func extractConditionsWalk(n tsl.Node, conditions *[]sq.Sqlizer) (tsl.Node, *err // b. replace the field name with the SQL column name. func FieldNameWalk( n tsl.Node, - disallowedFields map[string]string) (newNode tsl.Node, err *errors.ServiceError) { + disallowedFields map[string]string, + allowedFields map[string]bool) (newNode tsl.Node, err *errors.ServiceError) { var field string var l, r tsl.Node @@ -591,7 +653,7 @@ func FieldNameWalk( } // Check field name in the disallowedFields field names. - field, err = getField(userFieldName, disallowedFields) + field, err = getField(userFieldName, disallowedFields, allowedFields) if err != nil { return } @@ -609,7 +671,7 @@ func FieldNameWalk( err = errors.BadRequest("invalid node structure") return } - l, err = FieldNameWalk(leftNode, disallowedFields) + l, err = FieldNameWalk(leftNode, disallowedFields, allowedFields) if err != nil { return } @@ -620,7 +682,7 @@ func FieldNameWalk( switch v := n.Right.(type) { case tsl.Node: // It's a regular node, just add it. - r, err = FieldNameWalk(v, disallowedFields) + r, err = FieldNameWalk(v, disallowedFields, allowedFields) if err != nil { return } @@ -634,7 +696,7 @@ func FieldNameWalk( // Add all nodes in the right side array. for _, e := range v { - r, err = FieldNameWalk(e, disallowedFields) + r, err = FieldNameWalk(e, disallowedFields, allowedFields) if err != nil { return } @@ -661,28 +723,37 @@ func FieldNameWalk( } // cleanOrderBy takes the orderBy arg and cleans it. -func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy string, err *errors.ServiceError) { - var orderField string - +func cleanOrderBy( + userArg string, disallowedFields map[string]string, allowedFields map[string]bool, +) (orderBy string, err *errors.ServiceError) { // We want to accept user params with trailing and leading spaces - trimedName := strings.Trim(userArg, " ") + trimmedName := strings.Trim(userArg, " ") // Each OrderBy can be a "" or a " asc|desc" - order := strings.Split(trimedName, " ") - direction := "none valid" + order := strings.Split(trimmedName, " ") - if len(order) == 1 { - orderField, err = getField(order[0], disallowedFields) - direction = "asc" - } else if len(order) == 2 { - orderField, err = getField(order[0], disallowedFields) - direction = order[1] - } - if err != nil || (direction != "asc" && direction != "desc") { + // Reject invalid format (e.g., subqueries with multiple spaces) + if len(order) != 1 && len(order) != 2 { err = errors.BadRequest("bad order value '%s'", userArg) return } + // Validate field name + orderField, err := getField(order[0], disallowedFields, allowedFields) + if err != nil { + return "", err + } + + // Determine direction (default to asc) + direction := "asc" + if len(order) == 2 { + direction = strings.ToLower(order[1]) + if direction != "asc" && direction != "desc" { + err = errors.BadRequest("bad order value '%s'", userArg) + return + } + } + orderBy = fmt.Sprintf("%s %s", orderField, direction) return } @@ -690,22 +761,22 @@ func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy s // ArgsToOrderBy returns cleaned orderBy list. func ArgsToOrderBy( orderByArgs []string, - disallowedFields map[string]string) (orderBy []string, err *errors.ServiceError) { - - var order string - if len(orderByArgs) != 0 { - orderBy = []string{} - for _, o := range orderByArgs { - order, err = cleanOrderBy(o, disallowedFields) - if err != nil { - return - } + disallowedFields map[string]string, + allowedFields map[string]bool, +) (orderBy []string, err *errors.ServiceError) { + if len(orderByArgs) == 0 { + return nil, nil + } - // If valid add the user entered order by, to the order by list - orderBy = append(orderBy, order) + orderBy = make([]string, 0, len(orderByArgs)) + for _, arg := range orderByArgs { + order, err := cleanOrderBy(arg, disallowedFields, allowedFields) + if err != nil { + return nil, err } + orderBy = append(orderBy, order) } - return + return orderBy, nil } func GetTableName(g2 *gorm.DB) string { diff --git a/pkg/db/sql_helpers_test.go b/pkg/db/sql_helpers_test.go index f846065c..d00c2905 100644 --- a/pkg/db/sql_helpers_test.go +++ b/pkg/db/sql_helpers_test.go @@ -703,7 +703,7 @@ func TestGetField_SpecMapping(t *testing.T) { t.Run(tt.name, func(t *testing.T) { RegisterTestingT(t) - field, err := getField(tt.input, map[string]string{}) + field, err := getField(tt.input, map[string]string{}, nil) if tt.expectError { Expect(err).ToNot(BeNil()) } else { @@ -719,7 +719,7 @@ func TestGetField_SpecDisallowed(t *testing.T) { disallowed := map[string]string{"spec": "spec"} - _, err := getField("spec.is_default", disallowed) + _, err := getField("spec.is_default", disallowed, nil) Expect(err).ToNot(BeNil()) Expect(err.Reason).To(ContainSubstring("not a valid field name")) } diff --git a/pkg/services/generic.go b/pkg/services/generic.go index 864b3627..af1edb88 100755 --- a/pkg/services/generic.go +++ b/pkg/services/generic.go @@ -58,6 +58,7 @@ type listContext struct { args *ListArguments pagingMeta *api.PagingMeta disallowedFields *map[string]string + allowedFields *map[string]bool joins map[string]dao.TableRelation set map[string]bool resourceType string @@ -79,6 +80,13 @@ func (s *sqlGenericService) newListContext( if disallowedFields == nil { disallowedFields = allFieldsAllowed } + + allowedFields := db.OrderByAllowedFields[resourceTypeStr] + + if allowedFields == nil { + return nil, nil, errors.GeneralError("Could not determine what resource type to order by") + } + args.Search = strings.Trim(args.Search, " ") return &listContext{ ctx: ctx, @@ -86,6 +94,7 @@ func (s *sqlGenericService) newListContext( pagingMeta: &api.PagingMeta{Page: args.Page}, resourceList: resourceList, disallowedFields: &disallowedFields, + allowedFields: &allowedFields, resourceType: resourceTypeStr, }, reflect.New(resourceModel).Interface(), nil } @@ -150,7 +159,7 @@ func (s *sqlGenericService) buildPreload(listCtx *listContext, d *dao.GenericDao func (s *sqlGenericService) buildOrderBy(listCtx *listContext, d *dao.GenericDao) (bool, *errors.ServiceError) { if len(listCtx.args.OrderBy) != 0 { - orderByArgs, serviceErr := db.ArgsToOrderBy(listCtx.args.OrderBy, *listCtx.disallowedFields) + orderByArgs, serviceErr := db.ArgsToOrderBy(listCtx.args.OrderBy, *listCtx.disallowedFields, *listCtx.allowedFields) if serviceErr != nil { return false, serviceErr } @@ -196,7 +205,7 @@ func (s *sqlGenericService) buildSearchValues( // apply field name mapping first (status.xxx -> status_xxx, labels.xxx -> labels->>'xxx') // this must happen before treeWalkForRelatedTables to prevent treating "status" and "labels" as related resources - tslTree, serviceErr = db.FieldNameWalk(tslTree, *listCtx.disallowedFields) + tslTree, serviceErr = db.FieldNameWalk(tslTree, *listCtx.disallowedFields, *listCtx.allowedFields) if serviceErr != nil { return "", nil, serviceErr } diff --git a/pkg/services/generic_test.go b/pkg/services/generic_test.go index 607ab70c..b8a818c3 100755 --- a/pkg/services/generic_test.go +++ b/pkg/services/generic_test.go @@ -54,8 +54,8 @@ func TestSQLTranslation(t *testing.T) { // tests for sql parsing tests = []map[string]interface{}{ { - "search": "username in ('ooo.openshift')", - "sql": "username IN (?)", + "search": "created_by in ('ooo.openshift')", + "sql": "created_by IN (?)", "values": ConsistOf("ooo.openshift"), }, // Test status.conditions field mapping (use status.conditions.='' syntax for condition queries) @@ -85,7 +85,7 @@ func TestSQLTranslation(t *testing.T) { Expect(err).ToNot(HaveOccurred()) // Apply field name mapping (status.xxx -> status_xxx, labels.xxx -> labels->>'xxx') // This must happen before converting to sqlizer - tslTree, serviceErr = db.FieldNameWalk(tslTree, *listCtx.disallowedFields) + tslTree, serviceErr = db.FieldNameWalk(tslTree, *listCtx.disallowedFields, *listCtx.allowedFields) Expect(serviceErr).ToNot(HaveOccurred()) sqlizer, serviceErr := genericService.treeWalkForSqlizer(listCtx, tslTree) Expect(serviceErr).ToNot(HaveOccurred()) diff --git a/test/integration/resource_sql_injection_test.go b/test/integration/resource_sql_injection_test.go new file mode 100644 index 00000000..a43c7430 --- /dev/null +++ b/test/integration/resource_sql_injection_test.go @@ -0,0 +1,160 @@ +package integration + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/onsi/gomega" + + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/api/openapi" + "github.com/openshift-hyperfleet/hyperfleet-api/test" +) + +func TestClusterOrderBySQLInjection(t *testing.T) { + h, client := test.RegisterIntegration(t) + gomega.RegisterTestingT(t) + account := h.NewRandAccount() + ctx := h.NewAuthenticatedContext(account) + + // Create test data + _, err := h.Factories.NewClustersList("injection-test", 3) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + + // Test 1: SQL function injection (pg_sleep) + orderBy := openapi.QueryParamsOrderBy("pg_sleep(5)") + params := &openapi.GetClustersParams{OrderBy: &orderBy} + resp, err := client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusBadRequest)) + var errorResp openapi.ProblemDetails + err = json.Unmarshal(resp.Body, &errorResp) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(*errorResp.Code).To(gomega.Equal("HYPERFLEET-VAL-005")) + gomega.Expect(*errorResp.Detail).To(gomega.ContainSubstring("not valid for ordering")) + + // Test 2: PostgreSQL system identifier (current_user) + orderBy = openapi.QueryParamsOrderBy("current_user") + params = &openapi.GetClustersParams{OrderBy: &orderBy} + resp, err = client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusBadRequest)) + + // Test 3: Function execution (version) + orderBy = openapi.QueryParamsOrderBy("version()") + params = &openapi.GetClustersParams{OrderBy: &orderBy} + resp, err = client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusBadRequest)) + + // Test 4: Random function + orderBy = openapi.QueryParamsOrderBy("random()") + params = &openapi.GetClustersParams{OrderBy: &orderBy} + resp, err = client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusBadRequest)) + + // Test 5: Nonexistent column + orderBy = openapi.QueryParamsOrderBy("nonexistent_column") + params = &openapi.GetClustersParams{OrderBy: &orderBy} + resp, err = client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusBadRequest)) + + // Test 6: Valid fields still work + orderBy = openapi.QueryParamsOrderBy("name desc") + params = &openapi.GetClustersParams{OrderBy: &orderBy} + resp, err = client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusOK)) + + // Test 7: Valid timestamp field + orderBy = openapi.QueryParamsOrderBy("created_time desc") + params = &openapi.GetClustersParams{OrderBy: &orderBy} + resp, err = client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusOK)) +} + +func TestNodePoolOrderBySQLInjection(t *testing.T) { + h, client := test.RegisterIntegration(t) + gomega.RegisterTestingT(t) + account := h.NewRandAccount() + ctx := h.NewAuthenticatedContext(account) + + // Create test data - nodepools + _, err := h.Factories.NewNodePoolsList("nodepool-injection-test", 3) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + + // Test 1: SQL function injection (pg_sleep) + orderBy := openapi.QueryParamsOrderBy("pg_sleep(5)") + params := &openapi.GetNodePoolsParams{OrderBy: &orderBy} + resp, err := client.GetNodePoolsWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusBadRequest)) + var errorResp openapi.ProblemDetails + err = json.Unmarshal(resp.Body, &errorResp) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(*errorResp.Code).To(gomega.Equal("HYPERFLEET-VAL-005")) + gomega.Expect(*errorResp.Detail).To(gomega.ContainSubstring("not valid for ordering")) + + // Test 2: Nonexistent column + orderBy = openapi.QueryParamsOrderBy("malicious_column") + params = &openapi.GetNodePoolsParams{OrderBy: &orderBy} + resp, err = client.GetNodePoolsWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusBadRequest)) + + // Test 3: Valid field specific to NodePool (owner_id) + orderBy = openapi.QueryParamsOrderBy("owner_id") + params = &openapi.GetNodePoolsParams{OrderBy: &orderBy} + resp, err = client.GetNodePoolsWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusOK)) + + // Test 4: Valid common field (name) + orderBy = openapi.QueryParamsOrderBy("name desc") + params = &openapi.GetNodePoolsParams{OrderBy: &orderBy} + resp, err = client.GetNodePoolsWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusOK)) +} + +func TestSearchSQLInjection(t *testing.T) { + h, client := test.RegisterIntegration(t) + gomega.RegisterTestingT(t) + account := h.NewRandAccount() + ctx := h.NewAuthenticatedContext(account) + + // Create test data + _, err := h.Factories.NewClustersList("search-test", 2) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + + // Test 1: SQL function in search parameter + search := openapi.SearchParams("pg_sleep(5)='True'") + params := &openapi.GetClustersParams{Search: &search} + resp, err := client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusBadRequest)) + + // Test 2: System identifier in search + search = openapi.SearchParams("current_user='admin'") + params = &openapi.GetClustersParams{Search: &search} + resp, err = client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusBadRequest)) + + // Test 3: Valid search query + search = openapi.SearchParams("name='search-test-0'") + params = &openapi.GetClustersParams{Search: &search} + resp, err = client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusOK)) + + // Test 4: Valid status.conditions search + search = openapi.SearchParams("status.conditions.Reconciled='False'") + params = &openapi.GetClustersParams{Search: &search} + resp, err = client.GetClustersWithResponse(ctx, params, test.WithAuthToken(ctx)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(resp.StatusCode()).To(gomega.Equal(http.StatusOK)) +}