From cb040cb6703884a13cedf75f0a9fe214a090d97e Mon Sep 17 00:00:00 2001 From: kartik Date: Mon, 8 Jun 2026 15:56:14 +0530 Subject: [PATCH 1/9] feat: add PostgreSQL parser and reporting capabilities - Implemented a new PostgreSQL parser in `pgparser` that provides structural analysis of SQL queries, including support for SELECT, INSERT, UPDATE, and DELETE statements. - Added unit tests for the PostgreSQL parser to ensure accurate parsing and structural fact extraction. - Introduced a `ConsoleReporter` for outputting analysis results to the terminal with color-coded severity levels. - Created a `JSONReporter` for outputting analysis results in JSON format. - Defined a `Reporter` interface for consistent reporting of analysis results across different formats. - Updated the main `sqlguard` package to support the new PostgreSQL parser and reporting features. --- .github/ISSUE_TEMPLATE/bug_report.md | 43 ++ .github/ISSUE_TEMPLATE/config.yml | 8 + .github/ISSUE_TEMPLATE/feature_request.md | 31 ++ .github/PULL_REQUEST_TEMPLATE.md | 27 + .github/dependabot.yml | 56 +++ .github/workflows/ci.yml | 97 ++++ .github/workflows/codeql.yml | 60 +++ .gitignore | 4 + .golangci.yml | 33 ++ .sqlguard.example.yml | 73 +++ CHANGELOG.md | 61 +++ CONTRIBUTING.md | 96 ++++ Makefile | 194 ++++++++ README.md | 514 +++++++++++++++++++ SECURITY.md | 56 +++ analyzer/analyzer.go | 165 +++++++ analyzer/analyzer_test.go | 434 ++++++++++++++++ analyzer/fallback.go | 575 ++++++++++++++++++++++ analyzer/fallback_test.go | 111 +++++ analyzer/parser.go | 18 + analyzer/profile_test.go | 144 ++++++ analyzer/redact.go | 157 ++++++ analyzer/redact_policy_test.go | 51 ++ analyzer/redact_test.go | 115 +++++ analyzer/registry.go | 151 ++++++ analyzer/result.go | 19 + analyzer/rules.go | 269 ++++++++++ analyzer/severity.go | 26 + analyzer/statement.go | 138 ++++++ analyzer/suppress.go | 69 +++ cmd/sqlguard/db.go | 31 ++ cmd/sqlguard/explain.go | 87 ++++ cmd/sqlguard/main.go | 18 + cmd/sqlguard/resolve_test.go | 178 +++++++ cmd/sqlguard/root.go | 58 +++ cmd/sqlguard/scan.go | 398 +++++++++++++++ cmd/sqlguard/scan_test.go | 399 +++++++++++++++ codecov.yml | 31 ++ config/config.go | 317 ++++++++++++ config/config_test.go | 197 ++++++++ config/middleware.go | 61 +++ config/middleware_test.go | 50 ++ explain/explain.go | 295 +++++++++++ explain/explain_test.go | 52 ++ go.mod | 17 + go.sum | 24 + integrations/bunguard/bunguard.go | 75 +++ integrations/bunguard/bunguard_test.go | 185 +++++++ integrations/bunguard/go.mod | 21 + integrations/bunguard/go.sum | 26 + integrations/entguard/entguard.go | 123 +++++ integrations/entguard/entguard_test.go | 193 ++++++++ integrations/entguard/go.mod | 13 + integrations/entguard/go.sum | 16 + integrations/gormguard/go.mod | 18 + integrations/gormguard/go.sum | 12 + integrations/gormguard/gormguard.go | 134 +++++ integrations/gormguard/gormguard_test.go | 200 ++++++++ integrations/pgxguard/go.mod | 19 + integrations/pgxguard/go.sum | 30 ++ integrations/pgxguard/pgxguard.go | 145 ++++++ integrations/pgxguard/pgxguard_test.go | 247 ++++++++++ integrations/sqlxguard/go.mod | 12 + integrations/sqlxguard/go.sum | 11 + integrations/sqlxguard/sqlxguard.go | 157 ++++++ integrations/sqlxguard/sqlxguard_test.go | 187 +++++++ integrations/xormguard/go.mod | 18 + integrations/xormguard/go.sum | 94 ++++ integrations/xormguard/xormguard.go | 74 +++ integrations/xormguard/xormguard_test.go | 166 +++++++ middleware/cache.go | 84 ++++ middleware/cache_test.go | 125 +++++ middleware/dedup.go | 74 +++ middleware/dedup_test.go | 163 ++++++ middleware/driver.go | 390 +++++++++++++++ middleware/driver_fallback_test.go | 111 +++++ middleware/driver_test.go | 265 ++++++++++ middleware/guard.go | 137 ++++++ middleware/n_plus_one.go | 124 +++++ middleware/n_plus_one_test.go | 168 +++++++ middleware/options.go | 98 ++++ parsers/mysqlparser/go.mod | 9 + parsers/mysqlparser/go.sum | 2 + parsers/mysqlparser/mysqlparser.go | 136 +++++ parsers/mysqlparser/mysqlparser_test.go | 136 +++++ parsers/pgparser/go.mod | 35 ++ parsers/pgparser/go.sum | 347 +++++++++++++ parsers/pgparser/pgparser.go | 149 ++++++ parsers/pgparser/pgparser_test.go | 137 ++++++ reporter/console.go | 57 +++ reporter/console_test.go | 86 ++++ reporter/json.go | 56 +++ reporter/json_test.go | 77 +++ reporter/reporter.go | 8 + sqlguard.go | 40 ++ 95 files changed, 11198 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/config.yml create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md create mode 100644 .github/PULL_REQUEST_TEMPLATE.md create mode 100644 .github/dependabot.yml create mode 100644 .github/workflows/ci.yml create mode 100644 .github/workflows/codeql.yml create mode 100644 .golangci.yml create mode 100644 .sqlguard.example.yml create mode 100644 CHANGELOG.md create mode 100644 CONTRIBUTING.md create mode 100644 Makefile create mode 100644 README.md create mode 100644 SECURITY.md create mode 100644 analyzer/analyzer.go create mode 100644 analyzer/analyzer_test.go create mode 100644 analyzer/fallback.go create mode 100644 analyzer/fallback_test.go create mode 100644 analyzer/parser.go create mode 100644 analyzer/profile_test.go create mode 100644 analyzer/redact.go create mode 100644 analyzer/redact_policy_test.go create mode 100644 analyzer/redact_test.go create mode 100644 analyzer/registry.go create mode 100644 analyzer/result.go create mode 100644 analyzer/rules.go create mode 100644 analyzer/severity.go create mode 100644 analyzer/statement.go create mode 100644 analyzer/suppress.go create mode 100644 cmd/sqlguard/db.go create mode 100644 cmd/sqlguard/explain.go create mode 100644 cmd/sqlguard/main.go create mode 100644 cmd/sqlguard/resolve_test.go create mode 100644 cmd/sqlguard/root.go create mode 100644 cmd/sqlguard/scan.go create mode 100644 cmd/sqlguard/scan_test.go create mode 100644 codecov.yml create mode 100644 config/config.go create mode 100644 config/config_test.go create mode 100644 config/middleware.go create mode 100644 config/middleware_test.go create mode 100644 explain/explain.go create mode 100644 explain/explain_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 integrations/bunguard/bunguard.go create mode 100644 integrations/bunguard/bunguard_test.go create mode 100644 integrations/bunguard/go.mod create mode 100644 integrations/bunguard/go.sum create mode 100644 integrations/entguard/entguard.go create mode 100644 integrations/entguard/entguard_test.go create mode 100644 integrations/entguard/go.mod create mode 100644 integrations/entguard/go.sum create mode 100644 integrations/gormguard/go.mod create mode 100644 integrations/gormguard/go.sum create mode 100644 integrations/gormguard/gormguard.go create mode 100644 integrations/gormguard/gormguard_test.go create mode 100644 integrations/pgxguard/go.mod create mode 100644 integrations/pgxguard/go.sum create mode 100644 integrations/pgxguard/pgxguard.go create mode 100644 integrations/pgxguard/pgxguard_test.go create mode 100644 integrations/sqlxguard/go.mod create mode 100644 integrations/sqlxguard/go.sum create mode 100644 integrations/sqlxguard/sqlxguard.go create mode 100644 integrations/sqlxguard/sqlxguard_test.go create mode 100644 integrations/xormguard/go.mod create mode 100644 integrations/xormguard/go.sum create mode 100644 integrations/xormguard/xormguard.go create mode 100644 integrations/xormguard/xormguard_test.go create mode 100644 middleware/cache.go create mode 100644 middleware/cache_test.go create mode 100644 middleware/dedup.go create mode 100644 middleware/dedup_test.go create mode 100644 middleware/driver.go create mode 100644 middleware/driver_fallback_test.go create mode 100644 middleware/driver_test.go create mode 100644 middleware/guard.go create mode 100644 middleware/n_plus_one.go create mode 100644 middleware/n_plus_one_test.go create mode 100644 middleware/options.go create mode 100644 parsers/mysqlparser/go.mod create mode 100644 parsers/mysqlparser/go.sum create mode 100644 parsers/mysqlparser/mysqlparser.go create mode 100644 parsers/mysqlparser/mysqlparser_test.go create mode 100644 parsers/pgparser/go.mod create mode 100644 parsers/pgparser/go.sum create mode 100644 parsers/pgparser/pgparser.go create mode 100644 parsers/pgparser/pgparser_test.go create mode 100644 reporter/console.go create mode 100644 reporter/console_test.go create mode 100644 reporter/json.go create mode 100644 reporter/json_test.go create mode 100644 reporter/reporter.go create mode 100644 sqlguard.go diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..f8c374a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,43 @@ +--- +name: Bug report +about: Report incorrect behavior, a false positive/negative, or a crash +title: "" +labels: bug +assignees: "" +--- + +**Do not file security vulnerabilities here** — see [SECURITY.md](../../SECURITY.md). + +## What happened + +A clear description of the bug. + +## Expected behavior + +What you expected instead. For a false positive/negative, say which **rule** +(e.g. `select-star`) fired or failed to fire. + +## Reproduction + +The SQL or Go snippet, and how it was issued: + +```sql +-- query (redacted is fine) +``` + +```go +// minimal repro +``` + +## Environment + +- sqlguard version / commit: +- Affected module(s) (root, `integrations/`, `parsers/`): +- Parser in use (default fallback / pgparser / mysqlparser): +- Entry surface (runtime middleware / CLI `scan` / CLI `explain` / integration): +- Go version: +- Database + dialect (if relevant): + +## Additional context + +Logs (redaction-safe), config (`.sqlguard.yml`), or anything else useful. diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..d9cc535 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: Security vulnerability + url: https://github.com/KARTIKrocks/sqlguard/security/advisories/new + about: Report security issues privately — please do not open a public issue. + - name: Question / discussion + url: https://github.com/KARTIKrocks/sqlguard/discussions + about: Ask usage questions or discuss ideas here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..829c5ac --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,31 @@ +--- +name: Feature request +about: Suggest a new rule, integration, or capability +title: "" +labels: enhancement +assignees: "" +--- + +## Problem + +What are you trying to catch or do that sqlguard can't today? + +## Proposed solution + +What you'd like to see. If you're proposing a **new detection rule**, include: + +- the SQL anti-pattern it should flag, +- example queries that should and should **not** trigger it, +- a suggested severity (info / warning / critical), +- any tunable (and its default). + +If you're proposing a **new integration**, name the ORM/driver and its +hook/seam. + +## Alternatives considered + +Other approaches, workarounds, or existing rules/config that almost fit. + +## Additional context + +Anything else — links, prior art, willingness to send a PR. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..12ab443 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,27 @@ +## Summary + +What does this PR change, and why? + +Closes # + +## Type of change + +- [ ] Bug fix +- [ ] New detection rule +- [ ] New integration / parser +- [ ] Feature / enhancement +- [ ] Docs only +- [ ] Refactor / chore + +## Checklist + +- [ ] `make ci` passes (fmt-check, vet, lint, test-race) across all modules +- [ ] Added/updated tests (and, where practical, a failure-mode check) +- [ ] Updated docs as needed (`README.md`, `CLAUDE.md`, `.sqlguard.example.yml`) +- [ ] Added an entry under `## [Unreleased]` in `CHANGELOG.md` +- [ ] No new third-party deps in `analyzer` / `middleware` / `reporter` +- [ ] Findings stay redaction-safe (no raw literals leak into a `Result`) + +## Notes for reviewers + +Anything reviewers should focus on — tricky areas, trade-offs, follow-ups. diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..e47fffe --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,56 @@ +version: 2 + +updates: + - package-ecosystem: gomod + directory: / + schedule: + interval: weekly + groups: + go-dependencies: + patterns: + - "*" + + - package-ecosystem: gomod + directory: /integrations/gormguard + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /integrations/sqlxguard + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /integrations/pgxguard + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /integrations/bunguard + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /integrations/xormguard + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /integrations/entguard + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /parsers/pgparser + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /parsers/mysqlparser + schedule: + interval: weekly + + - package-ecosystem: github-actions + directory: / + schedule: + interval: weekly diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d6ba9b2 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,97 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +permissions: + contents: read + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + go-version: ["1.26"] + steps: + - uses: actions/checkout@v6 + + - uses: actions/setup-go@v6 + with: + go-version: ${{ matrix.go-version }} + + - name: Run tests + run: go test ./... -count=1 -race + + - name: Test integrations (gormguard) + run: cd integrations/gormguard && go test ./... -count=1 -race + + - name: Test integrations (sqlxguard) + run: cd integrations/sqlxguard && go test ./... -count=1 -race + + - name: Test integrations (pgxguard) + run: cd integrations/pgxguard && go test ./... -count=1 -race + + - name: Test integrations (bunguard) + run: cd integrations/bunguard && go test ./... -count=1 -race + + - name: Test integrations (xormguard) + run: cd integrations/xormguard && go test ./... -count=1 -race + + - name: Test integrations (entguard) + run: cd integrations/entguard && go test ./... -count=1 -race + + - name: Test parsers (pgparser) + run: cd parsers/pgparser && go test ./... -count=1 -race + + - name: Test parsers (mysqlparser) + run: cd parsers/mysqlparser && go test ./... -count=1 -race + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + + - uses: actions/setup-go@v6 + with: + go-version: "1.26" + + - uses: golangci/golangci-lint-action@v9 + with: + version: v2.11 + args: --timeout=5m + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + + - uses: actions/setup-go@v6 + with: + go-version: "1.26" + + - name: Build CLI + run: go build -o bin/sqlguard ./cmd/sqlguard + + coverage: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + + - uses: actions/setup-go@v6 + with: + go-version: "1.26" + + # `make coverage` runs every module and merges into a single coverage.out + # (root go test does not reach the satellite modules). + - name: Generate merged coverage + run: make coverage + + - name: Upload to Codecov + uses: codecov/codecov-action@v5 + with: + files: ./coverage.out + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000..90b96ee --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,60 @@ +name: CodeQL + +on: + push: + branches: [main] + pull_request: + branches: [main] + schedule: + # Weekly re-scan so newly published CodeQL queries flag old code too. + - cron: "0 6 * * 1" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.event_name != 'schedule' }} + +permissions: + security-events: write + contents: read + +jobs: + analyze: + name: Analyze (Go) + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Setup Go + uses: actions/setup-go@v6 + with: + go-version: "1.26" + + - name: Initialize CodeQL + uses: github/codeql-action/init@v4 + with: + languages: go + # Build the modules ourselves (below) so the tracer sees all nine. + build-mode: manual + queries: security-extended + + # Each integration/parser carries its own go.mod (heavy deps kept opt-in), + # so `go build ./...` from root does NOT reach them. Build every module + # under the CodeQL tracer so all nine are analyzed — same MODULES loop the + # Makefile uses; a satellite must not silently skip scanning. + - name: Build all modules + run: | + set -e + for mod in . \ + ./integrations/gormguard ./integrations/sqlxguard \ + ./integrations/pgxguard ./integrations/bunguard \ + ./integrations/xormguard ./integrations/entguard \ + ./parsers/pgparser ./parsers/mysqlparser; do + echo "==> Building $mod" + (cd "$mod" && go build ./...) + done + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v4 + with: + category: "/language:go" diff --git a/.gitignore b/.gitignore index aaadf73..1151f4d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ *.dll *.so *.dylib +bin # Test binary, built with `go test -c` *.test @@ -17,6 +18,9 @@ coverage.* *.coverprofile profile.cov +# FE +sqlguard-website + # Dependency directories (remove the comment below to include it) # vendor/ diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..4d5beff --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,33 @@ +version: "2" + +linters: + enable: + - errcheck + - govet + - staticcheck + - unused + - ineffassign + - misspell + - gocritic + - gocyclo + - revive + - prealloc + settings: + gocyclo: + min-complexity: 15 + revive: + rules: + - name: exported + exclusions: + rules: + - linters: + - errcheck + path: _test\.go + - linters: + - errcheck + path: examples/ + +formatters: + enable: + - gofmt + - goimports diff --git a/.sqlguard.example.yml b/.sqlguard.example.yml new file mode 100644 index 0000000..556d714 --- /dev/null +++ b/.sqlguard.example.yml @@ -0,0 +1,73 @@ +# sqlguard configuration. Copy to `.sqlguard.yml` at your project root +# (sqlguard discovers it by walking up from the scanned/working directory +# until it hits the file or the git root). +# +# Every key is optional; omitting the file runs all rules at their defaults. + +version: 1 + +# strict: true turns "soft" problems (unknown keys, unknown rule names, +# invalid severities) into hard errors instead of warnings. Leave false so a +# config written for a newer sqlguard still loads on an older binary. +strict: false + +rules: + # Turn rules off entirely. + disable: + - orderby-without-limit + + # Whitelist mode: when non-empty, ONLY these rules run (disable is ignored). + # only: + # - delete-without-where + # - update-without-where + + # Override the reported severity per rule: info | warning | critical | off + # ("off" is equivalent to disabling the rule). + severity: + select-star: info + select-without-limit: "off" + + # Per-rule tunables. Keys are rule-specific. + settings: + leading-wildcard: + # Don't flag LIKE/ILIKE '%x%' style patterns whose searchable term is + # shorter than this many characters. + min-length: 3 + in-list-too-large: + # Flag IN (...) value lists with more than this many elements + # (default 100). Subquery INs are never counted. + max-length: 100 + large-offset: + # Flag a literal OFFSET above this (default 1000) — deep pagination. + # Parameterized offsets (OFFSET $1 / ?) can't be evaluated statically. + threshold: 1000 + +# Redact literal values (strings/numbers) out of Result.Query before it +# reaches any reporter/log. ON by default — leave it on so customer data in +# query literals never lands in your logs. Result.Fingerprint (a PII-free, +# value-free query identity) is emitted regardless. Set to false ONLY for +# local debugging where the query text is trusted. +redact: true + +# Runtime slow-query threshold (middleware). Go duration string. +slow-query: + threshold: 200ms + +# Runtime de-duplication of repeated static findings (middleware). The same +# finding (rule + query fingerprint) is reported at most once per window, so a +# recurring query doesn't flood your logs. Default 1m. Set "0" to disable +# (report every occurrence). Slow-query and N+1 have their own emission policy. +dedup: + window: 1m + +# Static scanner only: skip files whose path matches any of these regexes. +scan: + exclude-paths: + - "(^|/)legacy/" + - "_gen\\.go$" + +# Inline suppressions (no config needed): +# In SQL: SELECT * FROM t -- sqlguard:ignore +# DELETE FROM t /* sqlguard:ignore:delete-without-where */ +# In Go: // sqlguard:ignore (on or above the db call) +# db.Query(q) // sqlguard:ignore:select-star diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..e602e58 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,61 @@ +# Changelog + +All notable changes to this project are documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +Each Go module in this repo (root, `integrations/*`, `parsers/*`) is tagged with +the same version in lockstep. + +## [Unreleased] + +## [0.1.0] - 2026-06-08 + +Initial public release. + +### Added + +- **Runtime middleware** that intercepts at the `database/sql` **driver** layer + (`Register` / `OpenDB`), so any query — including those issued by ORMs and + query builders — is analyzed and you get back a real `*sql.DB`. Zero + third-party dependencies in the core. +- **Analyzer with 19 detection rules** across static, runtime, and EXPLAIN + surfaces: `select-star`, `leading-wildcard`, `non-sargable-predicate`, + `add-not-null-without-default`, `implicit-join`, `cartesian-join`, + `in-list-too-large`, `large-offset`, `select-distinct`, `delete-without-where`, + `update-without-where`, `insert-without-columns`, `select-without-limit`, + `orderby-without-limit`, `n-plus-one`, `slow-query`, `seq-scan`, + `full-table-scan`, `high-cost`. +- **Redaction by default**: every `Result.Query` is redacted (literals → `?`) + before it leaves the process, and every `Result.Fingerprint` is a PII-free, + low-cardinality query identity safe as a metric label. Opt out with + `WithRawQuery()` / `redact: false`. +- **N+1 detection** (windowed) and **slow-query detection** with configurable + thresholds. +- **Finding de-duplication** — each finding (rule + fingerprint) is reported at + most once per window (default 1m) to keep hot queries from flooding logs + (`WithFindingDedup`). +- **Per-query analysis cache** — an LRU keyed on the exact query string so + recurring queries are parsed and checked once (`WithAnalysisCacheSize`). +- **Pluggable parser**: a zero-dependency, never-erroring `FallbackParser` by + default, with opt-in real grammars in separate modules — `parsers/pgparser` + (PostgreSQL) and `parsers/mysqlparser` (MySQL) — via `WithParser`. +- **File configuration** (`.sqlguard.yml`, discovered up to the git root): + enable/disable rules, `only` whitelist, per-rule severity overrides, per-rule + settings, `redact`, `slow-query`, `dedup`, and scanner `exclude-paths`. + Lenient by default; `strict: true` makes unknown keys/rules fatal. +- **Inline suppressions** — in-SQL `-- sqlguard:ignore[:rules]` (honored at + runtime and statically) and Go-source `// sqlguard:ignore[:rules]` (honored by + the scanner). +- **CLI** (`cmd/sqlguard`): `scan` for static analysis of Go source (with + literal/constant resolution via `go/types`) and `explain` for live EXPLAIN + plan analysis. `explain` never executes the statement — it validates input and + runs inside an always-rolled-back read-only transaction. +- **ORM / driver integrations**, each a separate opt-in module built on the + shared `middleware.Guard` core (redaction, fingerprints, parser seam, + slow-query, N+1, and a `ResetN1()` per-request hook): `integrations/gormguard`, + `integrations/sqlxguard`, `integrations/pgxguard` (native pgx / pgxpool), + `integrations/bunguard`, `integrations/xormguard`, `integrations/entguard`. + +[Unreleased]: https://github.com/KARTIKrocks/sqlguard/compare/v0.1.0...HEAD +[0.1.0]: https://github.com/KARTIKrocks/sqlguard/releases/tag/v0.1.0 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..44f70e9 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,96 @@ +# Contributing to sqlguard + +Thanks for your interest in improving sqlguard! This guide covers the +project-specific things that aren't obvious from a quick look at the repo. + +## Project layout + +sqlguard is a **multi-module repo** — nine Go modules on Go 1.26, kept in +lockstep: + +- **root** (`github.com/KARTIKrocks/sqlguard`) — core analyzer, middleware, + reporter, config, and CLI. Deliberately near-zero-dependency. +- **`parsers/pgparser`, `parsers/mysqlparser`** — opt-in real SQL grammars, + isolated so their heavy parser deps never enter a consumer's build. +- **`integrations/{gormguard,sqlxguard,pgxguard,bunguard,xormguard,entguard}`** — + ORM/driver adapters, each a separate module so its deps stay opt-in. + +The satellite modules use a local `replace` directive pointing at the root, so +you can develop across modules without publishing. + +> **Important:** `go test ./...` (and `go build` / `go vet` / `go mod tidy`) +> from the root does **not** reach the satellite modules. Always use the +> Makefile targets, which loop over every module. + +## Development workflow + +```bash +make setup # install pinned golangci-lint + goimports (one-time) +make all # tidy, fmt, vet, lint, build, test across all nine modules +make ci # what CI runs: fmt-check, vet, lint, test-race +make test-race # race detector (required for anything touching middleware) +make help # list every target +``` + +Before opening a PR, run `make ci` and make sure it's green. + +- Run a single test: `go test ./middleware/ -run TestName -count=1`. +- Use `-race` for anything touching `middleware` (the driver chain and + `QueryTracker` are concurrent). +- After any dependency change, run `make tidy` (tidies all nine modules — tidying + only the root leaves the others stale). + +## Conventions + +- **Pre-1.0, no backward-compatibility burden.** Prefer the clean design over + preserving an existing public API; don't add deprecation shims or compat + layers. +- Modern Go idioms are expected (range-over-int, `any`, compile-time interface + asserts `var _ I = (*T)(nil)`). +- Keep the **core dependency-light**: `analyzer`, `middleware`, and `reporter` + must stay free of third-party deps and of YAML. `config` is the only + YAML-aware package. +- **Redaction is the default.** Never let raw literal values reach a `Result` + that leaves the process. There is one canonical normalizer (`analyzer.Redact` + / `Fingerprint`) — don't add a second. +- See [`CLAUDE.md`](CLAUDE.md) for the deeper architecture notes and invariants, + and [`PRODUCTION_READINESS.md`](PRODUCTION_READINESS.md) for the roadmap. + +## Adding a detection rule + +Rules self-register. Write the rule, then add one `analyzer.Register(RuleSpec{ +... })` call in `analyzer/rules.go` (a stable name, default severity, and a +settings-aware factory). Being addressable by name is what makes enable/disable, +severity overrides, per-rule settings, and suppressions all work uniformly — +do **not** hand-maintain a rule list. Rules read the normalized `Statement`, +never raw SQL. + +If your rule has a tunable, read it from `Settings` in the factory and document +it in [`.sqlguard.example.yml`](.sqlguard.example.yml). + +## Adding an integration + +Every integration must build on the exported `middleware.Guard` core — +`integrations/pgxguard` is the reference. Hand-rolling analysis silently loses +redaction, fingerprints, the parser seam, config, N+1, and dedup. Each +integration should expose `ResetN1()` for per-request scoping. + +## Pull requests + +1. Fork and branch from `main`. +2. Keep changes focused; update docs (`README.md`, `CLAUDE.md`, + `.sqlguard.example.yml`) when behavior or config changes. +3. Add tests for new behavior; where practical, also prove the failure mode + (e.g. a bug-reintroduction check). +4. Add a line under `## [Unreleased]` in [`CHANGELOG.md`](CHANGELOG.md). +5. Run `make ci` and ensure it passes. + +## Reporting security issues + +Please do **not** open a public issue for security vulnerabilities. See +[`SECURITY.md`](SECURITY.md) for the private reporting process. + +## License + +By contributing, you agree that your contributions are licensed under the +project's [MIT License](LICENSE). diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d205223 --- /dev/null +++ b/Makefile @@ -0,0 +1,194 @@ +GOLANGCI_LINT_VERSION := v2.12.2 +GOIMPORTS_VERSION := v0.45.0 + +MODULE_PATH := github.com/KARTIKrocks/sqlguard + +# Sub-modules carry their own go.mod (heavy/opt-in deps kept out of the core +# import graph). `go test ./...` from root does NOT reach them, so every +# all-modules target loops over MODULES. +SUB_MODULES = \ + ./integrations/gormguard \ + ./integrations/sqlxguard \ + ./integrations/pgxguard \ + ./integrations/bunguard \ + ./integrations/xormguard \ + ./integrations/entguard \ + ./parsers/pgparser \ + ./parsers/mysqlparser +MODULES = . $(SUB_MODULES) + +.PHONY: all help setup deps ci test test-v test-race coverage lint lint-fix fix fmt fmt-check vet tidy build cli install bench clean release-prep + +all: tidy fmt vet lint build test + +## Show available targets +help: + @echo "Available targets:" + @echo " all - Tidy, format, vet, lint, build, test (all modules)" + @echo " setup - Install development tools" + @echo " deps - Download module dependencies (all modules)" + @echo " ci - CI pipeline (fmt-check, vet, lint, test-race)" + @echo " test - Run tests across all modules" + @echo " test-v - Run tests with verbose output (all modules)" + @echo " test-race - Run tests with race detector (all modules)" + @echo " coverage - Run tests with merged coverage report (all modules)" + @echo " vet - Run go vet (all modules)" + @echo " lint - Run golangci-lint (all modules)" + @echo " lint-fix - Run golangci-lint with --fix (root module)" + @echo " fix - fmt + lint-fix" + @echo " fmt - Format code (gofmt -s + goimports)" + @echo " fmt-check - Verify formatting without modifying files" + @echo " tidy - Run go mod tidy (all modules)" + @echo " build - Build all packages (all modules)" + @echo " cli - Build the sqlguard CLI to bin/sqlguard" + @echo " install - Install the CLI to \$$GOPATH/bin" + @echo " bench - Run benchmarks (all modules)" + @echo " clean - Remove build/coverage artifacts" + @echo " release-prep - Pin sub-modules to a release version (VERSION=vX.Y.Z)" + +## Install development tools (skips if already present) +setup: + @command -v golangci-lint >/dev/null 2>&1 || { \ + echo "Installing golangci-lint $(GOLANGCI_LINT_VERSION)..."; \ + go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@$(GOLANGCI_LINT_VERSION); \ + } + @command -v goimports >/dev/null 2>&1 || { \ + echo "Installing goimports $(GOIMPORTS_VERSION)..."; \ + go install golang.org/x/tools/cmd/goimports@$(GOIMPORTS_VERSION); \ + } + +## Download module dependencies across all modules +deps: + @for mod in $(MODULES); do \ + echo "==> Downloading deps $$mod"; \ + (cd $$mod && go mod download) || exit 1; \ + done + +## CI: run formatting check, vet, lint and tests with race detector +ci: fmt-check vet lint test-race + +## Build all packages across all modules (compile check) +build: + @for mod in $(MODULES); do \ + echo "==> Building $$mod"; \ + (cd $$mod && go build ./...) || exit 1; \ + done + +## Build the CLI binary +cli: + @echo "==> Building bin/sqlguard" + @go build -o bin/sqlguard ./cmd/sqlguard + +## Install the CLI to $GOPATH/bin +install: + go install ./cmd/sqlguard + +## Run tests across all modules +test: + @for mod in $(MODULES); do \ + echo "==> Testing $$mod"; \ + (cd $$mod && go test -count=1 ./...) || exit 1; \ + done + +## Run tests with verbose output across all modules +test-v: + @for mod in $(MODULES); do \ + echo "==> Testing (verbose) $$mod"; \ + (cd $$mod && go test -v -count=1 ./...) || exit 1; \ + done + +## Run tests with race detector across all modules +test-race: + @for mod in $(MODULES); do \ + echo "==> Testing (race) $$mod"; \ + (cd $$mod && go test -race -count=1 ./...) || exit 1; \ + done + +## Run tests with coverage and generate a merged report across all modules +coverage: + @echo "mode: atomic" > coverage.out + @for mod in $(MODULES); do \ + echo "==> Coverage $$mod"; \ + (cd $$mod && go test -race -covermode=atomic -coverprofile=cover.tmp ./...) || exit 1; \ + if [ -f $$mod/cover.tmp ]; then tail -n +2 $$mod/cover.tmp >> coverage.out && rm $$mod/cover.tmp; fi; \ + done + @go tool cover -func=coverage.out | tail -1 + @echo "Full report: go tool cover -html=coverage.out" + +## Run linter across all modules +lint: setup + @for mod in $(MODULES); do \ + echo "==> Linting $$mod"; \ + (cd $$mod && golangci-lint run --timeout=5m ./...) || exit 1; \ + done + +## Run golangci-lint with auto-fix (root module) +lint-fix: setup + golangci-lint run --fix ./... + +## Fix code formatting and linting issues +fix: fmt lint-fix + +## Format code (recurses the whole tree, all modules) +fmt: setup + @gofmt -s -w . + @goimports -w . + +## Check formatting without modifying files (used in CI) +fmt-check: setup + @test -z "$$(gofmt -s -l . | tee /dev/stderr)" || { echo "Unformatted files found. Run 'make fmt'."; exit 1; } + @test -z "$$(goimports -l . | tee /dev/stderr)" || { echo "Unordered imports found. Run 'make fmt'."; exit 1; } + +## Run go vet across all modules +vet: + @for mod in $(MODULES); do \ + echo "==> Vetting $$mod"; \ + (cd $$mod && go vet ./...) || exit 1; \ + done + +## Run go mod tidy across all modules +tidy: + @for mod in $(MODULES); do \ + echo "==> Tidying $$mod"; \ + (cd $$mod && go mod tidy) || exit 1; \ + done + +## Run benchmarks across all modules +bench: + @for mod in $(MODULES); do \ + echo "==> Benchmarking $$mod"; \ + (cd $$mod && go test -bench=. -benchmem -run='^$$' ./...) || exit 1; \ + done + +## Remove build and coverage artifacts +clean: + @rm -f coverage*.out cover.tmp coverage.txt coverage.html + @find . -name cover.tmp -delete 2>/dev/null || true + @rm -rf dist/ build/ bin/ + +## Prepare sub-modules for release: drop the local replace and pin the parent +## version. Usage: make release-prep VERSION=v0.1.0 +## Run this AFTER the root module tag for VERSION exists and is pushed, then +## commit and tag the sub-modules. Restore replaces afterwards for local dev +## (git checkout -- '**/go.mod') or develop against the published version. +release-prep: +ifndef VERSION + $(error VERSION is required. Usage: make release-prep VERSION=v0.1.0) +endif + @for mod in $(SUB_MODULES); do \ + echo "==> release-prep $$mod"; \ + (cd $$mod && go mod edit -dropreplace $(MODULE_PATH) -require $(MODULE_PATH)@$(VERSION)) || exit 1; \ + done + @echo "" + @echo "Done! Sub-modules now require $(MODULE_PATH)@$(VERSION) (replace dropped)." + @echo "Next steps (root tag $(VERSION) must already be pushed):" + @echo " git add -A && git commit -m 'Prepare release $(VERSION)'" + @echo " git tag integrations/gormguard/$(VERSION)" + @echo " git tag integrations/sqlxguard/$(VERSION)" + @echo " git tag integrations/pgxguard/$(VERSION)" + @echo " git tag integrations/bunguard/$(VERSION)" + @echo " git tag integrations/xormguard/$(VERSION)" + @echo " git tag integrations/entguard/$(VERSION)" + @echo " git tag parsers/pgparser/$(VERSION)" + @echo " git tag parsers/mysqlparser/$(VERSION)" + @echo " git push origin main --tags" diff --git a/README.md b/README.md new file mode 100644 index 0000000..7886379 --- /dev/null +++ b/README.md @@ -0,0 +1,514 @@ +# sqlguard + +[![Go Reference](https://pkg.go.dev/badge/github.com/KARTIKrocks/sqlguard.svg)](https://pkg.go.dev/github.com/KARTIKrocks/sqlguard) +[![Go Report Card](https://goreportcard.com/badge/github.com/KARTIKrocks/sqlguard)](https://goreportcard.com/report/github.com/KARTIKrocks/sqlguard) +[![Go Version](https://img.shields.io/github/go-mod/go-version/KARTIKrocks/sqlguard)](go.mod) +[![CI](https://github.com/KARTIKrocks/sqlguard/actions/workflows/ci.yml/badge.svg)](https://github.com/KARTIKrocks/sqlguard/actions/workflows/ci.yml) +[![CodeQL](https://github.com/KARTIKrocks/sqlguard/actions/workflows/codeql.yml/badge.svg)](https://github.com/KARTIKrocks/sqlguard/actions/workflows/codeql.yml) +[![GitHub tag](https://img.shields.io/github/v/tag/KARTIKrocks/sqlguard)](https://github.com/KARTIKrocks/sqlguard/releases) +[![License](https://img.shields.io/github/license/KARTIKrocks/sqlguard)](LICENSE) +[![codecov](https://codecov.io/gh/KARTIKrocks/sqlguard/branch/main/graph/badge.svg)](https://codecov.io/gh/KARTIKrocks/sqlguard) + +Production-safe SQL query analyzer for Go applications. + +Detects slow queries, dangerous SQL patterns, and performance issues — both at runtime and statically. Think of it as `golangci-lint` for SQL queries. + +## Install + +```bash +go get github.com/KARTIKrocks/sqlguard +``` + +CLI tool: + +```bash +go install github.com/KARTIKrocks/sqlguard/cmd/sqlguard@latest +``` + +## Detection Rules + +| Rule | Severity | Description | +| ------------------------------ | -------- | ------------------------------------------------------------------------------------------------ | +| `select-star` | WARNING | `SELECT *` — selects all columns unnecessarily | +| `leading-wildcard` | WARNING | `LIKE '%...'` (and `ILIKE`) — index cannot be used | +| `non-sargable-predicate` | WARNING | `WHERE LOWER(col) = ...` — function on column defeats its index | +| `add-not-null-without-default` | WARNING | `ALTER TABLE ... ADD COLUMN ... NOT NULL` without `DEFAULT` — fails / rewrites a populated table | +| `implicit-join` | WARNING | `FROM a, b` — comma join; a forgotten condition becomes a cartesian product | +| `cartesian-join` | WARNING | Multiple tables with no join condition or `WHERE` — a cartesian product (incl. `CROSS JOIN`) | +| `in-list-too-large` | WARNING | `IN (...)` value list with more than `max-length` (default 100) elements | +| `large-offset` | WARNING | `OFFSET` above `threshold` (default 1000) — deep pagination scans/discards skipped rows | +| `select-distinct` | INFO | `SELECT DISTINCT` — often masks duplicate rows from an unintended join | +| `delete-without-where` | CRITICAL | `DELETE` without `WHERE` — deletes all rows | +| `update-without-where` | CRITICAL | `UPDATE` without `WHERE` — updates all rows | +| `insert-without-columns` | WARNING | `INSERT` without an explicit column list (`VALUES` or `... SELECT`) — breaks on schema change | +| `select-without-limit` | WARNING | `SELECT` without `LIMIT` or `WHERE` — may return excessive rows | +| `orderby-without-limit` | INFO | `ORDER BY` without `LIMIT` — sorts entire result set | +| `n-plus-one` | WARNING | Same query pattern repeated N times (runtime only) | +| `slow-query` | WARNING | Query exceeds latency threshold (runtime only) | +| `seq-scan` | WARNING | Sequential scan detected via EXPLAIN (postgres) | +| `full-table-scan` | WARNING | Full table scan detected via EXPLAIN (mysql) | +| `high-cost` | WARNING | High cost operation in query plan | + +## Configuration + +Drop a `.sqlguard.yml` at your project root. sqlguard discovers it by walking +up from the scanned (or working) directory until it finds the file or the git +root. The CLI takes `--config ` and `--no-config`; the file is optional +— without it every rule runs at its default. A fully-commented template lives +at [`.sqlguard.example.yml`](.sqlguard.example.yml). + +```yaml +version: 1 +rules: + disable: [orderby-without-limit] + severity: + select-star: info # info | warning | critical | off + select-without-limit: "off" # "off" disables the rule + settings: + leading-wildcard: + min-length: 3 # ignore short LIKE '%x%' patterns + in-list-too-large: + max-length: 100 # flag IN (...) lists longer than this + large-offset: + threshold: 1000 # flag literal OFFSET above this +redact: true # redact literals out of Result.Query (default) +slow-query: + threshold: 200ms # runtime middleware threshold +dedup: + window: 1m # report each repeated finding at most once per window ("0" disables) +scan: + exclude-paths: ["(^|/)legacy/"] # static scanner only, regex +``` + +Unknown keys and rule names are warnings, not errors, so a config written for +a newer sqlguard still loads on an older binary; set `strict: true` to make +them fatal. `only: [rule, ...]` switches to whitelist mode. + +**Inline suppressions** — no config required: + +```sql +SELECT * FROM users -- sqlguard:ignore +DELETE FROM users /* sqlguard:ignore:delete-without-where */ +``` + +```go +// sqlguard:ignore +db.Exec("DELETE FROM users") +db.Query("SELECT * FROM users") // sqlguard:ignore:select-star +``` + +In-SQL directives work at runtime _and_ in the static scanner; the Go-source +form is honored by the scanner when it sits on or directly above the call. + +Apply the same config to the runtime middleware: + +```go +opts, _ := config.Middleware("", ".") // discover from cwd +sqlguard.Register("sqlguard-pg", "pgx", opts...) +``` + +## Security & redaction + +sqlguard's findings flow into logs, so by **default it never emits raw +literal values**. Before any `Result` leaves the process its `Query` is +redacted — single-quoted strings and numeric literals become `?`, while +keywords, identifiers (including `"quoted"` / `` `backtick` `` names) and +structure are preserved: + +``` +[SQLGUARD WARNING] select-star + Query: SELECT * FROM users WHERE email = ? +``` + +Every `Result` also carries a `Fingerprint`: the redacted query with +whitespace collapsed and `IN (?, ?, ?)` folded to `(?)`. It is a stable, +PII-free, low-cardinality identity — safe as a metrics label or log key, and +the same value the N+1 detector groups on. The JSON reporter emits it as +`fingerprint`. + +Opt out only where the query text is trusted (local debugging): + +```go +a := analyzer.Default().WithRawQuery() // standalone analyzer +sqlguard.Register("pg", "pgx", middleware.WithAnalyzer(a)) +``` + +or `redact: false` in `.sqlguard.yml`. `Fingerprint` is populated either way. + +## Usage + +### Runtime Middleware + +sqlguard wraps at the `database/sql` **driver** layer, so you get back a real +`*sql.DB` and every query is analyzed automatically — including queries issued +by ORMs and query builders (sqlc, ent, sqlx, gorm, pgx-stdlib). There is no +wrapper type to thread through your code and no method list to keep in sync. + +```go +import ( + "database/sql" + "github.com/KARTIKrocks/sqlguard" + "github.com/KARTIKrocks/sqlguard/middleware" + "time" +) + +func main() { + // Register an analyzed driver by wrapping an existing one... + sqlguard.Register("sqlguard-pg", "postgres", + middleware.WithSlowQueryThreshold(500*time.Millisecond), + middleware.WithN1Detection(5, 2*time.Second), + ) + db, _ := sql.Open("sqlguard-pg", "...") // db is a plain *sql.DB + + // ...or wrap a driver.Connector directly (e.g. pgx stdlib): + // db := sqlguard.OpenDB(connector, middleware.WithN1Detection(5, time.Second)) + + // Use as normal — warnings are logged automatically + db.Query("SELECT * FROM users") + // Output: + // [SQLGUARD WARNING] select-star + // Query: SELECT * FROM users + // Issue: SELECT * detected. Selecting all columns can hurt performance. + // Fix: Select only the columns you need. +} +``` + +### N+1 Query Detection + +The middleware detects when the same query pattern executes repeatedly — a classic N+1 problem: + +```go +sqlguard.Register("sqlguard-pg", "postgres", + middleware.WithN1Detection(5, 2*time.Second), // flag after 5 similar queries in 2s +) +db, _ := sql.Open("sqlguard-pg", "...") +``` + +N+1 patterns are detected within the configured time window. On the raw +`database/sql` driver path you get back a plain `*sql.DB`, so detection is +process-wide (windowed) — there is no handle to scope it per request. The +integration adapters (`gormguard`, `pgxguard`, `sqlxguard`, `bunguard`, +`xormguard`, `entguard`) hold the guard and expose `ResetN1()` to scope +detection to a single unit of work; call it at a request boundary. + +### Noise control (finding de-duplication) + +A recurring query would otherwise re-emit the same static warning on every +execution. By default the runtime middleware reports each finding (rule + query +fingerprint) **at most once per minute**, so a hot query doesn't flood your +logs. Tune or disable it: + +```go +sqlguard.Register("sqlguard-pg", "postgres", + middleware.WithFindingDedup(5*time.Minute), // quieter +) +sqlguard.Register("sqlguard-pg", "postgres", + middleware.WithFindingDedup(0), // disable: report every occurrence +) +``` + +Or set `dedup.window` in `.sqlguard.yml`. Slow-query and N+1 findings have +their own emission policy and are unaffected. + +The middleware also memoizes analysis per distinct query string (an LRU keyed on +the exact query — correct even for the literal-sensitive rules), so a recurring +query is parsed and rule-checked once rather than on every execution. A repeated +query then costs a cache lookup instead of a full parse (≈1000× cheaper, zero +allocations in the repeat case). Default 1024 entries; tune with +`middleware.WithAnalysisCacheSize(n)` or disable with `n == 0`. + +### CLI Static Scanner + +Scan your Go source code for SQL issues without running the application: + +```bash +# Scan current directory +sqlguard scan . + +# Scan specific package +sqlguard scan ./internal/repository + +# JSON output (for CI pipelines) +sqlguard scan --format json ./... +``` + +Exit code is **1** when issues are found, **0** when clean — works with CI/CD pipelines. + +### EXPLAIN Plan Analyzer + +Connect to a live database and analyze query plans: + +```bash +# PostgreSQL +sqlguard explain --db "postgres://user:pass@localhost/mydb?sslmode=disable" \ + "SELECT * FROM orders WHERE user_id = 42" + +# MySQL +sqlguard explain --dialect mysql --db "user:pass@tcp(localhost:3306)/mydb" \ + "SELECT * FROM orders WHERE user_id = 42" + +# JSON output +sqlguard explain --db "..." --format json "SELECT * FROM orders" +``` + +Detects sequential scans, missing indexes, filesort, and high-cost operations. + +For safety the EXPLAIN runs inside a **read-only transaction that is always +rolled back** (Postgres and MySQL), and `ANALYZE` is never used — the +statement is planned, never executed. Input is validated with a +comment- and string-literal-aware multi-statement check (a `;` hidden in a +comment or string can't smuggle a second statement). Only `SELECT`/`WITH` is +allowed by default; pass `--allow-dml` to EXPLAIN an `INSERT/UPDATE/DELETE` +(still rolled back). DDL/`SET`/transaction-control is always refused. + +### GORM Integration + +```bash +go get github.com/KARTIKrocks/sqlguard/integrations/gormguard +``` + +```go +import ( + "github.com/KARTIKrocks/sqlguard/integrations/gormguard" + "github.com/KARTIKrocks/sqlguard/middleware" +) + +gormDB, _ := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + +// Register as GORM plugin — hooks into all queries automatically +gormguard.Register(gormDB) + +// Or customize via the standard middleware options +gormguard.Register(gormDB, + middleware.WithSlowQueryThreshold(500*time.Millisecond), + middleware.WithN1Detection(10, time.Second), +) +``` + +### sqlx Integration + +```bash +go get github.com/KARTIKrocks/sqlguard/integrations/sqlxguard +``` + +```go +import ( + "github.com/KARTIKrocks/sqlguard/integrations/sqlxguard" + "github.com/KARTIKrocks/sqlguard/middleware" +) + +sqlxDB := sqlx.MustConnect("postgres", dsn) + +db := sqlxguard.WrapSqlx(sqlxDB, + middleware.WithSlowQueryThreshold(500*time.Millisecond), +) + +var users []User +db.Select(&users, "SELECT * FROM users") // warns about SELECT * +``` + +### pgx Integration (native pgx / pgxpool) + +The `database/sql` driver wrapper covers pgx-stdlib (`pgx/v5/stdlib`). For the +**native pgx APIs** (`pgxpool.Pool`, `pgx.Conn` — which bypass `database/sql` +entirely) use `pgxguard`. It hooks pgx's own tracer seam, so every +`Query`/`QueryRow`/`Exec` and every `SendBatch` is analyzed without a wrapper +type or a method list. + +```bash +go get github.com/KARTIKrocks/sqlguard/integrations/pgxguard +``` + +```go +import ( + "github.com/KARTIKrocks/sqlguard/integrations/pgxguard" + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/jackc/pgx/v5/pgxpool" +) + +cfg, _ := pgxpool.ParseConfig(dsn) +pgxguard.ApplyPool(cfg, + middleware.WithSlowQueryThreshold(50*time.Millisecond), + middleware.WithN1Detection(10, time.Second), +) +pool, _ := pgxpool.NewWithConfig(ctx, cfg) +``` + +`Apply` (for `*pgx.ConnConfig`) and `ApplyPool` (for `*pgxpool.Config`) +**compose** with any tracer already installed via pgx's own `multitracer`, +so sqlguard coexists with `otelpgx`, `ddtrace` and friends rather than +silently overwriting them. Configuration is the standard `middleware.Option` +set — same as the driver wrapper, no parallel surface to learn. + +Coverage: `Query` / `QueryRow` / `Exec` (via `pgx.QueryTracer`) and +`SendBatch` (via `pgx.BatchTracer`). Prepared-statement execution is already +covered by `QueryTracer`, so `PrepareTracer` is deliberately omitted to avoid +double-reporting. `CopyFrom` carries no SQL and is out of scope. + +### bun / xorm Integrations + +bun and xorm build SQL through their own query layers and expose native +before/after hook seams. `bunguard` and `xormguard` plug into those seams and +run every statement through the same shared core — same `middleware.Option` +set, no parallel surface. + +```bash +go get github.com/KARTIKrocks/sqlguard/integrations/bunguard +go get github.com/KARTIKrocks/sqlguard/integrations/xormguard +``` + +```go +// bun — register a QueryHook +db.AddQueryHook(bunguard.New( + middleware.WithSlowQueryThreshold(500*time.Millisecond), + middleware.WithN1Detection(10, time.Second), +)) + +// xorm — register a Hook +engine.AddHook(xormguard.New( + middleware.WithSlowQueryThreshold(500*time.Millisecond), +)) +``` + +### ent Integration + +ent runs on `database/sql`, so the simplest coverage is to point `entsql` at a +`*sql.DB` from `sqlguard.Register`/`OpenDB`. `entguard` is the dedicated +alternative: it decorates ent's own `dialect.Driver`, so it covers every +`Exec`/`Query` (and transactions it opens) regardless of how the `*sql.DB` was +created. + +```bash +go get github.com/KARTIKrocks/sqlguard/integrations/entguard +``` + +```go +drv, _ := entsql.Open(dialect.Postgres, dsn) +guarded := entguard.Wrap(drv, + middleware.WithSlowQueryThreshold(500*time.Millisecond), + middleware.WithN1Detection(10, time.Second), +) +client := ent.NewClient(ent.Driver(guarded)) +``` + +Every adapter (`gormguard`, `bunguard`, `xormguard`, `entguard`, `pgxguard`, +`sqlxguard`) exposes a `ResetN1()` you can call at a per-request boundary to +scope N+1 detection to one unit of work. + +### SQL Parsers (accuracy vs. zero dependencies) + +By default the analyzer uses a **zero-dependency fallback parser**: it strips +SQL comments and string-literal contents before pattern matching, so keywords +inside comments/strings and identifiers like `update_at` no longer cause false +positives. It never errors — SQL it can't fully understand still yields a +best-effort result, so analysis never breaks your query path. + +For **exact, structural analysis**, opt into a real grammar. These live in +separate modules so the core stays dependency-free: + +```bash +go get github.com/KARTIKrocks/sqlguard/parsers/pgparser # PostgreSQL (pure Go, no cgo) +go get github.com/KARTIKrocks/sqlguard/parsers/mysqlparser # MySQL (pure Go, no cgo) +``` + +```go +import ( + "github.com/KARTIKrocks/sqlguard" + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/KARTIKrocks/sqlguard/parsers/pgparser" +) + +sqlguard.Register("sqlguard-pg", "pgx", middleware.WithParser(pgparser.New())) +db, _ := sql.Open("sqlguard-pg", dsn) + +// Or with the standalone analyzer: +a := analyzer.Default().WithParser(pgparser.New()) +``` + +A real parser drives the false-positive-prone facts (statement kind, +WHERE/LIMIT/ORDER BY/FROM presence, `SELECT *`, `SELECT DISTINCT`, `OFFSET`, +explicit INSERT columns) from the AST instead of regex. CTEs, subqueries, and +dialect syntax are handled correctly; anything the grammar rejects (dynamic SQL, +driver placeholders) transparently degrades to the fallback parser. + +A few facts stay lexical heuristics even with a real parser, because they read +literal values the AST discards or are intentionally text-level: IN-list size +(`in-list-too-large`), comma/cartesian joins (`implicit-join` / +`cartesian-join`), and the literal/text checks (`leading-wildcard`, +`non-sargable-predicate`, `add-not-null-without-default`). These keep their +zero-dependency, best-effort behavior regardless of the parser. + +### Custom Rules + +```go +import "github.com/KARTIKrocks/sqlguard/analyzer" + +// Create analyzer with only the rules you want +a := analyzer.New( + analyzer.CheckDeleteWithoutWhere, + analyzer.CheckUpdateWithoutWhere, +) + +// Or use all defaults +a := analyzer.Default() + +// Analyze a query +results := a.Analyze("DELETE FROM users") +for _, r := range results { + fmt.Printf("[%s] %s: %s\n", r.Severity, r.RuleName, r.Message) +} +``` + +## Development + +```bash +make help # List all targets +make all # tidy, fmt, vet, lint, build, test (all modules) +make build # Compile all modules; `make cli` builds bin/sqlguard +make test # Run tests across all modules (test-race adds -race) +make lint # Run golangci-lint across all modules +make fmt # gofmt -s + goimports +make tidy # go mod tidy across all modules +make install # Install the CLI to $GOPATH/bin +``` + +## Coverage + +The middleware wraps the `database/sql` **driver** chain, so _every_ query +is analyzed regardless of how it's issued (`Query`/`Exec`/`Prepare`/`Tx`, +context variants, and any ORM/query builder on top — sqlc, ent, sqlx, gorm, +pgx-stdlib). There is no method allowlist to keep in sync; you get back a +real `*sql.DB`. + +Opt-in adapter modules, each built on the same `middleware.Guard` core, +extend coverage to APIs that bypass or sit above the `database/sql` driver +path: + +- **`pgxguard`** — native pgx / pgxpool (which never goes through + `database/sql`), via pgx's own tracer seam. Composes with existing tracers + (otelpgx, ddtrace) via `multitracer`. Covers `Query`/`QueryRow`/`Exec` and + `SendBatch`. +- **`gormguard`** / **`bunguard`** / **`xormguard`** — hook each ORM's native + before/after callback seam (`gorm.Plugin`, `bun.QueryHook`, xorm + `contexts.Hook`). +- **`entguard`** — decorates ent's `dialect.Driver` (Exec/Query + the + transactions it opens). +- **`sqlxguard`** — sqlx-only helpers that build SQL outside the driver path: + `Select` / `SelectContext`, `Get` / `GetContext`, `Queryx`, `NamedExec` / + `NamedExecContext`. + +All six inherit redaction-by-default, stable fingerprints, the parser seam, +and slow-query/N+1 detection from the shared core, and expose `ResetN1()` for +per-request scoping. + +## Limitations + +- The static scanner resolves inline literals, same/cross-package constants, + constant concatenation, and `fmt.Sprintf` with a constant format string + (via `go/types`); it cannot resolve values only known at runtime. +- The default fallback parser is best-effort; for exact structural analysis use a real parser module (see _SQL Parsers_ above) +- EXPLAIN analyzer requires a live database connection; only Postgres and MySQL dialects are supported + +## License + +[MIT](LICENSE) diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..87fd9ce --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,56 @@ +# Security Policy + +## Supported versions + +sqlguard is pre-1.0. Security fixes are applied to the latest released minor +version. Until 1.0, only the most recent `0.x` release is supported. + +| Version | Supported | +| ------------ | --------- | +| latest `0.x` | ✅ | +| older | ❌ | + +## Reporting a vulnerability + +**Please do not report security vulnerabilities through public GitHub issues, +discussions, or pull requests.** + +Instead, use one of the following private channels: + +1. **GitHub private vulnerability reporting** (preferred) — open a report via + the repository's **Security → Report a vulnerability** tab + (`https://github.com/KARTIKrocks/sqlguard/security/advisories/new`). +2. **Email** — `kartik.rajput622001@gmail.com` with a subject line starting + `[sqlguard security]`. + +Please include: + +- the affected module(s) and version/commit, +- a description of the issue and its impact, +- steps to reproduce (a minimal repro or PoC is ideal), +- any suggested remediation. + +You can expect an acknowledgement within **5 business days**. We'll keep you +informed as we investigate and work on a fix, and we'll credit you in the +release notes / advisory unless you prefer to remain anonymous. + +## Scope and threat model + +sqlguard is a _defensive_ tool — it analyzes SQL for risky patterns and is +designed to fail safe. A few invariants are part of its security contract; bugs +that break them are in scope: + +- **Redaction by default.** A `Result` must never carry raw literal values out + of the process: `Result.Query` is redacted and `Result.Fingerprint` must be + PII-free. A path that leaks literals to a reporter/log is a security bug. +- **EXPLAIN never executes the statement.** The `explain` analyzer validates + input (comment/string-aware multi-statement rejection, `SELECT`/`WITH`-only by + default) and runs every plan inside an always-rolled-back, read-only + transaction. A way to make `explain` mutate data or run a second statement is + in scope. +- **The middleware must not alter query semantics or results.** It observes; + it must not change what the underlying driver returns. + +Out of scope: vulnerabilities in third-party dependencies (report those +upstream; we'll bump once fixed), and misuse such as deliberately disabling +redaction with `WithRawQuery()` / `redact: false`. diff --git a/analyzer/analyzer.go b/analyzer/analyzer.go new file mode 100644 index 0000000..5be8118 --- /dev/null +++ b/analyzer/analyzer.go @@ -0,0 +1,165 @@ +package analyzer + +import "maps" + +// Rule checks a normalized Statement and returns a Result if an issue is +// found. It returns the result and true if an issue was detected, or a zero +// Result and false otherwise. +// +// Rules operate on the parsed Statement, not the raw SQL string, so a query +// is parsed once per Analyze call and every rule sees the same dialect- +// agnostic view. +type Rule func(s *Statement) (Result, bool) + +// boundRule is a rule together with its registry name and the default +// severity from its RuleSpec. The name is "" for rules supplied directly via +// New (anonymous rules); profile overrides and suppressions only apply to +// named, registry-built rules. hasSeverity distinguishes a registry-built rule +// (whose severity is the spec's DefaultSeverity, the single source of truth) +// from an anonymous rule (which carries its own severity in the Result it +// returns); since SeverityInfo is the zero value, a flag is needed rather than +// a sentinel. +type boundRule struct { + name string + check Rule + severity Severity + hasSeverity bool +} + +// Analyzer holds a set of rules and a Parser, and runs the rules against +// SQL queries. Configuration (disabled rules, severity overrides, per-rule +// settings) is resolved once at construction into the bound rule set and the +// severity map; the per-query Analyze path does no config work. +type Analyzer struct { + rules []boundRule + parser Parser + severity map[string]Severity + // rawQuery, when true, leaves Result.Query unredacted. Default is false + // (redact): the safe default for a tool whose findings flow into logs. + rawQuery bool +} + +// New creates an Analyzer with the given anonymous rules, using the +// zero-dependency FallbackParser. Use WithParser to supply a real dialect +// parser. Rules added this way are not subject to profile overrides (they +// have no registry name); use Default/DefaultWithProfile for configurable +// built-in rules. +func New(rules ...Rule) *Analyzer { + bound := make([]boundRule, len(rules)) + for i, r := range rules { + bound[i] = boundRule{check: r} + } + return &Analyzer{rules: bound, parser: NewFallbackParser()} +} + +// WithParser returns a copy of the Analyzer that uses the given Parser. +// Passing nil resets it to the FallbackParser. +func (a *Analyzer) WithParser(p Parser) *Analyzer { + if p == nil { + p = NewFallbackParser() + } + cp := *a + cp.parser = p + return &cp +} + +// WithRawQuery returns a copy of the Analyzer that leaves Result.Query +// unredacted (the raw SQL, literals and all). Redaction is on by default so +// literal values never reach a log sink; opt out only for local debugging +// where the query text is trusted. Fingerprint is always populated either +// way. +func (a *Analyzer) WithRawQuery() *Analyzer { + cp := *a + cp.rawQuery = true + return &cp +} + +// PrepareQuery returns the query field and fingerprint for a Result built +// outside the rule path (e.g. the runtime slow-query and N+1 findings), +// applying the same redaction policy as Analyze so every emitted Result is +// consistent. display is redacted unless the Analyzer was built +// WithRawQuery; fingerprint is always the PII-free identity. +func (a *Analyzer) PrepareQuery(raw string) (display, fingerprint string) { + fingerprint = Fingerprint(raw) + if a.rawQuery { + return raw, fingerprint + } + return Redact(raw), fingerprint +} + +// Default creates an Analyzer with all registered built-in rules and the +// fallback parser, using each rule's default settings and severity. +func Default() *Analyzer { + return DefaultWithProfile(Profile{}) +} + +// DefaultWithProfile builds an Analyzer from the rule registry with the given +// Profile applied: disabled/whitelisted rules are filtered, per-rule settings +// are passed to each rule's factory, and severity overrides are precomputed. +// The config package uses this to turn a .sqlguard.yml into an Analyzer +// without analyzer ever importing config or YAML. +func DefaultWithProfile(p Profile) *Analyzer { + var bound []boundRule + for _, spec := range specs() { + if p.skip(spec.Name) { + continue + } + bound = append(bound, boundRule{ + name: spec.Name, + check: spec.Factory(p.Settings[spec.Name]), + severity: spec.DefaultSeverity, + hasSeverity: true, + }) + } + var sev map[string]Severity + if len(p.Severity) > 0 { + sev = make(map[string]Severity, len(p.Severity)) + maps.Copy(sev, p.Severity) + } + return &Analyzer{rules: bound, parser: NewFallbackParser(), severity: sev, rawQuery: p.RawQuery} +} + +// Analyze parses the query once and runs all rules against it. If the +// configured parser returns an error, it degrades to the FallbackParser so +// analysis never breaks the caller's query path. Findings for rules named in +// an in-SQL `sqlguard:ignore` directive are suppressed, and severity +// overrides from the active Profile are applied. +func (a *Analyzer) Analyze(query string) []Result { + stmt, err := a.parser.Parse(query) + if err != nil || stmt == nil { + stmt, _ = NewFallbackParser().Parse(query) + } + + ignoreAll, ignored := parseIgnoreDirective(query) + + display, fingerprint := a.PrepareQuery(query) + + results := make([]Result, 0, len(a.rules)) + for _, br := range a.rules { + if ignoreAll { + break + } + r, ok := br.check(stmt) + if !ok { + continue + } + if r.RuleName != "" && ignored[r.RuleName] { + continue + } + // Severity precedence: the spec's DefaultSeverity is the single source + // of truth for a registry-built rule (the rule body no longer sets + // one); a profile override, when present, wins over that. + if br.hasSeverity { + r.Severity = br.severity + } + if a.severity != nil { + if s, has := a.severity[r.RuleName]; has { + r.Severity = s + } + } + r.Query = display + r.Fingerprint = fingerprint + results = append(results, r) + } + return results +} diff --git a/analyzer/analyzer_test.go b/analyzer/analyzer_test.go new file mode 100644 index 0000000..76419c8 --- /dev/null +++ b/analyzer/analyzer_test.go @@ -0,0 +1,434 @@ +package analyzer + +import "testing" + +// run parses q with the fallback parser and applies a single rule, returning +// whether the rule fired. Rules operate on a parsed Statement now, so tests +// go through the parser the same way Analyze does. +func run(t *testing.T, rule Rule, q string) bool { + t.Helper() + st, err := NewFallbackParser().Parse(q) + if err != nil { + t.Fatalf("fallback parser returned error for %q: %v", q, err) + } + _, ok := rule(st) + return ok +} + +func TestCheckSelectStar(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"basic select star", "SELECT * FROM users", true}, + {"lowercase", "select * from users", true}, + {"with where", "SELECT * FROM users WHERE id = 1", true}, + {"qualified star", "SELECT u.* FROM users u", true}, + {"specific columns", "SELECT id, name FROM users", false}, + {"count star", "SELECT COUNT(*) FROM users", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckSelectStar, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckLeadingWildcard(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"leading wildcard", "SELECT * FROM users WHERE email LIKE '%gmail.com%'", true}, + {"trailing only", "SELECT * FROM users WHERE name LIKE 'John%'", false}, + {"double quotes", `SELECT * FROM users WHERE email LIKE "%gmail%"`, true}, + {"ilike leading wildcard", "SELECT * FROM users WHERE email ILIKE '%gmail%'", true}, + {"ilike trailing only", "SELECT * FROM users WHERE name ILIKE 'John%'", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckLeadingWildcard, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckDeleteWithoutWhere(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"no where", "DELETE FROM users", true}, + {"with where", "DELETE FROM users WHERE id = 1", false}, + {"not a delete", "SELECT * FROM users", false}, + {"where in string literal", "DELETE FROM logs WHERE msg = 'no WHERE clause'", false}, + {"fake where in string", "DELETE FROM users SET bio = 'I live WHERE the sun shines'", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckDeleteWithoutWhere, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckUpdateWithoutWhere(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"no where", "UPDATE users SET name = 'test'", true}, + {"with where", "UPDATE users SET name = 'test' WHERE id = 1", false}, + {"not an update", "SELECT * FROM users", false}, + {"where in string literal", "UPDATE users SET bio = 'I live WHERE the sun shines'", true}, + {"real where after string", "UPDATE users SET bio = 'hello' WHERE id = 1", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckUpdateWithoutWhere, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckInsertWithoutColumns(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"no columns", "INSERT INTO users VALUES ('alice', 'alice@test.com')", true}, + {"with columns", "INSERT INTO users (name, email) VALUES ('alice', 'alice@test.com')", false}, + {"not an insert", "SELECT * FROM users", false}, + {"insert select no columns", "INSERT INTO users SELECT name, email FROM staging", true}, + {"insert select with columns", "INSERT INTO users (name, email) SELECT name, email FROM staging", false}, + {"qualified table no columns", "INSERT INTO public.users VALUES ('alice')", true}, + {"mysql set form", "INSERT INTO users SET name = 'alice', email = 'a@test.com'", false}, + {"default values", "INSERT INTO users DEFAULT VALUES", false}, + {"cte insert no columns", "WITH s AS (SELECT 1) INSERT INTO users SELECT * FROM s", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckInsertWithoutColumns, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckSelectWithoutLimit(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"no limit no where", "SELECT id FROM users", true}, + {"with limit", "SELECT id FROM users LIMIT 10", false}, + {"with where", "SELECT id FROM users WHERE id = 1", false}, + {"with both", "SELECT id FROM users WHERE id > 0 LIMIT 10", false}, + {"not a select", "DELETE FROM users", false}, + {"select without from", "SELECT 1", false}, + {"select version", "SELECT version()", false}, + {"select current_timestamp", "SELECT CURRENT_TIMESTAMP", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckSelectWithoutLimit, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckOrderByWithoutLimit(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"order without limit", "SELECT id FROM users ORDER BY name", true}, + {"order with limit", "SELECT id FROM users ORDER BY name LIMIT 10", false}, + {"no order by", "SELECT id FROM users", false}, + {"window order by", "SELECT row_number() OVER (ORDER BY id) FROM users", false}, + {"ordered aggregate", "SELECT GROUP_CONCAT(x ORDER BY y) FROM t", false}, + {"window order by with top-level order by", "SELECT rank() OVER (ORDER BY a) FROM t ORDER BY b", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckOrderByWithoutLimit, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckNonSargablePredicate(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"lower on column", "SELECT id FROM users WHERE LOWER(email) = 'x'", true}, + {"date on column", "SELECT id FROM events WHERE DATE(created_at) = '2020-01-01'", true}, + {"cast on column", "SELECT id FROM users WHERE CAST(id AS text) = '5'", true}, + {"coalesce on column", "SELECT id FROM users WHERE COALESCE(deleted, false) = false", true}, + {"like on wrapped column", "SELECT id FROM users WHERE UPPER(name) LIKE 'A%'", true}, + {"function on value side", "SELECT id FROM users WHERE email = LOWER('X')", false}, + {"now on value side", "SELECT id FROM events WHERE created_at > NOW()", false}, + {"bare column", "SELECT id FROM users WHERE email = 'x'", false}, + {"in list not a function", "SELECT id FROM users WHERE id IN (1, 2, 3)", false}, + {"function in select list", "SELECT LOWER(name) FROM users", false}, + {"function in order by", "SELECT id FROM users WHERE active = true ORDER BY LOWER(name)", false}, + {"commented out predicate", "SELECT id FROM users -- WHERE LOWER(email) = 'x'", false}, + {"predicate after subquery clause", "SELECT id FROM users WHERE id IN (SELECT uid FROM o ORDER BY x LIMIT 1) AND LOWER(name) = 'a'", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckNonSargablePredicate, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckAddNotNullWithoutDefault(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"add not null no default", "ALTER TABLE users ADD COLUMN age int NOT NULL", true}, + {"add not null without column kw", "ALTER TABLE users ADD age int NOT NULL", true}, + {"numeric type with comma", "ALTER TABLE t ADD COLUMN bal numeric(10,2) NOT NULL", true}, + {"multi action one unsafe", "ALTER TABLE t ADD COLUMN a int NOT NULL, ADD COLUMN b int DEFAULT 5", true}, + {"not null with default", "ALTER TABLE users ADD COLUMN age int NOT NULL DEFAULT 0", false}, + {"default before not null", "ALTER TABLE users ADD COLUMN age int DEFAULT 0 NOT NULL", false}, + {"nullable column", "ALTER TABLE users ADD COLUMN age int", false}, + {"set not null on existing", "ALTER TABLE users ALTER COLUMN age SET NOT NULL", false}, + {"add check constraint is not null", "ALTER TABLE users ADD CONSTRAINT chk CHECK (age IS NOT NULL)", false}, + {"not an alter", "INSERT INTO users (age) VALUES (1)", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckAddNotNullWithoutDefault, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckImplicitJoin(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"two table comma join", "SELECT * FROM a, b WHERE a.id = b.id", true}, + {"three tables", "SELECT * FROM a, b, c WHERE a.id = b.id AND b.id = c.id", true}, + {"comma plus explicit join", "SELECT * FROM a, b JOIN c ON b.id = c.id", true}, + {"explicit join only", "SELECT * FROM a JOIN b ON a.id = b.id", false}, + {"single table", "SELECT * FROM users WHERE id = 1", false}, + {"select list comma not from", "SELECT id, name FROM users", false}, + {"comma inside in list", "SELECT * FROM users WHERE id IN (1, 2, 3)", false}, + {"comma inside function", "SELECT * FROM generate_series(1, 10)", false}, + {"comma inside subquery", "SELECT * FROM (SELECT a, b FROM t) sub", false}, + {"from inside extract", "SELECT EXTRACT(YEAR FROM created_at) FROM events", false}, + {"extract then comma join", "SELECT EXTRACT(YEAR FROM ts) FROM events, logs WHERE events.id = logs.id", true}, + {"no from", "SELECT 1, 2", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckImplicitJoin, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckCartesianJoin(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"comma join no where", "SELECT * FROM a, b", true}, + {"three tables no where", "SELECT * FROM a, b, c", true}, + {"explicit cross join", "SELECT * FROM a CROSS JOIN b", true}, + {"cross join with where", "SELECT * FROM a CROSS JOIN b WHERE a.x = 1", false}, + {"comma join with where", "SELECT * FROM a, b WHERE a.id = b.id", false}, + {"join with on", "SELECT * FROM a JOIN b ON a.id = b.id", false}, + {"join with using", "SELECT * FROM a JOIN b USING (id)", false}, + {"natural join", "SELECT * FROM a NATURAL JOIN b", false}, + {"single table", "SELECT * FROM users", false}, + {"subquery cross product", "SELECT * FROM (SELECT * FROM x WHERE y = 1) sub, t", true}, + {"cross join only in subquery", "SELECT x FROM (SELECT * FROM a CROSS JOIN b) sub", false}, + {"conditioned join only in subquery", "SELECT x FROM (SELECT * FROM a JOIN b ON a.id = b.id) sub", false}, + {"no from", "SELECT 1, 2", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckCartesianJoin, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckInListTooLarge(t *testing.T) { + rule := inListRule(5) // flag IN lists with more than 5 elements + tests := []struct { + name string + query string + wantHit bool + }{ + {"over threshold", "SELECT * FROM t WHERE id IN (1, 2, 3, 4, 5, 6)", true}, + {"at threshold", "SELECT * FROM t WHERE id IN (1, 2, 3, 4, 5)", false}, + {"under threshold", "SELECT * FROM t WHERE id IN (1, 2, 3)", false}, + {"not in over threshold", "SELECT * FROM t WHERE id NOT IN (1, 2, 3, 4, 5, 6)", true}, + {"placeholders over threshold", "SELECT * FROM t WHERE id IN (?, ?, ?, ?, ?, ?)", true}, + {"string literals over threshold", "SELECT * FROM t WHERE c IN ('a', 'b', 'c', 'd', 'e', 'f')", true}, + {"subquery not counted", "SELECT * FROM t WHERE id IN (SELECT id FROM other)", false}, + {"no in list", "SELECT * FROM t WHERE id = 5", false}, + {"function commas not an in list", "SELECT * FROM t WHERE x = greatest(1, 2, 3, 4, 5, 6, 7)", false}, + {"largest of multiple lists", "SELECT * FROM t WHERE a IN (1, 2) AND b IN (1, 2, 3, 4, 5, 6, 7)", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, rule, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckLargeOffset(t *testing.T) { + rule := largeOffsetRule(1000) + tests := []struct { + name string + query string + wantHit bool + }{ + {"large offset", "SELECT * FROM t ORDER BY id LIMIT 20 OFFSET 5000", true}, + {"at threshold", "SELECT * FROM t ORDER BY id LIMIT 20 OFFSET 1000", false}, + {"small offset", "SELECT * FROM t ORDER BY id LIMIT 20 OFFSET 40", false}, + {"no offset", "SELECT * FROM t ORDER BY id LIMIT 20", false}, + {"parameterized offset", "SELECT * FROM t ORDER BY id LIMIT 20 OFFSET $1", false}, + {"offset rows fetch", "SELECT * FROM t ORDER BY id OFFSET 5000 ROWS FETCH NEXT 20 ROWS ONLY", true}, + {"mysql limit offset comma", "SELECT * FROM t ORDER BY id LIMIT 5000, 20", true}, + {"mysql limit small offset", "SELECT * FROM t ORDER BY id LIMIT 40, 20", false}, + {"offset as column name", "SELECT offset FROM t WHERE offset = 5000", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, rule, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckSelectDistinct(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"basic distinct", "SELECT DISTINCT name FROM users", true}, + {"lowercase", "select distinct id from t", true}, + {"distinct on postgres", "SELECT DISTINCT ON (dept) name FROM emp", true}, + {"distinct parens", "SELECT DISTINCT(name) FROM users", true}, + {"distinctrow mysql", "SELECT DISTINCTROW name FROM users", true}, + {"distinct in subquery", "SELECT * FROM (SELECT DISTINCT x FROM t) s", true}, + {"no distinct", "SELECT name FROM users", false}, + {"count distinct aggregate", "SELECT COUNT(DISTINCT name) FROM users", false}, + {"distinct in aggregate with group", "SELECT id, COUNT(DISTINCT x) FROM t GROUP BY id", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckSelectDistinct, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestDefaultAnalyzer(t *testing.T) { + a := Default() + + results := a.Analyze("DELETE FROM users") + if len(results) == 0 { + t.Fatal("expected at least one result for DELETE without WHERE") + } + if results[0].Severity != SeverityCritical { + t.Errorf("expected critical severity, got %s", results[0].Severity) + } + + results = a.Analyze("SELECT id FROM users WHERE id = 1") + if len(results) != 0 { + t.Errorf("expected no results for safe query, got %d", len(results)) + } +} + +// TestSpecDefaultSeverityIsAuthoritative locks in that a registry-built rule's +// reported severity comes from its RuleSpec.DefaultSeverity — the single +// source of truth — not from a literal in the rule body. The rule here +// deliberately returns the zero severity (Info); Analyze must report Critical. +func TestSpecDefaultSeverityIsAuthoritative(t *testing.T) { + const name = "zz-spec-severity-probe" + Register(RuleSpec{ + Name: name, + DefaultSeverity: SeverityCritical, + Factory: func(Settings) Rule { + return func(*Statement) (Result, bool) { + return Result{RuleName: name}, true // no Severity set + } + }, + }) + // Don't leak the probe into the global registry; Default() would pick it up. + t.Cleanup(func() { + registryMu.Lock() + delete(registry, name) + registryMu.Unlock() + }) + + a := DefaultWithProfile(Profile{Only: map[string]bool{name: true}}) + got := a.Analyze("SELECT 1") + if len(got) != 1 { + t.Fatalf("expected 1 result, got %d", len(got)) + } + if got[0].Severity != SeverityCritical { + t.Errorf("severity = %s, want CRITICAL (from spec DefaultSeverity)", got[0].Severity) + } + + // A profile override still wins over the spec default. + a = DefaultWithProfile(Profile{ + Only: map[string]bool{name: true}, + Severity: map[string]Severity{name: SeverityInfo}, + }) + if got := a.Analyze("SELECT 1"); len(got) != 1 || got[0].Severity != SeverityInfo { + t.Errorf("profile override not applied: %+v", got) + } +} diff --git a/analyzer/fallback.go b/analyzer/fallback.go new file mode 100644 index 0000000..dc9dd31 --- /dev/null +++ b/analyzer/fallback.go @@ -0,0 +1,575 @@ +package analyzer + +import ( + "regexp" + "strconv" + "strings" +) + +// FallbackParser is the zero-dependency Parser. It removes SQL comments and +// string-literal contents before pattern matching, so keywords inside +// comments or strings (and identifiers like update_at) no longer cause +// false positives. It is best-effort and never returns an error: SQL it +// cannot fully understand still yields a usable Statement with Exact=false. +type FallbackParser struct{} + +// NewFallbackParser returns the default zero-dependency parser. +func NewFallbackParser() *FallbackParser { return &FallbackParser{} } + +var ( + // I?LIKE matches both LIKE and Postgres' case-insensitive ILIKE; the \b + // before it keeps the "I" from matching inside words like DISLIKE. + fbLeadingWildcardRe = regexp.MustCompile(`(?i)\bI?LIKE\s+['"]\s*%`) + // Best-effort capture of a LIKE/ILIKE pattern's literal body. Does not model + // embedded/escaped quotes; the fallback is heuristic by contract. + fbLikeLiteralRe = regexp.MustCompile(`(?i)\bI?LIKE\s+['"]([^'"]*)['"]`) + fbWhereRe = regexp.MustCompile(`(?i)\bWHERE\b`) + fbLimitRe = regexp.MustCompile(`(?i)\bLIMIT\b`) + fbOrderByRe = regexp.MustCompile(`(?i)\bORDER\s+BY\b`) + fbFromRe = regexp.MustCompile(`(?i)\bFROM\b`) + fbSelectStarRe = regexp.MustCompile(`(?i)\bSELECT\s+(?:DISTINCT\s+)?(?:[a-z_][a-z0-9_]*\s*\.\s*)?\*`) + // fbSelectDistinctRe anchors DISTINCT to right after SELECT, so an + // aggregate-level DISTINCT (COUNT(DISTINCT x)) does not match. + fbSelectDistinctRe = regexp.MustCompile(`(?i)\bSELECT\s+DISTINCT(?:ROW)?\b`) + fbIntoRe = regexp.MustCompile(`(?i)\bINTO\b`) + // fbInsertDataRe marks the start of an INSERT's data clause, after the + // target table (and its optional column list). VALUES? covers MySQL's + // singular VALUE; SELECT/WITH/TABLE cover INSERT ... SELECT and friends. + fbInsertDataRe = regexp.MustCompile(`(?i)\b(VALUES?|SELECT|WITH|TABLE|SET|DEFAULT)\b`) + fbLeadKindRe = regexp.MustCompile(`(?i)^\s*\(*\s*(SELECT|INSERT|UPDATE|DELETE|WITH)\b`) + fbDMLWordRe = regexp.MustCompile(`(?i)\b(INSERT|UPDATE|DELETE)\b`) + + // fbWhereRegionEndRe marks the first clause keyword that ends the WHERE + // region, so a function in ORDER BY / GROUP BY / HAVING isn't read as a + // WHERE predicate. + fbWhereRegionEndRe = regexp.MustCompile(`(?i)\b(GROUP\s+BY|ORDER\s+BY|HAVING|LIMIT|OFFSET|WINDOW|FETCH|FOR\s+UPDATE)\b`) + // fbFuncOnColumnRe matches IDENT(args): a function/cast call whose + // closing paren is immediately followed by a comparison operator. The + // operator-after-paren shape is what restricts it to the column side of a + // predicate (WHERE LOWER(c) = ...), not the value side (WHERE c = ABS(x)). + fbFuncOnColumnRe = regexp.MustCompile(`(?i)\b([a-z_][a-z0-9_]*)\s*\(([^()]*)\)\s*(?:=|<>|!=|<=|>=|<|>|\bLIKE\b|\bIN\b|\bBETWEEN\b)`) + // fbArgIdentRe checks that a function's arguments contain a column-like + // identifier, so NOW() and LOWER('x') (literal blanked to '') are skipped. + fbArgIdentRe = regexp.MustCompile(`[a-zA-Z_]`) + + fbAlterTableRe = regexp.MustCompile(`(?i)^\s*ALTER\s+TABLE\b`) + // fbAddActionRe matches the start of an ALTER action and captures the + // first token after ADD [COLUMN] — a column name for a column add, or a + // keyword (CONSTRAINT, CHECK, ...) for the forms we must skip. + fbAddActionRe = regexp.MustCompile(`(?i)\bADD\s+(?:COLUMN\s+)?(\w+)`) + fbNotNullRe = regexp.MustCompile(`(?i)\bNOT\s+NULL\b`) + fbDefaultRe = regexp.MustCompile(`(?i)\bDEFAULT\b`) + + // fbFromRegionEndRe marks the first clause keyword that ends the FROM + // region, so commas after it (an IN list, GROUP BY, etc.) aren't read as + // join separators. + fbFromRegionEndRe = regexp.MustCompile(`(?i)\b(WHERE|GROUP\s+BY|ORDER\s+BY|HAVING|LIMIT|OFFSET|WINDOW|FETCH|FOR|UNION|EXCEPT|INTERSECT)\b`) + fbJoinRe = regexp.MustCompile(`(?i)\bJOIN\b`) + // fbJoinCondRe matches anything that conditions a join — an ON/USING + // predicate or a NATURAL join (which joins on common columns) — so a join + // carrying one of these is not treated as a cartesian product. + fbJoinCondRe = regexp.MustCompile(`(?i)\b(ON|USING|NATURAL)\b`) + // fbInListRe matches the opening of an IN value list (NOT IN matches too). + fbInListRe = regexp.MustCompile(`(?i)\bIN\s*\(`) + // fbSubqueryStartRe recognizes an IN (...) body that is a subquery (or set) + // rather than a value list, so it is not counted. + fbSubqueryStartRe = regexp.MustCompile(`(?i)^\(*\s*(SELECT|WITH|VALUES|TABLE)\b`) + // fbOffsetRe captures a literal standard OFFSET n (incl. OFFSET n ROWS). + fbOffsetRe = regexp.MustCompile(`(?i)\bOFFSET\s+(\d+)`) + // fbLimitOffsetRe captures the offset of MySQL's LIMIT offset, count form. + fbLimitOffsetRe = regexp.MustCompile(`(?i)\bLIMIT\s+(\d+)\s*,\s*\d+`) +) + +// fbNonSargableSkipFuncs are tokens that can appear as IDENT before "(" but +// are SQL keywords, not functions wrapping a column. +var fbNonSargableSkipFuncs = map[string]bool{ + "in": true, "exists": true, "any": true, "all": true, + "some": true, "and": true, "or": true, "not": true, +} + +// fbAddNonColumnKeywords are the tokens following ADD that mean the action is +// not a column add (so a stray NOT NULL, e.g. inside a CHECK constraint, isn't +// mistaken for a NOT NULL column). +var fbAddNonColumnKeywords = map[string]bool{ + "constraint": true, "primary": true, "foreign": true, + "unique": true, "check": true, "key": true, "index": true, +} + +// Parse implements Parser. It always returns a non-nil Statement and a nil +// error. +func (p *FallbackParser) Parse(sql string) (*Statement, error) { + st := &Statement{Raw: sql, Exact: false} + + noComments := stripComments(sql) + + // Leading-wildcard LIKE is detected before literal contents are blanked, + // because the pattern lives inside the literal. Comments are already gone, + // so a commented-out LIKE won't trigger. + st.LeadingWildcardLike = fbLeadingWildcardRe.MatchString(noComments) + if st.LeadingWildcardLike { + st.LeadingWildcardTermLen = leadingWildcardTermLen(noComments) + } + + sanitized := blankStringLiterals(noComments) + + st.Kind = detectKind(sanitized) + st.HasWhere = fbWhereRe.MatchString(sanitized) + st.HasLimit = fbLimitRe.MatchString(sanitized) + st.HasOrderBy = hasTopLevelOrderBy(sanitized) + st.HasFrom = fbFromRe.MatchString(sanitized) + st.SelectStar = fbSelectStarRe.MatchString(sanitized) + st.SelectDistinct = fbSelectDistinctRe.MatchString(sanitized) + st.NonSargablePredicate = hasNonSargablePredicate(sanitized) + st.AddNotNullNoDefault = hasUnsafeAddNotNull(sanitized) + st.ImplicitCommaJoin = hasImplicitCommaJoin(sanitized) + st.CartesianJoin = hasCartesianJoin(sanitized) + st.MaxInListLen = maxInListLen(sanitized) + st.OffsetValue = maxOffset(sanitized) + + if st.Kind == StmtInsert { + st.InsertColumnsListed = insertColumnsListed(sanitized) + } + + return st, nil +} + +// insertColumnsListed reports whether an INSERT names its target columns +// explicitly. It inspects the span between INTO and the data clause +// (VALUES / SELECT / WITH / TABLE): an explicit column list shows up there as a +// "(". The "VALUES"-only shape the old regex matched missed INSERT ... SELECT +// (and CTE-prefixed inserts), which carry the same schema-change risk. MySQL's +// "SET col = ..." names its columns, and "DEFAULT VALUES" inserts no data, so +// both count as listed (no positional column-order risk to warn about). +// Comment-free, literal-blanked input expected; heuristic by contract. +func insertColumnsListed(sanitized string) bool { + loc := fbIntoRe.FindStringIndex(sanitized) + if loc == nil { + return true // no INTO found — can't tell, don't flag + } + rest := sanitized[loc[1]:] + data := fbInsertDataRe.FindStringIndex(rest) + if data == nil { + return true // no recognizable data clause — don't flag + } + switch strings.ToUpper(strings.TrimSpace(rest[data[0]:data[1]])) { + case "SET", "DEFAULT": + return true + } + // Columns are listed iff a "(" appears between the table name and the data + // clause. A bare table reference (incl. schema.table) has no parens there. + return strings.Contains(rest[:data[0]], "(") +} + +// leadingWildcardTermLen returns the length of the longest searchable term +// (the LIKE literal with surrounding '%' trimmed) among patterns that begin +// with a wildcard. Comment-free input is expected. +func leadingWildcardTermLen(noComments string) int { + max := 0 + for _, m := range fbLikeLiteralRe.FindAllStringSubmatch(noComments, -1) { + body := strings.TrimSpace(m[1]) + if !strings.HasPrefix(body, "%") { + continue + } + if n := len(strings.Trim(body, "%")); n > max { + max = n + } + } + return max +} + +// hasNonSargablePredicate reports whether the WHERE clause applies a function +// or cast to a column (WHERE LOWER(email) = ...), which defeats an index on +// that column. Input must be comment-free and have its string literals +// blanked. Scope is limited to the WHERE region so functions in the SELECT +// list, ORDER BY, or GROUP BY don't false-fire. +func hasNonSargablePredicate(sanitized string) bool { + region := whereRegion(sanitized) + if region == "" { + return false + } + for _, m := range fbFuncOnColumnRe.FindAllStringSubmatch(region, -1) { + if fbNonSargableSkipFuncs[strings.ToLower(m[1])] { + continue // a keyword like IN(...) / EXISTS(...), not a function + } + if !fbArgIdentRe.MatchString(m[2]) { + continue // no column in the args (e.g. NOW(), LOWER('x')) + } + return true + } + return false +} + +// whereRegion returns the slice of sanitized SQL from the WHERE keyword up to +// the next clause keyword (ORDER BY, GROUP BY, HAVING, LIMIT, ...), or "" when +// there is no WHERE clause. +func whereRegion(sanitized string) string { + loc := fbWhereRe.FindStringIndex(sanitized) + if loc == nil { + return "" + } + region := sanitized[loc[1]:] + // End the region at the first clause keyword that sits at the WHERE's own + // nesting level. A keyword inside a subquery in the WHERE (e.g. ORDER BY / + // LIMIT in "WHERE id IN (SELECT ... ORDER BY x LIMIT 1)") is at depth > 0 + // and must not cut the region short, which would drop predicates after it. + for _, end := range fbWhereRegionEndRe.FindAllStringIndex(region, -1) { + if parenDepthBefore(region, end[0]) == 0 { + return region[:end[0]] + } + } + return region +} + +// hasUnsafeAddNotNull reports whether an ALTER TABLE adds a NOT NULL column +// without a DEFAULT — which errors or rewrites the table on a populated table. +// Input must be comment-free with string literals blanked. The statement is +// split on top-level commas so each ADD action is judged independently (and a +// numeric type's own comma, e.g. NUMERIC(10,2), isn't a split point). +func hasUnsafeAddNotNull(sanitized string) bool { + if !fbAlterTableRe.MatchString(sanitized) { + return false + } + for _, seg := range splitTopLevelCommas(sanitized) { + m := fbAddActionRe.FindStringSubmatch(seg) + if m == nil { + continue // not an ADD action + } + if fbAddNonColumnKeywords[strings.ToLower(m[1])] { + continue // ADD CONSTRAINT / CHECK / KEY / ... — not a column add + } + if fbNotNullRe.MatchString(seg) && !fbDefaultRe.MatchString(seg) { + return true + } + } + return false +} + +// splitTopLevelCommas splits s on commas that are not nested inside +// parentheses. Used to separate the actions of a multi-action ALTER TABLE +// without breaking parenthesized type specs or expressions. +func splitTopLevelCommas(s string) []string { + var segs []string + depth, start := 0, 0 + for i := 0; i < len(s); i++ { + switch s[i] { + case '(': + depth++ + case ')': + if depth > 0 { + depth-- + } + case ',': + if depth == 0 { + segs = append(segs, s[start:i]) + start = i + 1 + } + } + } + return append(segs, s[start:]) +} + +// hasImplicitCommaJoin reports whether the FROM clause lists multiple tables +// separated by top-level commas (FROM a, b) rather than explicit JOIN syntax. +// Input must be comment-free with string literals blanked. The FROM region is +// found and bounded paren-aware so that a FROM inside EXTRACT(... FROM ...) or +// a subquery, and commas inside function calls / subqueries / IN lists, are +// not mistaken for join separators. +func hasImplicitCommaJoin(sanitized string) bool { + region := fromRegion(sanitized) + if region == "" { + return false + } + return len(splitTopLevelCommas(region)) > 1 +} + +// hasCartesianJoin reports whether the FROM clause joins multiple tables (via +// a top-level comma, CROSS JOIN, or a bare JOIN) with no join condition +// (ON/USING/NATURAL) anywhere in the FROM region and no top-level WHERE — an +// unconditioned cartesian product. It is deliberately conservative: if any +// join condition or WHERE filter is present it does not fire, so mixed queries +// (some joins conditioned) yield a false negative rather than a false positive. +// Input must be comment-free with string literals blanked. +func hasCartesianJoin(sanitized string) bool { + region := fromRegion(sanitized) + if region == "" { + return false + } + multiTable := len(splitTopLevelCommas(region)) > 1 || hasTopLevelJoin(region) + if !multiTable { + return false + } + if fbJoinCondRe.MatchString(region) || hasTopLevelWhere(sanitized) { + return false + } + return true +} + +// hasTopLevelWhere reports whether a WHERE keyword appears at parenthesis depth +// zero (a WHERE inside a subquery does not filter an outer cartesian product). +func hasTopLevelWhere(sanitized string) bool { + for _, loc := range fbWhereRe.FindAllStringIndex(sanitized, -1) { + if parenDepthBefore(sanitized, loc[0]) == 0 { + return true + } + } + return false +} + +// hasTopLevelJoin reports whether a JOIN keyword appears at parenthesis depth +// zero within the FROM region — an outer-level join (incl. CROSS/bare JOIN), +// not one inside a subquery. region is a depth-zero slice from fromRegion, so +// depth is measured relative to it. This is the JOIN counterpart of the +// paren-aware comma split: without it a JOIN in a FROM-clause subquery would be +// read as an outer cartesian product. +func hasTopLevelJoin(region string) bool { + for _, loc := range fbJoinRe.FindAllStringIndex(region, -1) { + if parenDepthBefore(region, loc[0]) == 0 { + return true + } + } + return false +} + +// hasTopLevelOrderBy reports whether an ORDER BY appears at parenthesis depth +// zero — a result-set sort, not a window-function (OVER (ORDER BY ...)), +// ordered-aggregate (GROUP_CONCAT(... ORDER BY ...), WITHIN GROUP (ORDER BY +// ...)), or subquery ordering, none of which sort the statement's result set. +func hasTopLevelOrderBy(sanitized string) bool { + for _, loc := range fbOrderByRe.FindAllStringIndex(sanitized, -1) { + if parenDepthBefore(sanitized, loc[0]) == 0 { + return true + } + } + return false +} + +// maxInListLen returns the largest element count among the statement's +// IN (...) value lists. IN (SELECT ...) / IN (VALUES ...) subqueries are not +// counted. Input must be comment-free with string literals blanked — commas +// between blanked literals survive, so element counting is unaffected. +func maxInListLen(sanitized string) int { + max := 0 + for _, loc := range fbInListRe.FindAllStringIndex(sanitized, -1) { + inner, ok := parenContent(sanitized, loc[1]-1) // loc[1]-1 is the "(" + if !ok { + continue + } + if strings.TrimSpace(inner) == "" || fbSubqueryStartRe.MatchString(strings.TrimSpace(inner)) { + continue + } + if n := len(splitTopLevelCommas(inner)); n > max { + max = n + } + } + return max +} + +// maxOffset returns the largest literal offset in the statement, considering +// both the standard OFFSET n form and MySQL's LIMIT offset, count form. A +// parameterized offset (OFFSET $1 / ?) matches neither and yields 0. Input +// must be comment-free with string literals blanked (numeric literals, which +// is what an offset is, are left intact). An offset literal too large for int +// is ignored (treated as no offset) — a rare, harmless false negative. +func maxOffset(sanitized string) int { + max := 0 + consider := func(re *regexp.Regexp) { + for _, m := range re.FindAllStringSubmatch(sanitized, -1) { + if n, err := strconv.Atoi(m[1]); err == nil && n > max { + max = n + } + } + } + consider(fbOffsetRe) + consider(fbLimitOffsetRe) + return max +} + +// parenContent returns the substring between the parenthesis at index open and +// its matching close paren (exclusive of both), and true, or "" and false if +// unbalanced. +func parenContent(s string, open int) (string, bool) { + depth := 0 + for i := open; i < len(s); i++ { + switch s[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + return s[open+1 : i], true + } + } + } + return "", false +} + +// fromRegion returns the slice of sanitized SQL between the first top-level +// FROM keyword and the next top-level clause keyword (WHERE, GROUP BY, a set +// operator, ...), or "" when there is no top-level FROM. "Top-level" means at +// parenthesis depth zero, so subquery and function-argument keywords are +// ignored. +func fromRegion(sanitized string) string { + fromEnd := -1 + for _, loc := range fbFromRe.FindAllStringIndex(sanitized, -1) { + if parenDepthBefore(sanitized, loc[0]) == 0 { + fromEnd = loc[1] + break + } + } + if fromEnd == -1 { + return "" + } + region := sanitized[fromEnd:] + for _, loc := range fbFromRegionEndRe.FindAllStringIndex(region, -1) { + if parenDepthBefore(region, loc[0]) == 0 { + return region[:loc[0]] + } + } + return region +} + +// parenDepthBefore returns the net parenthesis nesting depth at index idx +// (count of unmatched '(' in s[:idx]). +func parenDepthBefore(s string, idx int) int { + depth := 0 + for i := range idx { + switch s[i] { + case '(': + depth++ + case ')': + if depth > 0 { + depth-- + } + } + } + return depth +} + +func detectKind(sanitized string) StmtKind { + m := fbLeadKindRe.FindStringSubmatch(sanitized) + if m == nil { + return StmtOther + } + switch strings.ToUpper(m[1]) { + case "SELECT": + return StmtSelect + case "INSERT": + return StmtInsert + case "UPDATE": + return StmtUpdate + case "DELETE": + return StmtDelete + case "WITH": + // A CTE feeds a main statement. Best-effort: if an INSERT/UPDATE/DELETE + // keyword appears anywhere, treat it as that; otherwise a SELECT. + if w := fbDMLWordRe.FindString(sanitized); w != "" { + switch strings.ToUpper(w) { + case "INSERT": + return StmtInsert + case "UPDATE": + return StmtUpdate + case "DELETE": + return StmtDelete + } + } + return StmtSelect + } + return StmtOther +} + +// stripComments removes -- line comments and /* */ block comments, replacing +// each with a single space so token boundaries are preserved. It does not +// remove comment markers that appear inside string literals. +func stripComments(s string) string { + var b strings.Builder + b.Grow(len(s)) + for i := 0; i < len(s); { + switch c := s[i]; { + case c == '\'' || c == '"': + i = copyStringLiteral(&b, s, i) + case c == '-' && i+1 < len(s) && s[i+1] == '-': + i = skipLineComment(s, i) + b.WriteByte(' ') + case c == '/' && i+1 < len(s) && s[i+1] == '*': + i = skipBlockComment(s, i) + b.WriteByte(' ') + default: + b.WriteByte(c) + i++ + } + } + return b.String() +} + +// copyStringLiteral writes the string literal that begins at s[i] (a quote +// byte) verbatim, honoring ” / "" doubled-quote escapes, and returns the +// index just past the literal. +func copyStringLiteral(b *strings.Builder, s string, i int) int { + q := s[i] + b.WriteByte(q) + i++ + for i < len(s) { + b.WriteByte(s[i]) + if s[i] == q { + if i+1 < len(s) && s[i+1] == q { // doubled-quote escape + b.WriteByte(s[i+1]) + i += 2 + continue + } + return i + 1 + } + i++ + } + return i +} + +// skipLineComment returns the index of the newline (or end of input) that +// terminates the -- comment starting at i. +func skipLineComment(s string, i int) int { + for i < len(s) && s[i] != '\n' { + i++ + } + return i +} + +// skipBlockComment returns the index just past the */ that closes the block +// comment starting at i (or end of input if unterminated). +func skipBlockComment(s string, i int) int { + i += 2 + for i+1 < len(s) && (s[i] != '*' || s[i+1] != '/') { + i++ + } + return i + 2 +} + +// blankStringLiterals replaces the contents of every string literal with an +// empty literal, so SQL keywords that appear inside string values cannot be +// mistaken for clauses. Input must already be comment-free. +func blankStringLiterals(s string) string { + var b strings.Builder + b.Grow(len(s)) + for i := 0; i < len(s); { + c := s[i] + if c == '\'' || c == '"' { + q := c + b.WriteByte(q) + i++ + for i < len(s) { + if s[i] == q { + if i+1 < len(s) && s[i+1] == q { // doubled-quote escape + i += 2 + continue + } + i++ + break + } + i++ + } + b.WriteByte(q) + continue + } + b.WriteByte(c) + i++ + } + return b.String() +} diff --git a/analyzer/fallback_test.go b/analyzer/fallback_test.go new file mode 100644 index 0000000..a6d0476 --- /dev/null +++ b/analyzer/fallback_test.go @@ -0,0 +1,111 @@ +package analyzer + +import "testing" + +// These cover the exact false-positive classes the production-grade review +// flagged for the old raw-regex engine: comments, keyword-like identifiers, +// CTEs, subqueries, multi-statement input, and driver placeholders. The +// fallback parser must not misfire and must never panic or error. + +func TestFallback_CommentsDoNotTriggerRules(t *testing.T) { + a := Default() + tests := []struct { + name string + query string + }{ + {"line comment with DELETE", "SELECT id FROM users WHERE id = 1 -- DELETE FROM users everything"}, + {"block comment with WHERE", "DELETE FROM users /* no WHERE here on purpose */ WHERE id = 1"}, + {"line comment hiding where", "UPDATE users SET active = false WHERE id = 1 -- WHERE"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for _, r := range a.Analyze(tt.query) { + if r.RuleName == "delete-without-where" || r.RuleName == "update-without-where" { + t.Errorf("%s: unexpected %s on commented query: %s", tt.name, r.RuleName, tt.query) + } + } + }) + } +} + +func TestFallback_CommentedOutClausesAreNotCounted(t *testing.T) { + a := New(CheckDeleteWithoutWhere) + // The only WHERE is inside a comment, so this DELETE is genuinely unsafe. + got := a.Analyze("DELETE FROM users -- WHERE id = 1") + if len(got) != 1 || got[0].RuleName != "delete-without-where" { + t.Errorf("expected delete-without-where when WHERE is only in a comment, got %+v", got) + } +} + +func TestFallback_KeywordLikeIdentifiers(t *testing.T) { + a := Default() + // Column/table names containing keyword substrings must not be parsed + // as clauses. + queries := []string{ + "SELECT id, update_at, where_clause FROM orders WHERE id = 1 LIMIT 1", + "SELECT limited, ordered_by FROM report WHERE k = 1 LIMIT 10", + "UPDATE wherehouse SET stock = 0 WHERE id = 7", + } + for _, q := range queries { + results := a.Analyze(q) + for _, r := range results { + if r.RuleName == "update-without-where" || r.RuleName == "delete-without-where" { + t.Errorf("keyword-like identifier misparsed in %q: got %s", q, r.RuleName) + } + } + } +} + +func TestFallback_CTEAndSubquery(t *testing.T) { + p := NewFallbackParser() + + st, err := p.Parse("WITH recent AS (SELECT id FROM orders WHERE ts > now()) DELETE FROM orders WHERE id IN (SELECT id FROM recent)") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if st.Kind != StmtDelete { + t.Errorf("CTE-wrapped DELETE: got kind %v, want StmtDelete", st.Kind) + } + if !st.HasWhere { + t.Error("CTE-wrapped DELETE: WHERE clause not detected") + } + + st, _ = p.Parse("WITH t AS (SELECT 1) SELECT id FROM t WHERE id = 1") + if st.Kind != StmtSelect { + t.Errorf("CTE SELECT: got kind %v, want StmtSelect", st.Kind) + } +} + +func TestFallback_PlaceholdersNeverErrorOrPanic(t *testing.T) { + p := NewFallbackParser() + queries := []string{ + "SELECT * FROM users WHERE id = $1", + "SELECT * FROM users WHERE id = ? AND name = ?", + "SELECT * FROM users WHERE id = :id", + "INSERT INTO t VALUES ($1, $2); DELETE FROM other", + "", + "not even sql", + "SELECT '%' || $1 || '%'", + } + for _, q := range queries { + st, err := p.Parse(q) + if err != nil { + t.Errorf("fallback returned error for %q: %v (it must never error)", q, err) + } + if st == nil { + t.Errorf("fallback returned nil Statement for %q", q) + continue + } + if st.Exact { + t.Errorf("fallback Statement for %q must have Exact=false", q) + } + } +} + +func TestFallback_MultiStatementLeadingKind(t *testing.T) { + p := NewFallbackParser() + st, _ := p.Parse("DELETE FROM a WHERE id = 1; DROP TABLE b") + if st.Kind != StmtDelete { + t.Errorf("multi-statement: got kind %v, want StmtDelete (leading statement)", st.Kind) + } +} diff --git a/analyzer/parser.go b/analyzer/parser.go new file mode 100644 index 0000000..979732c --- /dev/null +++ b/analyzer/parser.go @@ -0,0 +1,18 @@ +package analyzer + +// Parser turns a raw SQL string into sqlguard's normalized Statement. +// +// Implementations: +// +// - FallbackParser (this package): zero-dependency, best-effort, never +// returns an error. +// - parsers/pgparser, parsers/mysqlparser (optional modules): real +// dialect ASTs, exact analysis, fall back to FallbackParser on parse +// failure. +// +// A Parser used on the runtime query path MUST NOT panic and SHOULD avoid +// returning an error for SQL it merely doesn't understand — degrade to a +// best-effort Statement instead, so analysis never breaks db.Query. +type Parser interface { + Parse(sql string) (*Statement, error) +} diff --git a/analyzer/profile_test.go b/analyzer/profile_test.go new file mode 100644 index 0000000..6f3c277 --- /dev/null +++ b/analyzer/profile_test.go @@ -0,0 +1,144 @@ +package analyzer + +import ( + "slices" + "testing" +) + +func TestRuleNamesCoversBuiltins(t *testing.T) { + names := RuleNames() + for _, want := range []string{ + "select-star", "leading-wildcard", "delete-without-where", + "update-without-where", "insert-without-columns", + "select-without-limit", "orderby-without-limit", + } { + if !slices.Contains(names, want) { + t.Errorf("rule %q not registered; got %v", want, names) + } + } +} + +func TestDefaultMatchesRegistry(t *testing.T) { + // Default() must behave exactly as before the registry refactor. + a := Default() + if got := a.Analyze("DELETE FROM users"); len(got) == 0 || got[0].Severity != SeverityCritical { + t.Fatalf("expected critical delete-without-where, got %+v", got) + } + if got := a.Analyze("SELECT id FROM users WHERE id = 1"); len(got) != 0 { + t.Errorf("expected no findings, got %+v", got) + } +} + +func TestProfileDisable(t *testing.T) { + a := DefaultWithProfile(Profile{Disabled: map[string]bool{"select-star": true}}) + for _, r := range a.Analyze("SELECT * FROM users") { + if r.RuleName == "select-star" { + t.Fatal("select-star should be disabled") + } + } +} + +func TestProfileOnlyWhitelist(t *testing.T) { + a := DefaultWithProfile(Profile{Only: map[string]bool{"select-star": true}}) + // delete-without-where must not run; only select-star is whitelisted. + got := a.Analyze("DELETE FROM users") + if len(got) != 0 { + t.Errorf("expected no findings with whitelist, got %+v", got) + } + if got := a.Analyze("SELECT * FROM users"); len(got) != 1 || got[0].RuleName != "select-star" { + t.Errorf("expected only select-star, got %+v", got) + } +} + +func TestProfileSeverityOverride(t *testing.T) { + a := DefaultWithProfile(Profile{Severity: map[string]Severity{"select-star": SeverityInfo}}) + got := a.Analyze("SELECT * FROM users") + if len(got) == 0 || got[0].RuleName != "select-star" || got[0].Severity != SeverityInfo { + t.Fatalf("expected select-star downgraded to INFO, got %+v", got) + } +} + +func TestProfileSettingsLeadingWildcardMinLength(t *testing.T) { + a := DefaultWithProfile(Profile{ + Settings: map[string]Settings{"leading-wildcard": {"min-length": 5}}, + }) + // 2-char term -> below threshold, not flagged. + if hits := filterByRule(a.Analyze("SELECT id FROM t WHERE x LIKE '%ab%'"), "leading-wildcard"); hits != 0 { + t.Errorf("short pattern should be tolerated with min-length=5") + } + // 6-char term -> flagged. + if hits := filterByRule(a.Analyze("SELECT id FROM t WHERE x LIKE '%abcdef%'"), "leading-wildcard"); hits != 1 { + t.Errorf("long pattern should still be flagged with min-length=5") + } +} + +func TestProfileSettingsInListMaxLength(t *testing.T) { + a := DefaultWithProfile(Profile{ + Settings: map[string]Settings{"in-list-too-large": {"max-length": 3}}, + }) + // 3 elements -> at threshold, not flagged. + if hits := filterByRule(a.Analyze("SELECT id FROM t WHERE id IN (1, 2, 3)"), "in-list-too-large"); hits != 0 { + t.Errorf("list at threshold should be tolerated with max-length=3") + } + // 4 elements -> over threshold, flagged. + if hits := filterByRule(a.Analyze("SELECT id FROM t WHERE id IN (1, 2, 3, 4)"), "in-list-too-large"); hits != 1 { + t.Errorf("list over threshold should be flagged with max-length=3") + } +} + +func TestProfileSettingsLargeOffsetThreshold(t *testing.T) { + a := DefaultWithProfile(Profile{ + Settings: map[string]Settings{"large-offset": {"threshold": 100}}, + }) + // offset 100 -> at threshold, not flagged. + if hits := filterByRule(a.Analyze("SELECT id FROM t LIMIT 10 OFFSET 100"), "large-offset"); hits != 0 { + t.Errorf("offset at threshold should be tolerated with threshold=100") + } + // offset 200 -> over threshold, flagged. + if hits := filterByRule(a.Analyze("SELECT id FROM t LIMIT 10 OFFSET 200"), "large-offset"); hits != 1 { + t.Errorf("offset over threshold should be flagged with threshold=100") + } +} + +func TestInlineSuppression(t *testing.T) { + a := Default() + + if got := a.Analyze("SELECT * FROM users -- sqlguard:ignore"); len(got) != 0 { + t.Errorf("bare ignore should suppress all, got %+v", got) + } + // Scoped: suppress select-star only; delete-without-where still fires. + q := "DELETE FROM users /* sqlguard:ignore:select-star */" + if got := a.Analyze(q); len(got) != 1 || got[0].RuleName != "delete-without-where" { + t.Errorf("scoped ignore should keep delete-without-where, got %+v", got) + } + if got := a.Analyze("SELECT * FROM users WHERE id = 1 /* sqlguard:ignore:select-star */"); len(got) != 0 { + t.Errorf("select-star should be suppressed, got %+v", got) + } + // The token inside a string literal must NOT suppress (no comment marker). + if got := a.Analyze("SELECT * FROM users WHERE note = 'sqlguard:ignore'"); len(got) == 0 { + t.Error("string-literal text must not act as a suppression directive") + } +} + +func TestParseIgnoreComment(t *testing.T) { + if all, _, found := ParseIgnoreComment("// sqlguard:ignore"); !found || !all { + t.Error("expected bare directive parsed as all") + } + all, rules, found := ParseIgnoreComment("// noise sqlguard:ignore:select-star, leading-wildcard") + if !found || all || !rules["select-star"] || !rules["leading-wildcard"] { + t.Errorf("expected scoped rules, got all=%v rules=%v found=%v", all, rules, found) + } + if _, _, found := ParseIgnoreComment("// just a normal comment"); found { + t.Error("non-directive comment should not be found") + } +} + +func filterByRule(rs []Result, name string) int { + n := 0 + for _, r := range rs { + if r.RuleName == name { + n++ + } + } + return n +} diff --git a/analyzer/redact.go b/analyzer/redact.go new file mode 100644 index 0000000..ea3de97 --- /dev/null +++ b/analyzer/redact.go @@ -0,0 +1,157 @@ +package analyzer + +import ( + "regexp" + "strings" +) + +// Redact returns sql with comments stripped and every single-quoted string +// literal and numeric literal replaced by a single "?" placeholder. Query +// structure, keywords, and identifiers (including double-quoted and +// backtick-quoted identifiers) are preserved, so the result stays readable +// and analyzable but carries no literal values — no emails, tokens, or other +// PII reach a log sink. +// +// It is a zero-dependency lexical pass, not a full parser: it is +// intentionally conservative (e.g. it does not special-case hex/scientific +// forms beyond a simple exponent) and never errors. Use it whenever a query +// is about to leave the process. +func Redact(sql string) string { + s := stripComments(sql) + var b strings.Builder + b.Grow(len(s)) + + var prev byte // last byte written to output, 0 at start + for i := 0; i < len(s); { + c := s[i] + switch { + case c == '\'': + // String literal — the classic PII carrier. Replace its whole + // body (honoring '' escapes) with one placeholder. + i = skipSingleQuoted(s, i) + b.WriteByte('?') + prev = '?' + case c == '"' || c == '`': + // Quoted identifier (ANSI double-quote / MySQL backtick). Copy + // verbatim so a quote-enclosed name or a stray ' inside it does + // not corrupt structure or trip the literal branch. + j := skipQuoted(s, i, c) + b.WriteString(s[i:j]) + prev = s[j-1] + i = j + case isDigit(c) && !suppressesNumber(prev): + j := scanNumber(s, i) + b.WriteByte('?') + prev = '?' + i = j + default: + b.WriteByte(c) + prev = c + i++ + } + } + return b.String() +} + +var fpListRe = regexp.MustCompile(`\(\?(?:, ?\?)+\)`) + +// Fingerprint returns a stable, PII-free identity for sql: it is Redact +// followed by whitespace collapsing and IN/VALUES-list folding +// ("(?, ?, ?)" -> "(?)") so that queries differing only in literal values or +// list length share one fingerprint. A trailing ";" is trimmed. +// +// The result is safe to use as a low-cardinality metric label or log key — +// it is the canonical query identity the runtime, the N+1 tracker, and any +// metrics/observability adapter group on. +func Fingerprint(sql string) string { + r := Redact(sql) + r = strings.Join(strings.Fields(r), " ") + r = fpListRe.ReplaceAllString(r, "(?)") + return strings.TrimRight(r, "; ") +} + +// IsMultiStatement reports whether sql contains more than one SQL statement, +// i.e. a ";" statement separator followed by further non-whitespace content. +// Comments and string-literal bodies are removed first (reusing the same +// comment/literal-aware lexer the parser uses), so the check cannot be +// defeated by a ";" hidden in a -- / /* */ comment or inside a string +// literal — the evasion the brittle strings.Contains(query, ";") check +// allowed. A single trailing ";" is not multi-statement. +func IsMultiStatement(sql string) bool { + s := blankStringLiterals(stripComments(sql)) + if _, rest, found := strings.Cut(s, ";"); found { + return strings.TrimSpace(rest) != "" + } + return false +} + +func isDigit(c byte) bool { return c >= '0' && c <= '9' } + +// suppressesNumber reports whether a digit following prev is part of an +// identifier (col1, int8) or a bind placeholder ($1, @p1) rather than a +// numeric literal, so it must not be redacted. +func suppressesNumber(prev byte) bool { + switch { + case prev >= 'a' && prev <= 'z', prev >= 'A' && prev <= 'Z', + prev >= '0' && prev <= '9': + return true + case prev == '_' || prev == '$' || prev == '@': + return true + } + return false +} + +// scanNumber returns the index just past the numeric literal starting at i +// (digits, an optional decimal point, and an optional e[+-]?digits exponent). +func scanNumber(s string, i int) int { + for i < len(s) && (isDigit(s[i]) || s[i] == '.') { + i++ + } + if i < len(s) && (s[i] == 'e' || s[i] == 'E') { + j := i + 1 + if j < len(s) && (s[j] == '+' || s[j] == '-') { + j++ + } + if j < len(s) && isDigit(s[j]) { + for j < len(s) && isDigit(s[j]) { + j++ + } + i = j + } + } + return i +} + +// skipSingleQuoted returns the index just past the single-quoted string +// literal starting at s[i] == '\”, honoring ” doubled-quote escapes. +func skipSingleQuoted(s string, i int) int { + i++ // opening quote + for i < len(s) { + if s[i] == '\'' { + if i+1 < len(s) && s[i+1] == '\'' { + i += 2 + continue + } + return i + 1 + } + i++ + } + return i +} + +// skipQuoted returns the index just past the quoted run starting at s[i] == q +// (q is '"' or '`'), honoring doubled-quote escapes for the same quote. +func skipQuoted(s string, i int, q byte) int { + i++ // opening quote + for i < len(s) { + if s[i] == q { + if i+1 < len(s) && s[i+1] == q { + i += 2 + continue + } + return i + 1 + } + i++ + } + return i +} diff --git a/analyzer/redact_policy_test.go b/analyzer/redact_policy_test.go new file mode 100644 index 0000000..c456b95 --- /dev/null +++ b/analyzer/redact_policy_test.go @@ -0,0 +1,51 @@ +package analyzer + +import "testing" + +func TestAnalyzeRedactsByDefault(t *testing.T) { + q := `SELECT * FROM users WHERE email = 'alice@acme.com'` + res := Default().Analyze(q) + if len(res) == 0 { + t.Fatal("expected at least one finding (select-star)") + } + for _, r := range res { + if contains(r.Query, "alice@acme.com") { + t.Errorf("default Analyze leaked literal in Query: %q", r.Query) + } + if r.Fingerprint == "" { + t.Error("Fingerprint not populated") + } + if contains(r.Fingerprint, "alice@acme.com") { + t.Errorf("Fingerprint leaked literal: %q", r.Fingerprint) + } + } +} + +func TestWithRawQueryKeepsLiterals(t *testing.T) { + q := `SELECT * FROM users WHERE email = 'alice@acme.com'` + res := Default().WithRawQuery().Analyze(q) + if len(res) == 0 { + t.Fatal("expected a finding") + } + if !contains(res[0].Query, "alice@acme.com") { + t.Errorf("WithRawQuery should keep raw SQL, got %q", res[0].Query) + } + if res[0].Fingerprint == "" { + t.Error("Fingerprint must still be set in raw mode") + } +} + +func TestPrepareQueryPolicy(t *testing.T) { + raw := `SELECT 'x' FROM t WHERE id = 9` + d, fp := Default().PrepareQuery(raw) + if contains(d, "'x'") { + t.Errorf("default PrepareQuery should redact: %q", d) + } + if fp == "" { + t.Error("fingerprint empty") + } + d2, _ := Default().WithRawQuery().PrepareQuery(raw) + if d2 != raw { + t.Errorf("raw PrepareQuery = %q, want %q", d2, raw) + } +} diff --git a/analyzer/redact_test.go b/analyzer/redact_test.go new file mode 100644 index 0000000..ef3b35f --- /dev/null +++ b/analyzer/redact_test.go @@ -0,0 +1,115 @@ +package analyzer + +import "testing" + +func TestRedact(t *testing.T) { + cases := []struct { + name string + in string + want string + }{ + {"string literal", `SELECT * FROM users WHERE email = 'alice@acme.com'`, + `SELECT * FROM users WHERE email = ?`}, + {"numeric literal", `SELECT * FROM t WHERE id = 42 AND age > 18`, + `SELECT * FROM t WHERE id = ? AND age > ?`}, + {"float and exponent", `SELECT * FROM t WHERE x = 3.14 AND y = 1e10`, + `SELECT * FROM t WHERE x = ? AND y = ?`}, + {"identifier with digits kept", `SELECT col1, int8_v FROM t1 WHERE a2 = 5`, + `SELECT col1, int8_v FROM t1 WHERE a2 = ?`}, + {"bind placeholders kept", `SELECT * FROM t WHERE a = $1 AND b = @p2`, + `SELECT * FROM t WHERE a = $1 AND b = @p2`}, + {"quoted identifier preserved", `SELECT "weird;col" FROM t WHERE n = 'x'`, + `SELECT "weird;col" FROM t WHERE n = ?`}, + {"backtick identifier preserved", "SELECT `from` FROM t WHERE n = 'x'", + "SELECT `from` FROM t WHERE n = ?"}, + {"escaped quote in literal", `SELECT * FROM t WHERE s = 'O''Brien'`, + `SELECT * FROM t WHERE s = ?`}, + {"comment stripped", "SELECT a -- secret 'tok'\nFROM t WHERE id = 9", + "SELECT a \nFROM t WHERE id = ?"}, + {"semicolon inside literal not structural", `SELECT * FROM t WHERE s = 'a;b'`, + `SELECT * FROM t WHERE s = ?`}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := Redact(c.in); got != c.want { + t.Errorf("Redact(%q)\n got: %q\nwant: %q", c.in, got, c.want) + } + }) + } +} + +func TestRedactNoPII(t *testing.T) { + pii := []string{"alice@acme.com", "123-45-6789", "4111111111111111", "secret"} + q := `SELECT * FROM users WHERE email='alice@acme.com' AND ssn='123-45-6789' + AND card='4111111111111111' /* secret */ LIMIT 10` + got := Redact(q) + for _, p := range pii { + if contains(got, p) { + t.Errorf("Redact leaked %q: %q", p, got) + } + } +} + +func TestFingerprint(t *testing.T) { + cases := []struct{ name, in, want string }{ + {"collapse whitespace", "SELECT *\n FROM t WHERE id = 1", + "SELECT * FROM t WHERE id = ?"}, + {"fold IN list", `SELECT * FROM t WHERE id IN (1, 2, 3, 4)`, + `SELECT * FROM t WHERE id IN (?)`}, + {"fold VALUES tuple", `INSERT INTO t VALUES ('a', 'b', 'c')`, + `INSERT INTO t VALUES (?)`}, + {"trailing semicolon trimmed", `SELECT 1;`, `SELECT ?`}, + {"differing literals same fp", + `SELECT * FROM t WHERE name = 'bob' AND age = 7`, + `SELECT * FROM t WHERE name = ? AND age = ?`}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := Fingerprint(c.in); got != c.want { + t.Errorf("Fingerprint(%q)\n got: %q\nwant: %q", c.in, got, c.want) + } + }) + } + + // Stability: queries differing only in values/list length share a fp. + a := Fingerprint(`SELECT * FROM t WHERE id IN (1,2,3) AND s = 'x'`) + b := Fingerprint(`SELECT * FROM t WHERE id IN (9,8) AND s = 'zzzzz'`) + if a != b { + t.Errorf("fingerprints should match:\n a=%q\n b=%q", a, b) + } +} + +func TestIsMultiStatement(t *testing.T) { + cases := []struct { + name string + in string + want bool + }{ + {"single", `SELECT * FROM t WHERE id = 1`, false}, + {"trailing semicolon", `SELECT * FROM t;`, false}, + {"trailing semicolon + ws", "SELECT 1; \n\t", false}, + {"stacked", `SELECT 1; DROP TABLE users`, true}, + {"stacked no space", `SELECT 1;DELETE FROM t`, true}, + {"semicolon in line comment", "SELECT 1 -- a; b\n", false}, + {"semicolon in block comment", `SELECT 1 /* a ; b */`, false}, + {"semicolon in string literal", `SELECT * FROM t WHERE s = 'a; DROP'`, false}, + {"comment hides stacking attempt", "SELECT 1 -- ;\nfrom t", false}, + {"real stack after string", `SELECT 'a;b'; DELETE FROM t`, true}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := IsMultiStatement(c.in); got != c.want { + t.Errorf("IsMultiStatement(%q) = %v, want %v", c.in, got, c.want) + } + }) + } +} + +func contains(s, sub string) bool { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/analyzer/registry.go b/analyzer/registry.go new file mode 100644 index 0000000..12fe055 --- /dev/null +++ b/analyzer/registry.go @@ -0,0 +1,151 @@ +package analyzer + +import ( + "sort" + "sync" + "time" +) + +// Settings holds rule-specific configuration as a generic key/value map so +// new tunables can be added without changing this type or the config schema. +// Accessors are nil-safe and fall back to the provided default, so a rule +// can always be constructed even with no settings supplied. +type Settings map[string]any + +// Int returns the setting as an int, or def if missing or not numeric. +// YAML decodes integers as int and JSON as float64, so both are accepted. +func (s Settings) Int(key string, def int) int { + if s == nil { + return def + } + switch v := s[key].(type) { + case int: + return v + case int64: + return int(v) + case float64: + return int(v) + default: + return def + } +} + +// Bool returns the setting as a bool, or def if missing or not a bool. +func (s Settings) Bool(key string, def bool) bool { + if s == nil { + return def + } + if v, ok := s[key].(bool); ok { + return v + } + return def +} + +// String returns the setting as a string, or def if missing or not a string. +func (s Settings) String(key, def string) string { + if s == nil { + return def + } + if v, ok := s[key].(string); ok { + return v + } + return def +} + +// Duration returns the setting parsed as a time.Duration. It accepts a +// duration string ("200ms") or a number interpreted as milliseconds. Returns +// def if missing or unparseable. +func (s Settings) Duration(key string, def time.Duration) time.Duration { + if s == nil { + return def + } + switch v := s[key].(type) { + case string: + if d, err := time.ParseDuration(v); err == nil { + return d + } + case int: + return time.Duration(v) * time.Millisecond + case int64: + return time.Duration(v) * time.Millisecond + case float64: + return time.Duration(v) * time.Millisecond + } + return def +} + +// RuleSpec describes a built-in rule: its stable name (used in config, +// suppressions and reports), its default severity, and a factory that builds +// the rule from its settings. Keeping construction behind a factory is what +// makes per-rule settings work uniformly for every present and future rule. +type RuleSpec struct { + Name string + DefaultSeverity Severity + Factory func(Settings) Rule +} + +var ( + registryMu sync.RWMutex + registry = map[string]RuleSpec{} +) + +// Register adds a rule to the global registry. Built-in rules call this from +// init(); third-party rules may call it too. A duplicate name overwrites the +// previous spec, so a custom rule can replace a built-in one by name. +func Register(spec RuleSpec) { + registryMu.Lock() + defer registryMu.Unlock() + registry[spec.Name] = spec +} + +// RuleNames returns all registered rule names, sorted. Used by the config +// loader to validate rule references and by tooling to list rules. +func RuleNames() []string { + registryMu.RLock() + names := make([]string, 0, len(registry)) + for n := range registry { + names = append(names, n) + } + registryMu.RUnlock() + sort.Strings(names) + return names +} + +// specs returns all registered specs sorted by name, for deterministic +// analyzer construction and stable report ordering. +func specs() []RuleSpec { + registryMu.RLock() + out := make([]RuleSpec, 0, len(registry)) + for _, s := range registry { + out = append(out, s) + } + registryMu.RUnlock() + sort.Slice(out, func(i, j int) bool { return out[i].Name < out[j].Name }) + return out +} + +// Profile is the resolved, parser-independent view of configuration applied +// to an Analyzer at construction time. The config package builds it from +// .sqlguard.yml; analyzer never imports config or YAML. All maps are keyed +// by rule name. Resolution happens once here, never on the per-query path. +type Profile struct { + // Disabled rules are not constructed or run. + Disabled map[string]bool + // Only, when non-empty, is a whitelist: only these rules run. + Only map[string]bool + // Severity overrides a rule's reported severity. + Severity map[string]Severity + // Settings holds per-rule tunables. + Settings map[string]Settings + // RawQuery, when true, disables Result.Query redaction (literals are + // left in the reported SQL). Default (false) redacts — see + // Analyzer.WithRawQuery. + RawQuery bool +} + +func (p Profile) skip(name string) bool { + if len(p.Only) > 0 && !p.Only[name] { + return true + } + return p.Disabled[name] +} diff --git a/analyzer/result.go b/analyzer/result.go new file mode 100644 index 0000000..30a8645 --- /dev/null +++ b/analyzer/result.go @@ -0,0 +1,19 @@ +package analyzer + +// Result represents a single finding from query analysis. +type Result struct { + RuleName string + Severity Severity + // Query is the offending SQL as surfaced to reporters. By default it is + // redacted (string/numeric literals replaced with "?") so literal values + // never reach a log sink; an Analyzer built WithRawQuery leaves it raw. + Query string + // Fingerprint is the redacted, whitespace-collapsed, list-folded query + // identity (see analyzer.Fingerprint). It is always set, never carries + // PII, and is safe as a metric label or log key. + Fingerprint string + Message string + Suggestion string + File string // populated only in static analysis mode + Line int // populated only in static analysis mode +} diff --git a/analyzer/rules.go b/analyzer/rules.go new file mode 100644 index 0000000..872df4e --- /dev/null +++ b/analyzer/rules.go @@ -0,0 +1,269 @@ +package analyzer + +import "fmt" + +// Built-in rules self-register so they are addressable by name for config +// (enable/disable, severity overrides, settings) and suppressions. Adding a +// new rule is just another Register call here — no other plumbing changes. +func init() { + Register(RuleSpec{Name: "select-star", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckSelectStar }}) + Register(RuleSpec{Name: "leading-wildcard", DefaultSeverity: SeverityWarning, + Factory: func(s Settings) Rule { return leadingWildcardRule(s.Int("min-length", 0)) }}) + Register(RuleSpec{Name: "delete-without-where", DefaultSeverity: SeverityCritical, + Factory: func(Settings) Rule { return CheckDeleteWithoutWhere }}) + Register(RuleSpec{Name: "update-without-where", DefaultSeverity: SeverityCritical, + Factory: func(Settings) Rule { return CheckUpdateWithoutWhere }}) + Register(RuleSpec{Name: "insert-without-columns", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckInsertWithoutColumns }}) + Register(RuleSpec{Name: "select-without-limit", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckSelectWithoutLimit }}) + Register(RuleSpec{Name: "orderby-without-limit", DefaultSeverity: SeverityInfo, + Factory: func(Settings) Rule { return CheckOrderByWithoutLimit }}) + Register(RuleSpec{Name: "non-sargable-predicate", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckNonSargablePredicate }}) + Register(RuleSpec{Name: "add-not-null-without-default", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckAddNotNullWithoutDefault }}) + Register(RuleSpec{Name: "implicit-join", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckImplicitJoin }}) + Register(RuleSpec{Name: "cartesian-join", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckCartesianJoin }}) + Register(RuleSpec{Name: "in-list-too-large", DefaultSeverity: SeverityWarning, + Factory: func(s Settings) Rule { return inListRule(s.Int("max-length", 100)) }}) + Register(RuleSpec{Name: "large-offset", DefaultSeverity: SeverityWarning, + Factory: func(s Settings) Rule { return largeOffsetRule(s.Int("threshold", 1000)) }}) + Register(RuleSpec{Name: "select-distinct", DefaultSeverity: SeverityInfo, + Factory: func(Settings) Rule { return CheckSelectDistinct }}) +} + +// CheckSelectStar detects SELECT * usage. +func CheckSelectStar(s *Statement) (Result, bool) { + if s.SelectStar { + return Result{ + RuleName: "select-star", + Query: s.Raw, + Message: "SELECT * detected. Selecting all columns can hurt performance.", + Suggestion: "Select only the columns you need.", + }, true + } + return Result{}, false +} + +// CheckLeadingWildcard detects LIKE patterns with leading wildcards, using +// the rule's default settings (no minimum term length). +func CheckLeadingWildcard(s *Statement) (Result, bool) { + return leadingWildcardRule(0)(s) +} + +// leadingWildcardRule builds the leading-wildcard rule. When minLen > 0, a +// leading-wildcard LIKE is flagged only if its searchable term is at least +// minLen characters long, so short patterns like LIKE '%x%' can be tolerated. +// A statement whose term length is unknown (0, e.g. produced by a real +// parser that did not compute it) is still flagged, to avoid false negatives. +func leadingWildcardRule(minLen int) Rule { + return func(s *Statement) (Result, bool) { + if !s.LeadingWildcardLike { + return Result{}, false + } + if minLen > 0 && s.LeadingWildcardTermLen > 0 && s.LeadingWildcardTermLen < minLen { + return Result{}, false + } + return Result{ + RuleName: "leading-wildcard", + Query: s.Raw, + Message: "LIKE with leading wildcard detected. Index cannot be used.", + Suggestion: "Use prefix search or a full-text index.", + }, true + } +} + +// CheckDeleteWithoutWhere detects DELETE statements without a WHERE clause. +func CheckDeleteWithoutWhere(s *Statement) (Result, bool) { + if s.Kind == StmtDelete && !s.HasWhere { + return Result{ + RuleName: "delete-without-where", + Query: s.Raw, + Message: "DELETE without WHERE clause detected. This will delete all rows.", + Suggestion: "Add a WHERE clause to limit the scope of the delete.", + }, true + } + return Result{}, false +} + +// CheckUpdateWithoutWhere detects UPDATE statements without a WHERE clause. +func CheckUpdateWithoutWhere(s *Statement) (Result, bool) { + if s.Kind == StmtUpdate && !s.HasWhere { + return Result{ + RuleName: "update-without-where", + Query: s.Raw, + Message: "UPDATE without WHERE clause detected. This will update all rows.", + Suggestion: "Add a WHERE clause to limit the scope of the update.", + }, true + } + return Result{}, false +} + +// CheckInsertWithoutColumns detects INSERT statements without an explicit +// column list. +func CheckInsertWithoutColumns(s *Statement) (Result, bool) { + if s.Kind == StmtInsert && !s.InsertColumnsListed { + return Result{ + RuleName: "insert-without-columns", + Query: s.Raw, + Message: "INSERT without explicit column list. This breaks if table schema changes.", + Suggestion: "Specify columns explicitly: INSERT INTO table (col1, col2) VALUES (...).", + }, true + } + return Result{}, false +} + +// CheckSelectWithoutLimit detects SELECT statements without a LIMIT clause. +// Only flags queries that have a FROM clause (to skip SELECT 1, SELECT +// version(), etc.) and don't have WHERE, to reduce noise. +func CheckSelectWithoutLimit(s *Statement) (Result, bool) { + if s.Kind == StmtSelect && s.HasFrom && !s.HasLimit && !s.HasWhere { + return Result{ + RuleName: "select-without-limit", + Query: s.Raw, + Message: "SELECT without LIMIT or WHERE clause. May return excessive rows.", + Suggestion: "Add a LIMIT clause or WHERE filter to restrict results.", + }, true + } + return Result{}, false +} + +// CheckNonSargablePredicate detects a function or cast applied to a column on +// the column side of a WHERE comparison (e.g. WHERE LOWER(email) = ...), which +// prevents an ordinary index on that column from being used. +func CheckNonSargablePredicate(s *Statement) (Result, bool) { + if s.NonSargablePredicate { + return Result{ + RuleName: "non-sargable-predicate", + Query: s.Raw, + Message: "Function applied to a column in WHERE prevents index use.", + Suggestion: "Compare the bare column instead, or add a matching expression/function index.", + }, true + } + return Result{}, false +} + +// CheckAddNotNullWithoutDefault detects an ALTER TABLE that adds a NOT NULL +// column with no DEFAULT, which errors or forces a full table rewrite on a +// populated table. +func CheckAddNotNullWithoutDefault(s *Statement) (Result, bool) { + if s.AddNotNullNoDefault { + return Result{ + RuleName: "add-not-null-without-default", + Query: s.Raw, + Message: "ADD COLUMN ... NOT NULL without DEFAULT fails or rewrites the table on a populated table.", + Suggestion: "Add a DEFAULT, or split into: add the column nullable, backfill, then SET NOT NULL.", + }, true + } + return Result{}, false +} + +// CheckInListTooLarge detects an IN (...) value list with more elements than +// the default threshold (100). Use the registry / config to tune max-length. +func CheckInListTooLarge(s *Statement) (Result, bool) { + return inListRule(100)(s) +} + +// inListRule builds the in-list-too-large rule. It flags a statement whose +// largest IN (...) value list has more than maxLen elements. A maxLen of 0 +// flags any value-list IN; subquery INs are never counted (MaxInListLen +// excludes them). +func inListRule(maxLen int) Rule { + return func(s *Statement) (Result, bool) { + if s.MaxInListLen <= maxLen { + return Result{}, false + } + return Result{ + RuleName: "in-list-too-large", + Query: s.Raw, + Message: fmt.Sprintf("IN list has %d elements (threshold %d). Large IN lists hurt query planning.", s.MaxInListLen, maxLen), + Suggestion: "Use a JOIN against a temp table / VALUES list, or a parameterized array such as = ANY($1).", + }, true + } +} + +// CheckSelectDistinct detects a select-level DISTINCT, which is often added to +// hide duplicate rows produced by an unintended join fan-out rather than to +// express a genuine need for distinct results. INFO by default. +func CheckSelectDistinct(s *Statement) (Result, bool) { + if s.SelectDistinct { + return Result{ + RuleName: "select-distinct", + Query: s.Raw, + Message: "SELECT DISTINCT detected. It often masks duplicate rows from an unintended join.", + Suggestion: "Confirm the duplicates aren't a join fan-out; prefer fixing the join or using EXISTS/GROUP BY.", + }, true + } + return Result{}, false +} + +// CheckLargeOffset detects a literal OFFSET larger than the default threshold +// (1000). Use the registry / config to tune threshold. +func CheckLargeOffset(s *Statement) (Result, bool) { + return largeOffsetRule(1000)(s) +} + +// largeOffsetRule builds the large-offset rule. It flags a statement whose +// literal OFFSET exceeds threshold — deep pagination, where the database scans +// and discards every skipped row. Parameterized offsets (OffsetValue == 0) are +// never flagged. +func largeOffsetRule(threshold int) Rule { + return func(s *Statement) (Result, bool) { + if s.OffsetValue <= threshold { + return Result{}, false + } + return Result{ + RuleName: "large-offset", + Query: s.Raw, + Message: fmt.Sprintf("OFFSET %d exceeds %d. Deep pagination scans and discards all skipped rows.", s.OffsetValue, threshold), + Suggestion: "Use keyset (cursor) pagination: WHERE id > $last ORDER BY id LIMIT n.", + }, true + } +} + +// CheckCartesianJoin detects a multi-table FROM with no join condition and no +// WHERE filter — an unconditioned cartesian product (incl. CROSS JOIN). +func CheckCartesianJoin(s *Statement) (Result, bool) { + if s.CartesianJoin { + return Result{ + RuleName: "cartesian-join", + Query: s.Raw, + Message: "Cartesian product: multiple tables joined with no join condition or WHERE filter.", + Suggestion: "Add a JOIN ... ON condition (or a WHERE clause relating the tables).", + }, true + } + return Result{}, false +} + +// CheckImplicitJoin detects a FROM clause that joins tables with commas +// (FROM a, b) instead of explicit JOIN syntax — error-prone because a +// forgotten join condition silently yields a cartesian product. +func CheckImplicitJoin(s *Statement) (Result, bool) { + if s.ImplicitCommaJoin { + return Result{ + RuleName: "implicit-join", + Query: s.Raw, + Message: "Implicit comma join in FROM. A missing join condition silently becomes a cartesian product.", + Suggestion: "Use explicit JOIN ... ON syntax.", + }, true + } + return Result{}, false +} + +// CheckOrderByWithoutLimit detects ORDER BY without LIMIT, which sorts the +// entire result set. +func CheckOrderByWithoutLimit(s *Statement) (Result, bool) { + if s.HasOrderBy && !s.HasLimit { + return Result{ + RuleName: "orderby-without-limit", + Query: s.Raw, + Message: "ORDER BY without LIMIT sorts the entire result set.", + Suggestion: "Add a LIMIT clause if you only need a subset of rows.", + }, true + } + return Result{}, false +} diff --git a/analyzer/severity.go b/analyzer/severity.go new file mode 100644 index 0000000..cde1e06 --- /dev/null +++ b/analyzer/severity.go @@ -0,0 +1,26 @@ +package analyzer + +// Severity represents the importance level of an analysis finding. +type Severity int + +const ( + // SeverityInfo is an advisory finding worth noting but not necessarily acting on. + SeverityInfo Severity = iota + // SeverityWarning is a likely problem that should be reviewed. + SeverityWarning + // SeverityCritical is a serious problem likely to cause incorrect or destructive behavior. + SeverityCritical +) + +func (s Severity) String() string { + switch s { + case SeverityInfo: + return "INFO" + case SeverityWarning: + return "WARNING" + case SeverityCritical: + return "CRITICAL" + default: + return "UNKNOWN" + } +} diff --git a/analyzer/statement.go b/analyzer/statement.go new file mode 100644 index 0000000..3b54401 --- /dev/null +++ b/analyzer/statement.go @@ -0,0 +1,138 @@ +package analyzer + +// StmtKind is the top-level kind of a SQL statement. +type StmtKind int + +const ( + // StmtUnknown means the parser could not determine the statement kind. + StmtUnknown StmtKind = iota + // StmtSelect is a SELECT (or WITH ... SELECT) query. + StmtSelect + // StmtInsert is an INSERT statement. + StmtInsert + // StmtUpdate is an UPDATE statement. + StmtUpdate + // StmtDelete is a DELETE statement. + StmtDelete + // StmtOther is a recognized statement that none of the rules target + // (DDL, transaction control, etc.). + StmtOther +) + +// Statement is sqlguard's normalized, dialect-agnostic view of a single SQL +// statement. It carries only the semantic facts the rules need — not a full +// AST. Every Parser (the zero-dependency fallback and the optional real +// dialect parsers) populates this same struct, so rules never depend on a +// particular parser or dialect. +// +// Boolean fields are best-effort: a fallback-produced Statement may leave a +// field false when it genuinely cannot tell. Rules must treat "false" as +// "not detected", never as "proven absent", to avoid false positives. +type Statement struct { + // Raw is the original, untouched SQL string. Reported back to users. + Raw string + + // Kind is the statement's top-level kind. + Kind StmtKind + + // HasWhere reports whether the statement has a WHERE clause. + HasWhere bool + + // HasLimit reports whether the statement has a LIMIT clause. + HasLimit bool + + // HasOrderBy reports whether the statement has an ORDER BY clause. + HasOrderBy bool + + // HasFrom reports whether a SELECT has a FROM clause. Distinguishes + // "SELECT * FROM t" from "SELECT 1" / "SELECT version()". + HasFrom bool + + // SelectStar reports an unqualified "SELECT *" / "SELECT t.*" of columns. + // It is false for aggregate forms like COUNT(*). + SelectStar bool + + // SelectDistinct reports a select-level DISTINCT (SELECT DISTINCT ..., + // incl. Postgres DISTINCT ON and MySQL DISTINCTROW). It is false for an + // aggregate-level DISTINCT such as COUNT(DISTINCT col), which is unrelated. + // The dialect parsers compute it from the AST; the fallback approximates it + // lexically. + SelectDistinct bool + + // InsertColumnsListed reports whether an INSERT names its target columns + // explicitly: INSERT INTO t (a, b) VALUES (...). Only meaningful when + // Kind == StmtInsert. + InsertColumnsListed bool + + // LeadingWildcardLike reports a LIKE pattern beginning with a wildcard + // (e.g. LIKE '%foo'), which prevents index use. + LeadingWildcardLike bool + + // NonSargablePredicate reports a function or cast applied to a column on + // the column side of a WHERE comparison (e.g. WHERE LOWER(email) = ...), + // which prevents the use of an ordinary index on that column. Like the + // LIKE fields, this is a literal/text-level heuristic the real parsers' + // ASTs discard, so it is computed by the fallback lexer and preserved by + // the dialect parsers rather than recomputed structurally. + NonSargablePredicate bool + + // AddNotNullNoDefault reports an ALTER TABLE that adds a NOT NULL column + // with no DEFAULT (e.g. ALTER TABLE t ADD COLUMN c int NOT NULL), which + // fails or forces a table rewrite on a populated table. Like the other + // text-level fields above, it is computed by the fallback lexer and + // preserved by the dialect parsers. + AddNotNullNoDefault bool + + // ImplicitCommaJoin reports a FROM clause that lists multiple tables + // separated by top-level commas (FROM a, b) instead of explicit JOIN + // syntax — the old-style join that silently produces a cartesian product + // when its join condition is forgotten. Computed by the fallback lexer and + // preserved (not recomputed from the AST) by the dialect parsers, so it + // stays a best-effort heuristic even when Exact is true. + ImplicitCommaJoin bool + + // CartesianJoin reports a multi-table FROM (comma join, CROSS JOIN, or a + // bare JOIN) with no join condition (ON/USING/NATURAL) and no top-level + // WHERE filter — an unconditioned cartesian product. It is the high- + // confidence subset of ImplicitCommaJoin and also covers CROSS/bare JOIN. + // Like ImplicitCommaJoin, it is a fallback-lexer heuristic preserved by the + // dialect parsers, so it stays best-effort even when Exact is true. + CartesianJoin bool + + // MaxInListLen is the largest element count among the statement's IN (...) + // value lists (IN (SELECT ...) subqueries are excluded). It powers the + // in-list-too-large rule's max-length threshold. Zero means no value-list + // IN was found. Like the other counts, rules read it, never raw SQL. It is a + // fallback-lexer heuristic preserved by the dialect parsers (the AST discards + // the literal list it counts), so it stays best-effort even when Exact is true. + MaxInListLen int + + // OffsetValue is the largest literal OFFSET seen (standard OFFSET n or + // MySQL's LIMIT offset, count), powering the large-offset rule. Zero means + // no offset, OFFSET 0, or a parameterized offset (OFFSET $1 / ?), which + // cannot be evaluated statically and is therefore never flagged. The dialect + // parsers read it from the AST's limit clause; the fallback scans for it. + OffsetValue int + + // LeadingWildcardTermLen is the length of the longest searchable term + // (the literal with surrounding % wildcards trimmed) across all + // leading-wildcard LIKE patterns in the statement. It powers the + // leading-wildcard rule's min-length setting. Zero means "unknown" + // (e.g. produced by a real parser that did not compute it); rules must + // treat zero as unknown and not as "short", to avoid false negatives. + LeadingWildcardTermLen int + + // Exact is true when the Statement was produced by a real SQL parser + // (structural analysis), false when produced by the regex fallback + // (best-effort). Rules may use this to suppress lower-confidence findings. + // + // "Exact" covers the structural facts the dialect parsers derive from the + // AST: Kind, HasWhere/HasLimit/HasOrderBy/HasFrom, SelectStar, + // SelectDistinct, OffsetValue, and InsertColumnsListed. A few facts stay + // lexical heuristics even when Exact is true — MaxInListLen, + // ImplicitCommaJoin, CartesianJoin, and the literal/text-level fields + // (LeadingWildcard*, NonSargablePredicate, AddNotNullNoDefault) — because + // they read literal values the AST discards or are intentionally text-level. + // Each such field documents this. + Exact bool +} diff --git a/analyzer/suppress.go b/analyzer/suppress.go new file mode 100644 index 0000000..5502116 --- /dev/null +++ b/analyzer/suppress.go @@ -0,0 +1,69 @@ +package analyzer + +import ( + "regexp" + "strings" +) + +// ignoreDirectiveRe matches a sqlguard:ignore directive inside a SQL or Go +// comment. The leading comment marker (--, /*, #, //) anchors it so the +// token is honored only in comment context, not when the literal text +// happens to appear inside a string. An optional `:rule-a, rule-b` list +// scopes the suppression to specific rules; without it, all rules are +// suppressed for the statement. +var ignoreDirectiveRe = regexp.MustCompile(`(?i)(?:--|/\*|#|//)[^\n]*?sqlguard:ignore(?::\s*([a-z0-9_,\s-]+))?`) + +// ignoreTokenRe matches the bare directive in text that is already known to +// be a comment (e.g. go/ast comment text with the marker stripped). No +// comment marker is required here because the whole string is comment +// context. +var ignoreTokenRe = regexp.MustCompile(`(?i)sqlguard:ignore(?::\s*([a-z0-9_,\s-]+))?`) + +// parseIgnoreDirective scans raw SQL for `sqlguard:ignore` directives. +// It returns ignoreAll=true if any directive has no rule list, otherwise a +// set of rule names to suppress. The result is empty when no directive is +// present, so the common path allocates nothing. +func parseIgnoreDirective(sql string) (ignoreAll bool, ignored map[string]bool) { + if !strings.Contains(strings.ToLower(sql), "sqlguard:ignore") { + return false, nil + } + for _, m := range ignoreDirectiveRe.FindAllStringSubmatch(sql, -1) { + list := strings.TrimSpace(m[1]) + if list == "" { + return true, nil + } + if ignored == nil { + ignored = make(map[string]bool) + } + for name := range strings.SplitSeq(list, ",") { + if name = strings.TrimSpace(name); name != "" { + ignored[name] = true + } + } + } + return false, ignored +} + +// ParseIgnoreComment parses the text of a single comment for a +// sqlguard:ignore directive. It is used by the static scanner to honor +// `// sqlguard:ignore` / `// sqlguard:ignore:rule-a,rule-b` annotations in Go +// source. found reports whether a directive was present; all is true for a +// bare directive (suppress every rule); rules holds the named rules +// otherwise. +func ParseIgnoreComment(text string) (all bool, rules map[string]bool, found bool) { + m := ignoreTokenRe.FindStringSubmatch(text) + if m == nil { + return false, nil, false + } + list := strings.TrimSpace(m[1]) + if list == "" { + return true, nil, true + } + rules = make(map[string]bool) + for name := range strings.SplitSeq(list, ",") { + if name = strings.TrimSpace(name); name != "" { + rules[name] = true + } + } + return false, rules, true +} diff --git a/cmd/sqlguard/db.go b/cmd/sqlguard/db.go new file mode 100644 index 0000000..649a0ef --- /dev/null +++ b/cmd/sqlguard/db.go @@ -0,0 +1,31 @@ +package main + +import ( + "database/sql" + "fmt" +) + +// openDB opens a database connection using the appropriate driver. +func openDB(dialect, dsn string) (*sql.DB, error) { + var driverName string + switch dialect { + case "postgres": + driverName = "postgres" + case "mysql": + driverName = "mysql" + default: + return nil, fmt.Errorf("unsupported dialect: %s", dialect) + } + + db, err := sql.Open(driverName, dsn) + if err != nil { + return nil, err + } + + if err := db.Ping(); err != nil { + _ = db.Close() + return nil, fmt.Errorf("cannot reach database: %w", err) + } + + return db, nil +} diff --git a/cmd/sqlguard/explain.go b/cmd/sqlguard/explain.go new file mode 100644 index 0000000..8379b73 --- /dev/null +++ b/cmd/sqlguard/explain.go @@ -0,0 +1,87 @@ +package main + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/KARTIKrocks/sqlguard/explain" + "github.com/KARTIKrocks/sqlguard/reporter" + "github.com/spf13/cobra" +) + +var ( + explainDSN string + explainDialect string + explainFormat string + explainAllowDML bool +) + +var explainCmd = &cobra.Command{ + Use: `explain "SQL QUERY"`, + Short: "Run EXPLAIN on a query against a live database", + Long: "Connects to a database and runs EXPLAIN to detect performance issues like sequential scans and missing indexes.", + Args: cobra.ExactArgs(1), + RunE: runExplain, +} + +func init() { + explainCmd.Flags().StringVar(&explainDSN, "db", "", "Database connection string (required)") + explainCmd.Flags().StringVar(&explainDialect, "dialect", "postgres", "Database dialect: postgres or mysql") + explainCmd.Flags().StringVar(&explainFormat, "format", "console", "Output format: console or json") + explainCmd.Flags().BoolVar(&explainAllowDML, "allow-dml", false, "Allow EXPLAIN on INSERT/UPDATE/DELETE (still run in an always-rolled-back transaction); refused by default") + _ = explainCmd.MarkFlagRequired("db") +} + +func runExplain(cmd *cobra.Command, args []string) error { + // Args are valid past this point; don't dump usage for runtime errors or + // the errIssuesFound sentinel. (Arg-parse errors still show usage.) + cmd.SilenceUsage = true + + query := args[0] + + db, err := openDB(explainDialect, explainDSN) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + defer func() { _ = db.Close() }() + + var explainOpts []explain.Option + if explainAllowDML { + explainOpts = append(explainOpts, explain.WithAllowDML()) + } + analyzer, err := explain.New(db, explainDialect, explainOpts...) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := analyzer.Analyze(ctx, query) + if err != nil { + return err + } + + var rep reporter.Reporter + switch explainFormat { + case "json": + rep = reporter.NewJSONReporter() + default: + rep = reporter.NewConsoleReporter() + } + + if len(result.Issues) > 0 { + rep.Report(result.Issues) + if explainFormat != "json" { + fmt.Fprintf(os.Stderr, "\n%d issue(s) found in query plan\n", len(result.Issues)) + } + return errIssuesFound + } + + if explainFormat != "json" { + fmt.Fprintln(os.Stderr, "No issues found in query plan") + } + return nil +} diff --git a/cmd/sqlguard/main.go b/cmd/sqlguard/main.go new file mode 100644 index 0000000..927a276 --- /dev/null +++ b/cmd/sqlguard/main.go @@ -0,0 +1,18 @@ +package main + +import ( + "errors" + "fmt" + "os" +) + +func main() { + if err := rootCmd.Execute(); err != nil { + // If issues were found, exit with code 1 silently (already reported). + if errors.Is(err, errIssuesFound) { + os.Exit(1) + } + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} diff --git a/cmd/sqlguard/resolve_test.go b/cmd/sqlguard/resolve_test.go new file mode 100644 index 0000000..749010a --- /dev/null +++ b/cmd/sqlguard/resolve_test.go @@ -0,0 +1,178 @@ +package main + +import ( + "errors" + "os" + "path/filepath" + "strings" + "testing" +) + +// createModule turns dir into a loadable Go module so the type-aware +// (go/packages) scan path runs instead of the AST fallback. +func createModule(t *testing.T, dir, modPath string) { + t.Helper() + createTestFile(t, dir, "go.mod", "module "+modPath+"\n\ngo 1.26\n") +} + +func createFileInSubdir(t *testing.T, dir, rel, content string) { + t.Helper() + full := filepath.Join(dir, rel) + if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(full, []byte(content), 0o644); err != nil { + t.Fatal(err) + } +} + +func TestScan_ResolvesSamePackageConst(t *testing.T) { + dir := t.TempDir() + createModule(t, dir, "example.com/m") + createTestFile(t, dir, "q.go", `package example +import "database/sql" + +const userQuery = "SELECT * FROM users WHERE id = 1" + +func f(db *sql.DB) { + db.Query(userQuery) +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Fatalf("expected issues from a resolved const, got %v\n%s", err, out) + } + if !strings.Contains(out, "select-star") { + t.Errorf("expected select-star from resolved const, got:\n%s", out) + } +} + +func TestScan_ResolvesConstConcatenation(t *testing.T) { + dir := t.TempDir() + createModule(t, dir, "example.com/m") + createTestFile(t, dir, "q.go", `package example +import "database/sql" + +const ( + cols = "*" + q = "SELECT " + cols + " FROM users WHERE id = 1" +) + +func f(db *sql.DB) { + db.Query(q) +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Fatalf("expected issues from folded concatenation, got %v\n%s", err, out) + } + if !strings.Contains(out, "select-star") { + t.Errorf("expected select-star from concatenated const, got:\n%s", out) + } +} + +func TestScan_ResolvesCrossPackageConst(t *testing.T) { + dir := t.TempDir() + createModule(t, dir, "example.com/m") + createFileInSubdir(t, dir, "queries/queries.go", `package queries + +const GetUser = "SELECT * FROM users WHERE id = 1" +`) + createTestFile(t, dir, "main.go", `package example +import ( + "database/sql" + + "example.com/m/queries" +) + +func f(db *sql.DB) { + db.Query(queries.GetUser) +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Fatalf("expected issues from cross-package const, got %v\n%s", err, out) + } + if !strings.Contains(out, "select-star") { + t.Errorf("expected select-star from cross-package const, got:\n%s", out) + } +} + +func TestScan_ResolvesSprintfFormat(t *testing.T) { + dir := t.TempDir() + createModule(t, dir, "example.com/m") + createTestFile(t, dir, "q.go", `package example +import ( + "database/sql" + "fmt" +) + +func f(db *sql.DB, table string) { + db.Query(fmt.Sprintf("SELECT * FROM %s WHERE id = %d", table, 1)) +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Fatalf("expected issues from Sprintf format string, got %v\n%s", err, out) + } + if !strings.Contains(out, "select-star") { + t.Errorf("expected select-star from Sprintf format, got:\n%s", out) + } +} + +// A safe query held in a constant must stay clean — proves resolution does not +// introduce false positives. +func TestScan_ResolvedConstNoFalsePositive(t *testing.T) { + dir := t.TempDir() + createModule(t, dir, "example.com/m") + createTestFile(t, dir, "q.go", `package example +import "database/sql" + +const safe = "SELECT id, name FROM users WHERE id = ? LIMIT 10" + +func f(db *sql.DB) { + db.Query(safe, 1) +} +`) + + out, err := captureScanOutput(t, dir) + + if err != nil { + t.Fatalf("expected clean exit for safe resolved const, got %v\n%s", err, out) + } + if strings.Contains(out, "SQLGUARD") { + t.Errorf("expected no findings for safe const, got:\n%s", out) + } +} + +// Inline suppression must still apply when the query is a resolved const. +func TestScan_SuppressionWithResolvedConst(t *testing.T) { + dir := t.TempDir() + createModule(t, dir, "example.com/m") + createTestFile(t, dir, "q.go", `package example +import "database/sql" + +const userQuery = "SELECT * FROM users WHERE id = 1" + +func f(db *sql.DB) { + db.Query(userQuery) // sqlguard:ignore:select-star +} +`) + + out, err := captureScanOutput(t, dir) + + if err != nil { + t.Fatalf("expected clean exit, finding suppressed, got %v\n%s", err, out) + } + if strings.Contains(out, "select-star") { + t.Errorf("inline directive should suppress finding on resolved const, got:\n%s", out) + } +} diff --git a/cmd/sqlguard/root.go b/cmd/sqlguard/root.go new file mode 100644 index 0000000..ff44072 --- /dev/null +++ b/cmd/sqlguard/root.go @@ -0,0 +1,58 @@ +package main + +import ( + "fmt" + "os" + + "github.com/KARTIKrocks/sqlguard/config" + "github.com/spf13/cobra" +) + +var ( + configPathFlag string + noConfigFlag bool +) + +var rootCmd = &cobra.Command{ + Use: "sqlguard", + Short: "Production-safe SQL query analyzer for Go applications", + Long: "sqlguard detects slow queries, dangerous SQL patterns, and performance issues in Go applications.", + // main() owns error printing and exit codes. Without this, cobra prints + // "Error: issues found" for the errIssuesFound sentinel, which is a normal + // outcome (issues were already reported), not a CLI error. + SilenceErrors: true, +} + +func init() { + rootCmd.PersistentFlags().StringVar(&configPathFlag, "config", "", "path to .sqlguard.yml (default: auto-discover)") + rootCmd.PersistentFlags().BoolVar(&noConfigFlag, "no-config", false, "ignore any .sqlguard.yml and use built-in defaults") + rootCmd.AddCommand(scanCmd) + rootCmd.AddCommand(explainCmd) +} + +// resolveConfig loads configuration honoring --config / --no-config, falling +// back to discovery from startDir. Warnings are printed to stderr; a load +// error is returned to abort the command. +func resolveConfig(startDir string) (*config.Config, error) { + switch { + case noConfigFlag: + return config.Default(), nil + case configPathFlag != "": + return config.Load(configPathFlag) + default: + c, path, err := config.Discover(startDir) + if err != nil { + return nil, err + } + if path != "" { + _, _ = fmt.Fprintf(os.Stderr, "Using config %s\n", path) + } + return c, nil + } +} + +func printConfigWarnings(c *config.Config) { + for _, w := range c.Warnings() { + _, _ = fmt.Fprintf(os.Stderr, "sqlguard: config warning: %s\n", w) + } +} diff --git a/cmd/sqlguard/scan.go b/cmd/sqlguard/scan.go new file mode 100644 index 0000000..23f94ef --- /dev/null +++ b/cmd/sqlguard/scan.go @@ -0,0 +1,398 @@ +package main + +import ( + "errors" + "fmt" + "go/ast" + "go/constant" + "go/parser" + "go/token" + "go/types" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/reporter" + "github.com/spf13/cobra" + "golang.org/x/tools/go/packages" +) + +// SQL method names we look for on any receiver. +var sqlMethods = map[string]bool{ + "Query": true, + "QueryContext": true, + "QueryRow": true, + "QueryRowContext": true, + "Exec": true, + "ExecContext": true, + "Prepare": true, + "PrepareContext": true, +} + +var formatFlag string + +var scanCmd = &cobra.Command{ + Use: "scan [path]", + Short: "Scan Go source files for SQL query issues", + Long: "Statically analyzes Go source files to find SQL queries and check them for common issues.", + Args: cobra.MaximumNArgs(1), + RunE: runScan, +} + +func init() { + scanCmd.Flags().StringVar(&formatFlag, "format", "console", "Output format: console or json") +} + +// errIssuesFound is returned when the scan finds issues, to signal a non-zero exit code. +var errIssuesFound = errors.New("issues found") + +func runScan(cmd *cobra.Command, args []string) error { + // Args are valid past this point; don't dump usage for runtime errors or + // the errIssuesFound sentinel. (Arg-parse errors still show usage.) + cmd.SilenceUsage = true + + dir := "." + if len(args) > 0 { + dir = args[0] + } + + rep, err := newReporter(formatFlag) + if err != nil { + return err + } + + cfg, err := resolveConfig(dir) + if err != nil { + return err + } + a, err := cfg.Analyzer() + if err != nil { + return err + } + printConfigWarnings(cfg) + exclude, err := cfg.ExcludeMatcher() + if err != nil { + return err + } + + allResults, totalFiles, err := scanDir(dir, a, exclude) + if err != nil { + return fmt.Errorf("scan failed: %w", err) + } + + if len(allResults) > 0 { + rep.Report(allResults) + if formatFlag != "json" { + _, _ = fmt.Fprintf(os.Stderr, "\n%d issue(s) found (%d file(s) scanned)\n", len(allResults), totalFiles) + } + return errIssuesFound + } + + if formatFlag != "json" { + _, _ = fmt.Fprintf(os.Stderr, "No issues found (%d file(s) scanned)\n", totalFiles) + } + return nil +} + +func newReporter(format string) (reporter.Reporter, error) { + switch format { + case "json": + return reporter.NewJSONReporter(), nil + case "console", "": + return reporter.NewConsoleReporter(), nil + default: + return nil, fmt.Errorf("unknown format %q: use 'console' or 'json'", format) + } +} + +// scanDir type-checks the target with golang.org/x/tools/go/packages so query +// arguments that are constants, cross-package constants, constant +// concatenations, or fmt.Sprintf literal format strings all resolve. If the +// target is not a loadable module (no go.mod, ad-hoc files), it degrades to a +// dependency-free go/parser walk that still handles inline string literals, so +// a broken or module-less tree is never silently skipped. +func scanDir(dir string, a *analyzer.Analyzer, exclude func(string) bool) ([]analyzer.Result, int, error) { + absDir, _ := filepath.Abs(dir) + + if results, n, ok := scanViaPackages(absDir, a, exclude); ok { + return results, n, nil + } + results, n, err := scanViaAST(dir, absDir, a, exclude) + return results, n, err +} + +// scanViaPackages is the primary, type-aware path. ok is false when the target +// cannot be loaded as a module at all (caller then falls back to the AST walk); +// individual packages with type errors are still scanned, degrading per-file to +// literal-only resolution. +func scanViaPackages(absDir string, a *analyzer.Analyzer, exclude func(string) bool) (results []analyzer.Result, totalFiles int, ok bool) { + cfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | + packages.NeedImports | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo, + Dir: absDir, + Tests: false, + } + pkgs, err := packages.Load(cfg, "./...") + if err != nil || len(pkgs) == 0 { + return nil, 0, false + } + + seen := map[string]struct{}{} + scannedAny := false + for _, pkg := range pkgs { + if len(pkg.Syntax) == 0 { + continue + } + scannedAny = true + // Degraded package (type errors): TypesInfo may be partial or nil; + // constString falls back to *ast.BasicLit when info lacks the value. + info := pkg.TypesInfo + for _, file := range pkg.Syntax { + path := pkg.Fset.Position(file.Pos()).Filename + if !keepFile(path, absDir, exclude) { + continue + } + if _, dup := seen[path]; dup { + continue + } + seen[path] = struct{}{} + totalFiles++ + results = append(results, scanASTFile(pkg.Fset, file, info, a)...) + } + } + if !scannedAny { + return nil, 0, false + } + return results, totalFiles, true +} + +// scanViaAST is the dependency-free fallback for module-less / unbuildable +// trees: parse each file in isolation and resolve only inline string literals +// (info is nil, so scanASTFile degrades accordingly). +func scanViaAST(dir, absDir string, a *analyzer.Analyzer, exclude func(string) bool) ([]analyzer.Result, int, error) { + fset := token.NewFileSet() + var results []analyzer.Result + totalFiles := 0 + + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return shouldSkipDir(path, absDir) + } + if !keepFile(path, absDir, exclude) { + return nil + } + f, perr := parser.ParseFile(fset, path, nil, parser.ParseComments) + if perr != nil { + return nil + } + totalFiles++ + results = append(results, scanASTFile(fset, f, nil, a)...) + return nil + }) + + return results, totalFiles, err +} + +// keepFile reports whether a .go file should be analyzed: skip non-Go and +// _test.go files, then apply the configured exclude matcher against the path +// relative to the scan root (so regexes behave identically whether the path +// came from go list (absolute) or the walk (relative)). +func keepFile(path, absDir string, exclude func(string) bool) bool { + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return false + } + if exclude != nil { + rel := path + if abs, err := filepath.Abs(path); err == nil { + if r, rerr := filepath.Rel(absDir, abs); rerr == nil { + rel = r + } + } + if exclude(filepath.ToSlash(rel)) { + return false + } + } + return true +} + +func shouldSkipDir(path, absDir string) error { + absPath, _ := filepath.Abs(path) + if absPath != absDir { + base := filepath.Base(path) + if strings.HasPrefix(base, ".") || base == "vendor" || base == "node_modules" { + return filepath.SkipDir + } + } + return nil +} + +// scanASTFile walks one parsed file for SQL-method calls and resolves each +// query argument via resolveQuery. info may be nil (fallback / degraded +// package), in which case resolution is limited to inline string literals. +func scanASTFile(fset *token.FileSet, f *ast.File, info *types.Info, a *analyzer.Analyzer) []analyzer.Result { + suppress := buildSuppressor(fset, f) + + var results []analyzer.Result + ast.Inspect(f, func(n ast.Node) bool { + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok || !sqlMethods[sel.Sel.Name] { + return true + } + + arg := queryArgExpr(sel.Sel.Name, call.Args) + if arg == nil { + return true + } + query := resolveQuery(info, arg) + if query == "" { + return true + } + + found := a.Analyze(query) + pos := fset.Position(call.Pos()) + all, rules := suppress(pos.Line) + for _, r := range found { + if all || rules[r.RuleName] { + continue + } + r.File = pos.Filename + r.Line = pos.Line + results = append(results, r) + } + return true + }) + + return results +} + +// buildSuppressor returns a lookup that, for a given source line, reports +// whether a `// sqlguard:ignore` directive applies — either trailing on that +// line or on the line directly above the call. This is the static-analysis +// counterpart to the in-SQL directive the analyzer handles at runtime. +func buildSuppressor(fset *token.FileSet, f *ast.File) func(line int) (bool, map[string]bool) { + type directive struct { + all bool + rules map[string]bool + } + byLine := map[int]directive{} + for _, cg := range f.Comments { + all, rules, found := analyzer.ParseIgnoreComment(cg.Text()) + if !found { + continue + } + end := fset.Position(cg.End()).Line + // Apply to the comment's own line (trailing) and the next line + // (comment sitting directly above the call). + byLine[end] = directive{all, rules} + byLine[end+1] = directive{all, rules} + } + return func(line int) (bool, map[string]bool) { + d, ok := byLine[line] + if !ok { + return false, nil + } + return d.all, d.rules + } +} + +// queryArgExpr returns the expression holding the SQL string for a given SQL +// method (the first arg, or the second for *Context variants). +func queryArgExpr(methodName string, args []ast.Expr) ast.Expr { + argIdx := 0 + if strings.HasSuffix(methodName, "Context") { + argIdx = 1 + } + if argIdx >= len(args) { + return nil + } + return args[argIdx] +} + +// resolveQuery turns a query-argument expression into SQL text. The single +// go/constant lookup in constString already covers inline literals, +// same-package constants, cross-package constants, and constant concatenation +// (the type checker folded them). fmt.Sprintf with a constant format string is +// resolved by neutralizing its verbs so the SQL stays structurally analyzable. +func resolveQuery(info *types.Info, e ast.Expr) string { + if s, ok := constString(info, e); ok { + return s + } + if ce, ok := e.(*ast.CallExpr); ok { + if fa, ok := sprintfFormatArg(info, ce); ok { + if f, ok := constString(info, fa); ok { + return neutralizeFormat(f) + } + } + } + return "" +} + +// constString resolves any constant string-valued expression. With type info +// this is one map lookup that the compiler already folded; without it (nil +// info or value absent) it degrades to a raw string literal. +func constString(info *types.Info, e ast.Expr) (string, bool) { + if info != nil { + if tv, ok := info.Types[e]; ok && tv.Value != nil && tv.Value.Kind() == constant.String { + return constant.StringVal(tv.Value), true + } + } + if bl, ok := e.(*ast.BasicLit); ok && bl.Kind == token.STRING { + if s, err := strconv.Unquote(bl.Value); err == nil { + return s, true + } + } + return "", false +} + +// sprintfFormatArg returns the format-string argument if ce is a call to +// fmt.Sprintf. With type info the callee is verified to be package "fmt"; +// without it, a conservative `fmt.Sprintf` selector-name heuristic is used. +func sprintfFormatArg(info *types.Info, ce *ast.CallExpr) (ast.Expr, bool) { + sel, ok := ce.Fun.(*ast.SelectorExpr) + if !ok || sel.Sel.Name != "Sprintf" || len(ce.Args) == 0 { + return nil, false + } + if info != nil { + if obj := info.Uses[sel.Sel]; obj != nil { + fn, ok := obj.(*types.Func) + if !ok || fn.Pkg() == nil || fn.Pkg().Path() != "fmt" { + return nil, false + } + return ce.Args[0], true + } + } + if id, ok := sel.X.(*ast.Ident); ok && id.Name == "fmt" { + return ce.Args[0], true + } + return nil, false +} + +var formatVerb = regexp.MustCompile(`%[-+# 0]*[\d.*]*[a-zA-Z%]`) + +// neutralizeFormat replaces fmt verbs in a constant format string with benign +// placeholders so the remaining SQL keeps its structure for the rule engine. +// Numeric verbs become 0; everything else becomes a harmless identifier; %% +// collapses to a literal %. +func neutralizeFormat(format string) string { + return formatVerb.ReplaceAllStringFunc(format, func(v string) string { + switch v[len(v)-1] { + case '%': + return "%" + case 'b', 'c', 'd', 'o', 'O', 'x', 'X', 'U', 'e', 'E', 'f', 'F', 'g', 'G', 'p': + return "0" + default: + return "sqlguard" + } + }) +} diff --git a/cmd/sqlguard/scan_test.go b/cmd/sqlguard/scan_test.go new file mode 100644 index 0000000..7cbafc8 --- /dev/null +++ b/cmd/sqlguard/scan_test.go @@ -0,0 +1,399 @@ +package main + +import ( + "bytes" + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/cobra" +) + +func createTestFile(t *testing.T, dir, name, content string) { + t.Helper() + err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0644) + if err != nil { + t.Fatalf("failed to create test file: %v", err) + } +} + +func TestScan_DetectsSelectStar(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT * FROM users WHERE id = 1") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit (errIssuesFound)") + } + if !strings.Contains(out, "select-star") { + t.Errorf("expected select-star warning, got:\n%s", out) + } +} + +func TestScan_DetectsDeleteWithoutWhere(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Exec("DELETE FROM users") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "delete-without-where") { + t.Errorf("expected delete-without-where warning, got:\n%s", out) + } + if !strings.Contains(out, "CRITICAL") { + t.Errorf("expected CRITICAL severity, got:\n%s", out) + } +} + +func TestScan_DetectsLeadingWildcard(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT id FROM users WHERE name LIKE '%test%'") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "leading-wildcard") { + t.Errorf("expected leading-wildcard warning, got:\n%s", out) + } +} + +func TestScan_DetectsUpdateWithoutWhere(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Exec("UPDATE users SET name = 'test'") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "update-without-where") { + t.Errorf("expected update-without-where warning, got:\n%s", out) + } +} + +func TestScan_DetectsInsertWithoutColumns(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Exec("INSERT INTO users VALUES ('alice', 'alice@test.com')") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "insert-without-columns") { + t.Errorf("expected insert-without-columns warning, got:\n%s", out) + } +} + +func TestScan_DetectsSelectWithoutLimit(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT id, name FROM users") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "select-without-limit") { + t.Errorf("expected select-without-limit warning, got:\n%s", out) + } +} + +func TestScan_DetectsOrderByWithoutLimit(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT id FROM users WHERE active = true ORDER BY name") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "orderby-without-limit") { + t.Errorf("expected orderby-without-limit warning, got:\n%s", out) + } +} + +func TestScan_NoWarningsForSafeQuery(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "good.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT id, name FROM users WHERE id = ? LIMIT 10", 1) +} +`) + + out, err := captureScanOutput(t, dir) + + if err != nil { + t.Errorf("expected nil error for safe query, got: %v", err) + } + if strings.Contains(out, "SQLGUARD") { + t.Errorf("expected no warnings for safe query, got:\n%s", out) + } + if !strings.Contains(out, "No issues found (") { + t.Errorf("expected 'No issues found' message, got:\n%s", out) + } +} + +func TestScan_SkipsTestFiles(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad_test.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT * FROM users WHERE id = 1") +} +`) + + out, err := captureScanOutput(t, dir) + + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + if strings.Contains(out, "select-star") { + t.Errorf("should skip _test.go files, got:\n%s", out) + } +} + +func TestScan_SkipsVendorDir(t *testing.T) { + vendorDir := filepath.Join(t.TempDir(), "vendor") + os.MkdirAll(vendorDir, 0755) + createTestFile(t, vendorDir, "bad.go", `package vendor +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT * FROM users WHERE id = 1") +} +`) + + out, err := captureScanOutput(t, filepath.Dir(vendorDir)) + + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + if strings.Contains(out, "select-star") { + t.Errorf("should skip vendor directory, got:\n%s", out) + } +} + +func TestScan_HandlesContextMethods(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "ctx.go", `package example +import ( + "context" + "database/sql" +) +func f(db *sql.DB) { + db.QueryContext(context.Background(), "SELECT * FROM users WHERE id = 1") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "select-star") { + t.Errorf("expected select-star for QueryContext, got:\n%s", out) + } +} + +func TestScan_MultipleIssuesInOneFile(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "multi.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT * FROM users WHERE email LIKE '%test%'") + db.Exec("DELETE FROM orders") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "select-star") { + t.Error("expected select-star warning") + } + if !strings.Contains(out, "leading-wildcard") { + t.Error("expected leading-wildcard warning") + } + if !strings.Contains(out, "delete-without-where") { + t.Error("expected delete-without-where warning") + } +} + +// boundDir creates a temp dir with a .git marker so config.Discover does not +// escape it while walking parents. +func boundDir(t *testing.T) string { + t.Helper() + dir := t.TempDir() + if err := os.Mkdir(filepath.Join(dir, ".git"), 0o755); err != nil { + t.Fatal(err) + } + return dir +} + +func TestScan_ConfigDisablesRule(t *testing.T) { + dir := boundDir(t) + createTestFile(t, dir, ".sqlguard.yml", "rules:\n disable: [select-star]\n") + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT * FROM users WHERE id = 1") +} +`) + + out, err := captureScanOutput(t, dir) + + if err != nil { + t.Errorf("expected clean exit when rule disabled by config, got %v\n%s", err, out) + } + if strings.Contains(out, "select-star") { + t.Errorf("select-star should be disabled via .sqlguard.yml, got:\n%s", out) + } +} + +func TestScan_InlineSuppressionComment(t *testing.T) { + dir := boundDir(t) + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + // sqlguard:ignore + db.Exec("DELETE FROM users") + db.Query("SELECT * FROM users WHERE id = 1") // sqlguard:ignore:select-star +} +`) + + out, err := captureScanOutput(t, dir) + + if err != nil { + t.Errorf("expected clean exit, all findings suppressed, got %v\n%s", err, out) + } + if strings.Contains(out, "delete-without-where") || strings.Contains(out, "select-star") { + t.Errorf("inline directives should suppress findings, got:\n%s", out) + } +} + +func TestScan_ExitCodeZeroWhenClean(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "clean.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT id, name FROM users WHERE id = ? LIMIT 10", 1) +} +`) + + _, err := captureScanOutput(t, dir) + + if err != nil { + t.Errorf("expected exit code 0 for clean code, got error: %v", err) + } +} + +// captureScanOutput runs the scan command and captures stderr output. +// Returns the output and the error (errIssuesFound if issues were found). +func captureScanOutput(t *testing.T, dir string) (string, error) { + t.Helper() + + // Reset format flag to default for each test + formatFlag = "console" + + old := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + + err := runScan(&cobra.Command{}, []string{dir}) + + w.Close() + os.Stderr = old + + var buf bytes.Buffer + buf.ReadFrom(r) + + // Only fail on unexpected errors, not errIssuesFound + if err != nil && !errors.Is(err, errIssuesFound) { + t.Fatalf("scan failed unexpectedly: %v", err) + } + + return buf.String(), err +} + +// TestScanCommand_NoUsageDumpOnIssues runs the real command tree (rootCmd.Execute) +// and asserts that a normal "issues found" outcome does NOT print cobra's usage +// text. Regression guard for the SilenceErrors/SilenceUsage wiring: without it, +// returning errIssuesFound from RunE makes cobra dump "Error: issues found" +// followed by the full usage, which looks like a CLI misuse. +func TestScanCommand_NoUsageDumpOnIssues(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", + "package bad\n\nimport \"database/sql\"\n\nfunc r(d *sql.DB) { d.Exec(\"DELETE FROM x\") }\n") + + formatFlag = "console" + noConfigFlag = true + t.Cleanup(func() { noConfigFlag = false }) + + old := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + + rootCmd.SetArgs([]string{"scan", "--no-config", dir}) + err := rootCmd.Execute() + + w.Close() + os.Stderr = old + var buf bytes.Buffer + _, _ = buf.ReadFrom(r) + out := buf.String() + + if !errors.Is(err, errIssuesFound) { + t.Fatalf("expected errIssuesFound, got %v", err) + } + if strings.Contains(out, "Usage:") { + t.Errorf("scan dumped usage text on an issues-found result:\n%s", out) + } + if !strings.Contains(out, "delete-without-where") { + t.Errorf("expected the finding in output, got:\n%s", out) + } +} diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..cce2abc --- /dev/null +++ b/codecov.yml @@ -0,0 +1,31 @@ +# Codecov configuration for sqlguard. +# Docs: https://docs.codecov.com/docs/codecov-yaml +# Coverage is produced by `make coverage` (merged across all nine modules) and +# uploaded from the CI "coverage" job. + +codecov: + require_ci_to_pass: true + +coverage: + precision: 2 + round: down + range: "70...100" + status: + project: + default: + target: auto + threshold: 5% + patch: + default: + target: auto + threshold: 5% + +comment: + layout: "reach,diff,flags,files" + behavior: default + require_changes: true + +ignore: + - "examples/**" + - "**/*_test.go" + - "cmd/sqlguard/main.go" diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..4319da1 --- /dev/null +++ b/config/config.go @@ -0,0 +1,317 @@ +// Package config loads and applies .sqlguard.yml configuration. +// +// It is the only package that depends on a YAML library. Importing +// sqlguard/analyzer or sqlguard/middleware does NOT pull YAML in; only code +// that opts into file-based configuration through this package does. The +// analyzer stays parser- and config-agnostic: config translates a Config +// into an analyzer.Profile, which the analyzer applies once at construction. +package config + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "strings" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "gopkg.in/yaml.v3" +) + +// ConfigFileNames are the file names Discover looks for, in order. +var ConfigFileNames = []string{".sqlguard.yml", ".sqlguard.yaml"} + +// Config mirrors the .sqlguard.yml schema. The Version field is reserved for +// forward compatibility: older binaries reading a newer config degrade with +// warnings rather than failing, unless Strict is set. +type Config struct { + Version int `yaml:"version"` + Strict bool `yaml:"strict"` + Rules RulesConfig `yaml:"rules"` + SlowQuery SlowQueryConfig `yaml:"slow-query"` + Dedup DedupConfig `yaml:"dedup"` + Scan ScanConfig `yaml:"scan"` + // Redact controls Result.Query literal redaction. Pointer so an unset + // key means "use the safe default" (redact). Set `redact: false` only + // when the query text is trusted (local debugging). + Redact *bool `yaml:"redact"` + + warnings []string +} + +// RulesConfig configures which rules run, their severity, and per-rule +// settings. +type RulesConfig struct { + // Disable turns off the named rules. + Disable []string `yaml:"disable"` + // Only, when non-empty, is a whitelist: only these rules run. + Only []string `yaml:"only"` + // Severity overrides per rule: info | warning | critical | off + // ("off" is equivalent to disabling the rule). + Severity map[string]string `yaml:"severity"` + // Settings holds per-rule tunables, e.g. leading-wildcard.min-length. + Settings map[string]map[string]any `yaml:"settings"` +} + +// SlowQueryConfig configures the middleware slow-query threshold. +type SlowQueryConfig struct { + // Threshold is a Go duration string, e.g. "200ms". + Threshold string `yaml:"threshold"` +} + +// DedupConfig configures runtime suppression of repeated static findings. +type DedupConfig struct { + // Window is a Go duration string, e.g. "1m". The same finding (rule + + // query fingerprint) is reported at most once per window. "0" disables + // dedup (report every occurrence). Unset keeps the middleware default. + Window string `yaml:"window"` +} + +// ScanConfig holds settings that apply only to the static scanner. +type ScanConfig struct { + // ExcludePaths is a list of regular expressions matched against scanned + // file paths; matching files are skipped. + ExcludePaths []string `yaml:"exclude-paths"` +} + +// Default returns an empty configuration: every rule enabled at its default +// severity and settings. Used when no .sqlguard.yml is found. +func Default() *Config { return &Config{Version: 1} } + +// Load reads and parses the config at path. Parsing is lenient by default so +// a config written for a newer sqlguard still loads on an older binary; +// unknown top-level keys become warnings. If the file sets `strict: true`, +// unknown keys are a hard error instead. +func Load(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("sqlguard config: %w", err) + } + + var c Config + if err := yaml.Unmarshal(data, &c); err != nil { + return nil, fmt.Errorf("sqlguard config %s: %w", path, err) + } + + // Detect unknown fields with a second strict decode. yaml.v3 surfaces the + // first unknown field as an error; we treat it as fatal only in strict + // mode, otherwise as a warning so forward-compatible configs still work. + if strictErr := strictDecode(data); strictErr != nil { + if c.Strict { + return nil, fmt.Errorf("sqlguard config %s (strict): %w", path, strictErr) + } + c.warnings = append(c.warnings, strictErr.Error()) + } + return &c, nil +} + +func strictDecode(data []byte) error { + dec := yaml.NewDecoder(bytes.NewReader(data)) + dec.KnownFields(true) + var probe Config + if err := dec.Decode(&probe); err != nil && !errors.Is(err, io.EOF) { + return err + } + return nil +} + +// Discover walks startDir and its parents looking for a config file. It stops +// at a directory containing a .git entry (project root) after checking that +// directory, or at the filesystem root. It returns Default() and an empty +// path when no config file is found. +func Discover(startDir string) (cfg *Config, path string, err error) { + dir, err := filepath.Abs(startDir) + if err != nil { + return nil, "", err + } + for { + for _, name := range ConfigFileNames { + p := filepath.Join(dir, name) + if st, statErr := os.Stat(p); statErr == nil && !st.IsDir() { + c, loadErr := Load(p) + return c, p, loadErr + } + } + if isProjectRoot(dir) { + break + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + return Default(), "", nil +} + +func isProjectRoot(dir string) bool { + _, err := os.Stat(filepath.Join(dir, ".git")) + return err == nil +} + +// Warnings returns non-fatal issues collected while loading or resolving the +// config (unknown keys in lenient mode, unknown rule names, bad severities). +// Callers should surface these to the user. +func (c *Config) Warnings() []string { return c.warnings } + +// Profile resolves the config into an analyzer.Profile. Unknown rule names +// and unparseable severities are warnings (or errors if Strict). A severity +// of "off" disables the rule. The returned Profile is ready to pass to +// analyzer.DefaultWithProfile. +func (c *Config) Profile() (analyzer.Profile, error) { + known := make(map[string]bool) + for _, n := range analyzer.RuleNames() { + known[n] = true + } + + p := analyzer.Profile{ + Disabled: map[string]bool{}, + Only: map[string]bool{}, + Severity: map[string]analyzer.Severity{}, + Settings: map[string]analyzer.Settings{}, + RawQuery: c.rawQuery(), + } + + warn := func(format string, args ...any) error { + msg := fmt.Sprintf(format, args...) + if c.Strict { + return errors.New(msg) + } + c.warnings = append(c.warnings, msg) + return nil + } + + checkName := func(name string) error { + if !known[name] { + return warn("unknown rule %q (known: %s)", name, strings.Join(analyzer.RuleNames(), ", ")) + } + return nil + } + + for _, name := range c.Rules.Disable { + if err := checkName(name); err != nil { + return p, err + } + p.Disabled[name] = true + } + for _, name := range c.Rules.Only { + if err := checkName(name); err != nil { + return p, err + } + p.Only[name] = true + } + for name, sevStr := range c.Rules.Severity { + if err := checkName(name); err != nil { + return p, err + } + sev, off, ok := parseSeverity(sevStr) + if !ok { + if err := warn("rule %q: invalid severity %q", name, sevStr); err != nil { + return p, err + } + continue + } + if off { + p.Disabled[name] = true + continue + } + p.Severity[name] = sev + } + for name, kv := range c.Rules.Settings { + if err := checkName(name); err != nil { + return p, err + } + p.Settings[name] = analyzer.Settings(kv) + } + return p, nil +} + +// rawQuery reports whether Result.Query redaction is disabled. Redaction is +// the default (PII-safe); only an explicit `redact: false` turns it off. +func (c *Config) rawQuery() bool { return c.Redact != nil && !*c.Redact } + +// Analyzer is a convenience that builds an analyzer from the config's +// Profile using the fallback parser. Callers wanting a real dialect parser +// should take the Profile and combine with analyzer.DefaultWithProfile + +// WithParser themselves. +func (c *Config) Analyzer() (*analyzer.Analyzer, error) { + p, err := c.Profile() + if err != nil { + return nil, err + } + return analyzer.DefaultWithProfile(p), nil +} + +// SlowQueryThreshold returns the configured slow-query threshold. ok is false +// when unset, in which case the caller keeps its own default. +func (c *Config) SlowQueryThreshold() (d time.Duration, ok bool, err error) { + s := strings.TrimSpace(c.SlowQuery.Threshold) + if s == "" { + return 0, false, nil + } + d, err = time.ParseDuration(s) + if err != nil { + return 0, false, fmt.Errorf("sqlguard config: slow-query.threshold %q: %w", s, err) + } + return d, true, nil +} + +// DedupWindow returns the configured static-finding dedup window. ok is false +// when unset, in which case the middleware keeps its own default. A configured +// "0" returns ok=true with d=0, which disables dedup (report every occurrence). +func (c *Config) DedupWindow() (d time.Duration, ok bool, err error) { + s := strings.TrimSpace(c.Dedup.Window) + if s == "" { + return 0, false, nil + } + d, err = time.ParseDuration(s) + if err != nil { + return 0, false, fmt.Errorf("sqlguard config: dedup.window %q: %w", s, err) + } + return d, true, nil +} + +// ExcludeMatcher compiles Scan.ExcludePaths into a single predicate. It +// returns a nil func (never excludes) when no patterns are configured. +func (c *Config) ExcludeMatcher() (func(path string) bool, error) { + if len(c.Scan.ExcludePaths) == 0 { + return nil, nil + } + res := make([]*regexp.Regexp, 0, len(c.Scan.ExcludePaths)) + for _, pat := range c.Scan.ExcludePaths { + re, err := regexp.Compile(pat) + if err != nil { + return nil, fmt.Errorf("sqlguard config: scan.exclude-paths %q: %w", pat, err) + } + res = append(res, re) + } + return func(path string) bool { + for _, re := range res { + if re.MatchString(path) { + return true + } + } + return false + }, nil +} + +// parseSeverity maps a config severity string to an analyzer.Severity. +// "off" / "none" / "disabled" report off=true (disable the rule). +func parseSeverity(s string) (sev analyzer.Severity, off bool, ok bool) { + switch strings.ToLower(strings.TrimSpace(s)) { + case "info": + return analyzer.SeverityInfo, false, true + case "warning", "warn": + return analyzer.SeverityWarning, false, true + case "critical", "error": + return analyzer.SeverityCritical, false, true + case "off", "none", "disabled": + return 0, true, true + default: + return 0, false, false + } +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 0000000..42ecdda --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,197 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func writeConfig(t *testing.T, dir, body string) string { + t.Helper() + p := filepath.Join(dir, ".sqlguard.yml") + if err := os.WriteFile(p, []byte(body), 0o644); err != nil { + t.Fatalf("write config: %v", err) + } + return p +} + +func TestLoadAndProfile(t *testing.T) { + dir := t.TempDir() + p := writeConfig(t, dir, ` +version: 1 +rules: + disable: [orderby-without-limit] + severity: + select-star: info + select-without-limit: "off" + settings: + leading-wildcard: + min-length: 4 +slow-query: + threshold: 350ms +`) + c, err := Load(p) + if err != nil { + t.Fatalf("Load: %v", err) + } + + prof, err := c.Profile() + if err != nil { + t.Fatalf("Profile: %v", err) + } + if !prof.Disabled["orderby-without-limit"] { + t.Error("orderby-without-limit should be disabled") + } + if !prof.Disabled["select-without-limit"] { + t.Error(`severity "off" should disable select-without-limit`) + } + if prof.Severity["select-star"] != analyzer.SeverityInfo { + t.Errorf("select-star severity = %v, want INFO", prof.Severity["select-star"]) + } + if prof.Settings["leading-wildcard"].Int("min-length", 0) != 4 { + t.Error("min-length setting not carried into profile") + } + + d, ok, err := c.SlowQueryThreshold() + if err != nil || !ok || d != 350*time.Millisecond { + t.Errorf("SlowQueryThreshold = %v, %v, %v; want 350ms,true,nil", d, ok, err) + } + + // End-to-end: the built analyzer respects the profile. + a := analyzer.DefaultWithProfile(prof) + got := a.Analyze("SELECT * FROM users") + if len(got) != 1 || got[0].RuleName != "select-star" || got[0].Severity != analyzer.SeverityInfo { + t.Errorf("expected single INFO select-star, got %+v", got) + } +} + +func TestDedupWindow(t *testing.T) { + t.Run("set", func(t *testing.T) { + c := &Config{Dedup: DedupConfig{Window: "30s"}} + d, ok, err := c.DedupWindow() + if err != nil || !ok || d != 30*time.Second { + t.Errorf("DedupWindow = %v, %v, %v; want 30s,true,nil", d, ok, err) + } + }) + t.Run("unset keeps default", func(t *testing.T) { + c := &Config{} + if d, ok, err := c.DedupWindow(); err != nil || ok || d != 0 { + t.Errorf("DedupWindow = %v, %v, %v; want 0,false,nil", d, ok, err) + } + }) + t.Run("zero disables", func(t *testing.T) { + c := &Config{Dedup: DedupConfig{Window: "0"}} + if d, ok, err := c.DedupWindow(); err != nil || !ok || d != 0 { + t.Errorf("DedupWindow = %v, %v, %v; want 0,true,nil (explicit disable)", d, ok, err) + } + }) + t.Run("invalid errors", func(t *testing.T) { + c := &Config{Dedup: DedupConfig{Window: "soon"}} + if _, _, err := c.DedupWindow(); err == nil { + t.Error("expected error for invalid dedup.window") + } + }) +} + +func TestUnknownRuleLenientVsStrict(t *testing.T) { + dir := t.TempDir() + body := "rules:\n disable: [no-such-rule]\n" + + c, err := Load(writeConfig(t, dir, body)) + if err != nil { + t.Fatalf("Load: %v", err) + } + if _, err := c.Profile(); err != nil { + t.Fatalf("lenient Profile should not error: %v", err) + } + if len(c.Warnings()) == 0 { + t.Error("expected a warning for unknown rule in lenient mode") + } + + strict := &Config{Strict: true, Rules: RulesConfig{Disable: []string{"no-such-rule"}}} + if _, err := strict.Profile(); err == nil { + t.Error("expected error for unknown rule in strict mode") + } +} + +func TestUnknownKeyLenientWarnsStrictFails(t *testing.T) { + dir := t.TempDir() + + c, err := Load(writeConfig(t, dir, "bananas: true\n")) + if err != nil { + t.Fatalf("lenient load should succeed: %v", err) + } + if len(c.Warnings()) == 0 { + t.Error("expected warning for unknown top-level key") + } + + if _, err := Load(writeConfig(t, dir, "strict: true\nbananas: true\n")); err == nil { + t.Error("expected strict load to fail on unknown key") + } +} + +func TestDiscoverWalksUpAndStopsAtGitRoot(t *testing.T) { + root := t.TempDir() + if err := os.Mkdir(filepath.Join(root, ".git"), 0o755); err != nil { + t.Fatal(err) + } + writeConfig(t, root, "rules:\n disable: [select-star]\n") + deep := filepath.Join(root, "a", "b", "c") + if err := os.MkdirAll(deep, 0o755); err != nil { + t.Fatal(err) + } + + c, path, err := Discover(deep) + if err != nil { + t.Fatalf("Discover: %v", err) + } + if path == "" { + t.Fatal("expected to find config by walking up") + } + prof, _ := c.Profile() + if !prof.Disabled["select-star"] { + t.Error("discovered config not applied") + } +} + +func TestDiscoverNoConfigReturnsDefault(t *testing.T) { + dir := t.TempDir() + // .git marks the boundary so Discover does not escape the temp dir. + _ = os.Mkdir(filepath.Join(dir, ".git"), 0o755) + + c, path, err := Discover(dir) + if err != nil { + t.Fatalf("Discover: %v", err) + } + if path != "" { + t.Errorf("expected no config path, got %q", path) + } + if _, err := c.Profile(); err != nil { + t.Errorf("default profile should be valid: %v", err) + } +} + +func TestExcludeMatcher(t *testing.T) { + c := &Config{Scan: ScanConfig{ExcludePaths: []string{`(^|/)legacy/`, `_gen\.go$`}}} + m, err := c.ExcludeMatcher() + if err != nil { + t.Fatalf("ExcludeMatcher: %v", err) + } + if !m("pkg/legacy/old.go") || !m("api/types_gen.go") { + t.Error("expected matches for excluded paths") + } + if m("pkg/service/user.go") { + t.Error("did not expect match for normal path") + } + + none, err := (&Config{}).ExcludeMatcher() + if err != nil { + t.Errorf("no patterns should not error: %v", err) + } + if none != nil { + t.Error("no patterns should yield a nil matcher") + } +} diff --git a/config/middleware.go b/config/middleware.go new file mode 100644 index 0000000..439d91d --- /dev/null +++ b/config/middleware.go @@ -0,0 +1,61 @@ +package config + +import ( + "github.com/KARTIKrocks/sqlguard/middleware" +) + +// MiddlewareOptions translates this config into middleware options: an +// analyzer built from the rule Profile, and the slow-query threshold when +// configured. Combine with other middleware options as needed, e.g.: +// +// opts, _ := cfg.MiddlewareOptions() +// opts = append(opts, middleware.WithParser(pgparser.New())) +// sqlguard.Register("sqlguard-pg", "pgx", opts...) +// +// Keeping this in the config package (not middleware) keeps YAML out of the +// middleware import graph for users who do not use file configuration. +func (c *Config) MiddlewareOptions() ([]middleware.Option, error) { + a, err := c.Analyzer() + if err != nil { + return nil, err + } + opts := []middleware.Option{middleware.WithAnalyzer(a)} + + d, ok, err := c.SlowQueryThreshold() + if err != nil { + return nil, err + } + if ok { + opts = append(opts, middleware.WithSlowQueryThreshold(d)) + } + + dw, ok, err := c.DedupWindow() + if err != nil { + return nil, err + } + if ok { + opts = append(opts, middleware.WithFindingDedup(dw)) + } + return opts, nil +} + +// Middleware loads configuration and returns ready-to-use middleware +// options. If path is non-empty it is loaded directly; otherwise config is +// discovered by walking up from startDir (use "." for the working +// directory). A missing config is not an error — it yields options +// equivalent to the built-in defaults. +func Middleware(path, startDir string) ([]middleware.Option, error) { + var ( + c *Config + err error + ) + if path != "" { + c, err = Load(path) + } else { + c, _, err = Discover(startDir) + } + if err != nil { + return nil, err + } + return c.MiddlewareOptions() +} diff --git a/config/middleware_test.go b/config/middleware_test.go new file mode 100644 index 0000000..6caf642 --- /dev/null +++ b/config/middleware_test.go @@ -0,0 +1,50 @@ +package config + +import ( + "database/sql" + "path/filepath" + "strings" + "testing" + + "github.com/KARTIKrocks/sqlguard" + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/KARTIKrocks/sqlguard/reporter" + + _ "github.com/mattn/go-sqlite3" +) + +func TestMiddlewareOptionsAppliesProfile(t *testing.T) { + dir := t.TempDir() + writeConfig(t, dir, "rules:\n disable: [select-star]\n") + + opts, err := Middleware("", dir) + if err != nil { + t.Fatalf("Middleware: %v", err) + } + + var buf strings.Builder + opts = append(opts, middleware.WithReporter(&reporter.ConsoleReporter{Out: &buf})) + + name := "sqlguard-cfg-test" + if err := sqlguard.Register(name, "sqlite3", opts...); err != nil { + t.Fatalf("Register: %v", err) + } + db, err := sql.Open(name, filepath.Join(dir, "t.db")) + if err != nil { + t.Fatalf("open: %v", err) + } + defer db.Close() + if _, err := db.Exec("CREATE TABLE u (id INTEGER, name TEXT)"); err != nil { + t.Fatalf("create: %v", err) + } + + rows, err := db.Query("SELECT * FROM u WHERE id = 1") + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if strings.Contains(buf.String(), "select-star") { + t.Errorf("select-star should be disabled via config, got:\n%s", buf.String()) + } +} diff --git a/explain/explain.go b/explain/explain.go new file mode 100644 index 0000000..86a0a00 --- /dev/null +++ b/explain/explain.go @@ -0,0 +1,295 @@ +// Package explain provides SQL EXPLAIN plan analysis. +// It connects to a live database to run EXPLAIN on queries and detect +// performance issues like sequential scans and high-cost operations. +package explain + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +// PlanAnalyzer runs EXPLAIN on queries against a live database. +type PlanAnalyzer struct { + db *sql.DB + dialect string // "postgres" or "mysql" + allowDML bool +} + +// Option configures a PlanAnalyzer. +type Option func(*PlanAnalyzer) + +// WithAllowDML permits EXPLAIN on INSERT/UPDATE/DELETE statements. It is OFF +// by default: only SELECT/WITH are explained, because feeding DML to a prod +// database — even under plain EXPLAIN — is a footgun (and EXPLAIN ANALYZE +// would execute it). When enabled, DML EXPLAINs still run inside a +// transaction that is always rolled back (see analyzePostgres/analyzeMySQL), +// so nothing is committed regardless. +func WithAllowDML() Option { + return func(p *PlanAnalyzer) { p.allowDML = true } +} + +// New creates a PlanAnalyzer for the given database connection. +// dialect must be "postgres" or "mysql". +func New(db *sql.DB, dialect string, opts ...Option) (*PlanAnalyzer, error) { + if db == nil { + return nil, fmt.Errorf("explain: db is nil") + } + dialect = strings.ToLower(dialect) + if dialect != "postgres" && dialect != "mysql" { + return nil, fmt.Errorf("explain: unsupported dialect %q (use 'postgres' or 'mysql')", dialect) + } + p := &PlanAnalyzer{db: db, dialect: dialect} + for _, o := range opts { + o(p) + } + return p, nil +} + +// Result holds the parsed EXPLAIN output and any detected issues. +type Result struct { + Query string + RawPlan string + Issues []analyzer.Result +} + +// Analyze runs EXPLAIN on the given query and returns detected issues. The +// query is validated (see validate) and the EXPLAIN is run inside an +// always-rolled-back transaction, so a query passed here cannot mutate the +// target database. +func (p *PlanAnalyzer) Analyze(ctx context.Context, query string) (*Result, error) { + safe, err := p.validate(query) + if err != nil { + return nil, err + } + + var res *Result + switch p.dialect { + case "postgres": + res, err = p.analyzePostgres(ctx, safe) + case "mysql": + res, err = p.analyzeMySQL(ctx, safe) + default: + return nil, fmt.Errorf("explain: unsupported dialect %q", p.dialect) + } + if res != nil { + fp := analyzer.Fingerprint(query) + for i := range res.Issues { + res.Issues[i].Fingerprint = fp + } + } + return res, err +} + +// validate enforces the EXPLAIN safety policy and returns the single, +// terminator-stripped statement that is safe to concatenate into an EXPLAIN +// prefix. +// +// EXPLAIN cannot take bind parameters, so the query is necessarily +// string-concatenated; the defense is therefore strict input validation, not +// parameterization: +// +// - Reject empty input. +// - Reject multi-statement input using a comment- and string-literal-aware +// check (analyzer.IsMultiStatement). The previous +// strings.Contains(query, ";") check was defeated by a ";" inside a +// -- / /* */ comment or a string literal, and over-rejected a harmless +// trailing ";". +// - Classify the statement via the same parser the analyzer uses. Only +// SELECT/WITH are allowed by default; INSERT/UPDATE/DELETE require +// WithAllowDML; DDL/SET/other is always refused. +func (p *PlanAnalyzer) validate(query string) (string, error) { + q := strings.TrimSpace(query) + if q == "" { + return "", fmt.Errorf("explain: refusing to explain an empty query") + } + if analyzer.IsMultiStatement(q) { + return "", fmt.Errorf("explain: refusing to explain multi-statement input") + } + q = strings.TrimRight(q, "; \t\r\n") + + st, _ := analyzer.NewFallbackParser().Parse(q) + switch st.Kind { + case analyzer.StmtSelect: + return q, nil + case analyzer.StmtInsert, analyzer.StmtUpdate, analyzer.StmtDelete: + if !p.allowDML { + return "", fmt.Errorf("explain: refusing to EXPLAIN a data-modifying statement by default; construct the analyzer with explain.WithAllowDML to opt in") + } + return q, nil + default: + return "", fmt.Errorf("explain: refusing to explain a non-SELECT/WITH/DML statement (DDL, SET, transaction control, or unrecognized)") + } +} + +// PostgreSQL EXPLAIN JSON structures +type pgPlan struct { + Plan pgPlanNode `json:"Plan"` +} + +type pgPlanNode struct { + NodeType string `json:"Node Type"` + TotalCost float64 `json:"Total Cost"` + PlanRows int64 `json:"Plan Rows"` + Plans []pgPlanNode `json:"Plans"` +} + +func (p *PlanAnalyzer) analyzePostgres(ctx context.Context, query string) (*Result, error) { + // query is the validated, single, terminator-free statement from + // validate(). EXPLAIN takes no bind parameters, so concatenation is + // unavoidable; safety comes from validate() plus the rolled-back, + // read-only transaction below. We never use EXPLAIN ANALYZE, so the + // statement is planned, not executed. + explainQuery := "EXPLAIN (FORMAT JSON) " + query + + tx, err := p.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + return nil, fmt.Errorf("explain: failed to begin read-only transaction: %w", err) + } + // Always roll back: an EXPLAIN must never commit anything. + defer func() { _ = tx.Rollback() }() + + var rawJSON string + if err := tx.QueryRowContext(ctx, explainQuery).Scan(&rawJSON); err != nil { + return nil, fmt.Errorf("explain: failed to run EXPLAIN: %w", err) + } + + result := &Result{ + Query: query, + RawPlan: rawJSON, + } + + var plans []pgPlan + if err := json.Unmarshal([]byte(rawJSON), &plans); err != nil { + return result, fmt.Errorf("explain: failed to parse EXPLAIN JSON: %w", err) + } + + if len(plans) > 0 { + p.walkPgPlan(&plans[0].Plan, query, &result.Issues) + } + + return result, nil +} + +func (p *PlanAnalyzer) walkPgPlan(node *pgPlanNode, query string, issues *[]analyzer.Result) { + if node == nil { + return + } + + // Detect sequential scans + if node.NodeType == "Seq Scan" { + severity := analyzer.SeverityInfo + if node.PlanRows > 1000 { + severity = analyzer.SeverityWarning + } + *issues = append(*issues, analyzer.Result{ + RuleName: "seq-scan", + Severity: severity, + Query: query, + Message: fmt.Sprintf("Sequential scan detected (estimated %d rows, cost %.1f)", node.PlanRows, node.TotalCost), + Suggestion: "Consider adding an index to avoid full table scan.", + }) + } + + // Detect high cost operations + if node.TotalCost > 10000 { + *issues = append(*issues, analyzer.Result{ + RuleName: "high-cost", + Severity: analyzer.SeverityWarning, + Query: query, + Message: fmt.Sprintf("High cost operation: %s (cost %.1f)", node.NodeType, node.TotalCost), + Suggestion: "Review query plan and consider optimization.", + }) + } + + // Recurse into child plans + for i := range node.Plans { + p.walkPgPlan(&node.Plans[i], query, issues) + } +} + +func (p *PlanAnalyzer) analyzeMySQL(ctx context.Context, query string) (*Result, error) { + // See analyzePostgres: validated single statement, no ANALYZE, run in an + // always-rolled-back read-only transaction so EXPLAIN cannot mutate data. + explainQuery := "EXPLAIN " + query + + tx, err := p.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + return nil, fmt.Errorf("explain: failed to begin read-only transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + rows, err := tx.QueryContext(ctx, explainQuery) + if err != nil { + return nil, fmt.Errorf("explain: failed to run EXPLAIN: %w", err) + } + defer func() { _ = rows.Close() }() + + result := &Result{ + Query: query, + } + + for rows.Next() { + var ( + id int + selectType string + table sql.NullString + partitions sql.NullString + accessType sql.NullString + possibleKeys sql.NullString + key sql.NullString + keyLen sql.NullString + ref sql.NullString + rowCount sql.NullInt64 + filtered sql.NullFloat64 + extra sql.NullString + ) + + if err := rows.Scan(&id, &selectType, &table, &partitions, &accessType, &possibleKeys, &key, &keyLen, &ref, &rowCount, &filtered, &extra); err != nil { + return result, fmt.Errorf("explain: failed to scan row: %w", err) + } + + // Detect full table scans (type = ALL) + if accessType.Valid && accessType.String == "ALL" { + result.Issues = append(result.Issues, analyzer.Result{ + RuleName: "full-table-scan", + Severity: analyzer.SeverityWarning, + Query: query, + Message: fmt.Sprintf("Full table scan on %s (estimated %d rows)", table.String, rowCount.Int64), + Suggestion: "Consider adding an index to avoid full table scan.", + }) + } + + // Detect missing indexes + if (!key.Valid || key.String == "") && (!possibleKeys.Valid || possibleKeys.String == "") && table.Valid && table.String != "" { + result.Issues = append(result.Issues, analyzer.Result{ + RuleName: "no-index-used", + Severity: analyzer.SeverityWarning, + Query: query, + Message: fmt.Sprintf("No index used on table %s", table.String), + Suggestion: "Consider adding an index on the filtered/joined columns.", + }) + } + + // Detect filesort + if strings.Contains(extra.String, "Using filesort") { + result.Issues = append(result.Issues, analyzer.Result{ + RuleName: "filesort", + Severity: analyzer.SeverityInfo, + Query: query, + Message: fmt.Sprintf("Filesort detected on table %s", table.String), + Suggestion: "Consider adding an index that covers the ORDER BY columns.", + }) + } + } + + if err := rows.Err(); err != nil { + return result, fmt.Errorf("explain: error reading rows: %w", err) + } + + return result, nil +} diff --git a/explain/explain_test.go b/explain/explain_test.go new file mode 100644 index 0000000..d251bcb --- /dev/null +++ b/explain/explain_test.go @@ -0,0 +1,52 @@ +package explain + +import ( + "strings" + "testing" +) + +func TestValidate(t *testing.T) { + cases := []struct { + name string + query string + allowDML bool + wantErr string // substring; "" means no error + wantSafe string // expected returned statement when no error + }{ + {"select ok", `SELECT * FROM t WHERE id = 1`, false, "", `SELECT * FROM t WHERE id = 1`}, + {"with ok", `WITH c AS (SELECT 1) SELECT * FROM c`, false, "", `WITH c AS (SELECT 1) SELECT * FROM c`}, + {"trailing semicolon trimmed", `SELECT 1;`, false, "", `SELECT 1`}, + {"empty", ` `, false, "empty", ""}, + {"stacked statements", `SELECT 1; DROP TABLE users`, false, "multi-statement", ""}, + {"semicolon in comment is fine", "SELECT 1 -- ; DROP\n", false, "", "SELECT 1 -- ; DROP"}, + {"semicolon in string is fine", `SELECT * FROM t WHERE s = 'a;b'`, false, "", `SELECT * FROM t WHERE s = 'a;b'`}, + {"stack hidden after string", `SELECT 'a;b'; DELETE FROM t`, false, "multi-statement", ""}, + {"dml refused by default", `DELETE FROM t WHERE id = 1`, false, "data-modifying", ""}, + {"update refused by default", `UPDATE t SET a = 1`, false, "data-modifying", ""}, + {"dml allowed with opt-in", `DELETE FROM t WHERE id = 1`, true, "", `DELETE FROM t WHERE id = 1`}, + {"ddl always refused", `DROP TABLE users`, true, "non-SELECT", ""}, + {"set always refused", `SET search_path = x`, true, "non-SELECT", ""}, + {"truncate refused", `TRUNCATE t`, true, "non-SELECT", ""}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + p := &PlanAnalyzer{dialect: "postgres", allowDML: c.allowDML} + safe, err := p.validate(c.query) + if c.wantErr == "" { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if safe != c.wantSafe { + t.Errorf("safe = %q, want %q", safe, c.wantSafe) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", c.wantErr) + } + if !strings.Contains(err.Error(), c.wantErr) { + t.Errorf("error %q does not contain %q", err, c.wantErr) + } + }) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b178be2 --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module github.com/KARTIKrocks/sqlguard + +go 1.26 + +require ( + github.com/mattn/go-sqlite3 v1.14.45 + github.com/spf13/cobra v1.10.2 + golang.org/x/tools v0.45.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + golang.org/x/mod v0.36.0 // indirect + golang.org/x/sync v0.20.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..37b285f --- /dev/null +++ b/go.sum @@ -0,0 +1,24 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4= +golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8= +golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/integrations/bunguard/bunguard.go b/integrations/bunguard/bunguard.go new file mode 100644 index 0000000..dc3a5bf --- /dev/null +++ b/integrations/bunguard/bunguard.go @@ -0,0 +1,75 @@ +// Package bunguard integrates sqlguard with bun (github.com/uptrace/bun). +// +// Analysis is driven by the single shared sqlguard core (middleware.Guard), +// so redaction-by-default, stable fingerprints, the pluggable real-grammar +// parser, slow-query timing and N+1 detection behave identically to the +// database/sql driver wrapper, pgxguard and gormguard. There is no parallel +// option surface — configure with the standard middleware options: +// +// sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) +// db := bun.NewDB(sqldb, pgdialect.New()) +// db.AddQueryHook(bunguard.New( +// middleware.WithSlowQueryThreshold(500*time.Millisecond), +// middleware.WithN1Detection(10, time.Second), +// )) +// +// bun exposes the final rendered SQL and a start timestamp on the QueryEvent +// in its AfterQuery hook, so this uses the explicit Check+CheckLatency pair +// (matching gormguard) rather than middleware.Guard.Observe: static rules run +// on every query, latency is reported only on success. +package bunguard + +import ( + "context" + "time" + + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/uptrace/bun" +) + +// QueryHook implements bun.QueryHook and drives every traced statement +// through the shared sqlguard analysis core. +type QueryHook struct { + g *middleware.Guard +} + +// Compile-time proof we satisfy bun.QueryHook. +var _ bun.QueryHook = (*QueryHook)(nil) + +// New creates a new sqlguard bun query hook. It accepts the standard sqlguard +// middleware options (WithAnalyzer, WithReporter, WithSlowQueryThreshold, +// WithParser, WithN1Detection, …) — the same option set the database/sql +// driver wrapper, pgxguard and gormguard use, so there is no parallel +// configuration surface to drift. +func New(opts ...middleware.Option) *QueryHook { + return &QueryHook{g: middleware.NewGuard(opts...)} +} + +// ResetN1 clears N+1 tracker state. Call it at a per-request boundary +// (e.g. end of an HTTP handler) to scope N+1 detection to one unit of work. +// No-op unless WithN1Detection was passed to New. +func (h *QueryHook) ResetN1() { h.g.ResetN1() } + +// BeforeQuery implements bun.QueryHook. bun stamps event.StartTime itself +// before invoking the hook, so there is nothing to stash here. +func (h *QueryHook) BeforeQuery(ctx context.Context, _ *bun.QueryEvent) context.Context { + return ctx +} + +// AfterQuery implements bun.QueryHook. event.Query holds the rendered SQL. +func (h *QueryHook) AfterQuery(_ context.Context, event *bun.QueryEvent) { + sql := event.Query + if sql == "" { + return + } + + // Static rules + N+1 run on every call (matches Observe semantics). + h.g.Check(sql) + + // Latency is reported only on success — a failed query's duration is + // meaningless. This mirrors middleware.Guard.Observe. + if event.Err != nil { + return + } + h.g.CheckLatency(sql, time.Since(event.StartTime)) +} diff --git a/integrations/bunguard/bunguard_test.go b/integrations/bunguard/bunguard_test.go new file mode 100644 index 0000000..5b5fbd5 --- /dev/null +++ b/integrations/bunguard/bunguard_test.go @@ -0,0 +1,185 @@ +package bunguard + +import ( + "context" + "database/sql" + "strings" + "sync" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/middleware" + _ "github.com/mattn/go-sqlite3" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" +) + +// capture is a thread-safe in-memory Reporter for assertions. +type capture struct { + mu sync.Mutex + r []analyzer.Result +} + +func (c *capture) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.r = append(c.r, rs...) +} + +func (c *capture) snapshot() []analyzer.Result { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]analyzer.Result, len(c.r)) + copy(out, c.r) + return out +} + +func (c *capture) has(rule string) bool { + for _, r := range c.snapshot() { + if r.RuleName == rule { + return true + } + } + return false +} + +type user struct { + bun.BaseModel `bun:"table:users"` + ID int64 `bun:"id,pk"` + Email string `bun:"email"` +} + +// newDBWithCapture spins up an in-memory sqlite-backed *bun.DB with the +// sqlguard hook registered, so the integration runs end-to-end (QueryHook +// seam → driver round trip) rather than mocked. +func newDBWithCapture(t *testing.T, opts ...middleware.Option) (*bun.DB, *capture, *QueryHook) { + t.Helper() + sqldb, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("sql.Open: %v", err) + } + t.Cleanup(func() { _ = sqldb.Close() }) + db := bun.NewDB(sqldb, sqlitedialect.New()) + + ctx := context.Background() + if _, err := db.NewCreateTable().Model((*user)(nil)).Exec(ctx); err != nil { + t.Fatalf("create table: %v", err) + } + if _, err := db.NewInsert().Model(&user{ID: 1, Email: "leak@example.com"}).Exec(ctx); err != nil { + t.Fatalf("seed: %v", err) + } + + cap := &capture{} + opts = append([]middleware.Option{middleware.WithReporter(cap)}, opts...) + hook := New(opts...) + db.AddQueryHook(hook) + // Hook registered after seeding, so capture starts clean — every test + // asserts only on findings from its own queries. + return db, cap, hook +} + +func TestHook_DetectsRawSelectStar(t *testing.T) { + db, cap, _ := newDBWithCapture(t) + var us []user + if err := db.NewRaw("SELECT * FROM users").Scan(context.Background(), &us); err != nil { + t.Fatalf("Raw: %v", err) + } + if !cap.has("select-star") { + t.Fatalf("expected select-star finding, got %+v", cap.snapshot()) + } +} + +// TestHook_RedactsLiteralsByDefault asserts the headline redaction guarantee: +// single-quoted literals never reach Result.Query and Fingerprint is always +// populated. +func TestHook_RedactsLiteralsByDefault(t *testing.T) { + db, cap, _ := newDBWithCapture(t) + var us []user + if err := db.NewRaw("SELECT * FROM users WHERE email = 'leak@example.com'").Scan(context.Background(), &us); err != nil { + t.Fatalf("Raw: %v", err) + } + results := cap.snapshot() + if len(results) == 0 { + t.Fatal("expected at least one finding") + } + for _, r := range results { + if strings.Contains(r.Query, "leak@example.com") { + t.Errorf("literal leaked into Result.Query: %q (rule=%s)", r.Query, r.RuleName) + } + if r.Fingerprint == "" { + t.Errorf("Fingerprint must always be populated, got empty for rule %s", r.RuleName) + } + } +} + +func TestHook_SlowQueryReportedOnSuccess(t *testing.T) { + db, cap, _ := newDBWithCapture(t, middleware.WithSlowQueryThreshold(0)) + var u user + if err := db.NewSelect().Model(&u).Where("id = ?", 1).Scan(context.Background()); err != nil { + t.Fatalf("Select: %v", err) + } + if !cap.has("slow-query") { + t.Fatalf("expected slow-query finding with zero threshold, got %+v", cap.snapshot()) + } +} + +func TestHook_SlowQuerySuppressedOnError(t *testing.T) { + db, cap, _ := newDBWithCapture(t, middleware.WithSlowQueryThreshold(0)) + var dst int + err := db.NewRaw("SELECT id FROM no_such_table_xyz WHERE id = 1").Scan(context.Background(), &dst) + if err == nil { + t.Fatal("expected error from selecting a missing table") + } + if cap.has("slow-query") { + t.Fatalf("slow-query must not fire when the query failed; got %+v", cap.snapshot()) + } +} + +func TestHook_NPlusOneAcrossCalls(t *testing.T) { + db, cap, _ := newDBWithCapture(t, middleware.WithN1Detection(3, time.Second)) + var u user + for range 3 { + if err := db.NewRaw("SELECT id FROM users WHERE id = 1").Scan(context.Background(), &u); err != nil { + t.Fatalf("Raw: %v", err) + } + } + if !cap.has("n-plus-one") { + t.Fatalf("expected n-plus-one finding after 3 identical queries, got %+v", cap.snapshot()) + } +} + +func TestHook_ResetN1ClearsState(t *testing.T) { + db, cap, hook := newDBWithCapture(t, middleware.WithN1Detection(3, time.Second)) + var u user + for range 2 { + if err := db.NewRaw("SELECT id FROM users WHERE id = 1").Scan(context.Background(), &u); err != nil { + t.Fatalf("Raw: %v", err) + } + } + hook.ResetN1() + if err := db.NewRaw("SELECT id FROM users WHERE id = 1").Scan(context.Background(), &u); err != nil { + t.Fatalf("Raw: %v", err) + } + if cap.has("n-plus-one") { + t.Fatalf("n-plus-one should not fire — ResetN1 zeroed the counter; got %+v", cap.snapshot()) + } +} + +// Proves UPDATE / DELETE statements also flow through Guard. +func TestHook_UpdateAndDeleteAnalyzed(t *testing.T) { + db, cap, _ := newDBWithCapture(t) + ctx := context.Background() + if _, err := db.NewRaw("UPDATE users SET email = 'x'").Exec(ctx); err != nil { + t.Fatalf("UPDATE: %v", err) + } + if !cap.has("update-without-where") { + t.Fatalf("expected update-without-where, got %+v", cap.snapshot()) + } + if _, err := db.NewRaw("DELETE FROM users").Exec(ctx); err != nil { + t.Fatalf("DELETE: %v", err) + } + if !cap.has("delete-without-where") { + t.Fatalf("expected delete-without-where, got %+v", cap.snapshot()) + } +} diff --git a/integrations/bunguard/go.mod b/integrations/bunguard/go.mod new file mode 100644 index 0000000..c089ec0 --- /dev/null +++ b/integrations/bunguard/go.mod @@ -0,0 +1,21 @@ +module github.com/KARTIKrocks/sqlguard/integrations/bunguard + +go 1.26 + +require ( + github.com/KARTIKrocks/sqlguard v0.0.0 + github.com/mattn/go-sqlite3 v1.14.45 + github.com/uptrace/bun v1.2.18 + github.com/uptrace/bun/dialect/sqlitedialect v1.2.18 +) + +require ( + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect + github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect + github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + golang.org/x/sys v0.41.0 // indirect +) + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/integrations/bunguard/go.sum b/integrations/bunguard/go.sum new file mode 100644 index 0000000..f284f00 --- /dev/null +++ b/integrations/bunguard/go.sum @@ -0,0 +1,26 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= +github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= +github.com/uptrace/bun v1.2.18 h1:3HnRcMfS6OBPMG1eSOzlbFJ/X/AyMEJb7rMxE6VQvDU= +github.com/uptrace/bun v1.2.18/go.mod h1:wNltaKJk4JtOt4SG5I5zmA7v0/Mzjh1+/S906Rayd3Y= +github.com/uptrace/bun/dialect/sqlitedialect v1.2.18 h1:Z33SY/U++XK9uGWqS4h8OZVxfCXguIG+sU9cYq2PGFQ= +github.com/uptrace/bun/dialect/sqlitedialect v1.2.18/go.mod h1:1MVOS/Ncy4FZbkJcgUFH6OqYoQinYNjkEwsmNQEXz2A= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/integrations/entguard/entguard.go b/integrations/entguard/entguard.go new file mode 100644 index 0000000..68191a1 --- /dev/null +++ b/integrations/entguard/entguard.go @@ -0,0 +1,123 @@ +// Package entguard integrates sqlguard with ent (entgo.io/ent). +// +// ent runs on database/sql, so the simplest coverage is already available by +// pointing entsql at a *sql.DB obtained from sqlguard.Register / OpenDB. This +// package is the dedicated alternative: it decorates ent's own +// dialect.Driver seam, so it works regardless of how the underlying *sql.DB +// was opened (including ent's dialect.DebugDriver chain) and mirrors ent's +// built-in dialect.Debug wrapper. +// +// Analysis is driven by the single shared sqlguard core (middleware.Guard), +// so redaction-by-default, stable fingerprints, the pluggable real-grammar +// parser, slow-query timing and N+1 detection behave identically to every +// other sqlguard surface. There is no parallel option surface — configure +// with the standard middleware options: +// +// drv, _ := entsql.Open(dialect.Postgres, dsn) +// guarded := entguard.Wrap(drv, +// middleware.WithSlowQueryThreshold(500*time.Millisecond), +// middleware.WithN1Detection(10, time.Second), +// ) +// client := ent.NewClient(ent.Driver(guarded)) +// +// Every Exec/Query — on the driver and on transactions it opens — flows +// through middleware.Guard.Observe: static rules run on every call, latency +// is recorded only on success. +package entguard + +import ( + "context" + "database/sql" + + "entgo.io/ent/dialect" + "github.com/KARTIKrocks/sqlguard/middleware" +) + +// Driver decorates an ent dialect.Driver, routing every statement through the +// shared sqlguard analysis core. +type Driver struct { + dialect.Driver + g *middleware.Guard +} + +// Compile-time proof we still satisfy ent's driver contract. +var _ dialect.Driver = (*Driver)(nil) + +// Wrap decorates an ent dialect.Driver. It accepts the standard sqlguard +// middleware options (WithAnalyzer, WithReporter, WithSlowQueryThreshold, +// WithParser, WithN1Detection, …) — the same option set every other sqlguard +// surface uses, so there is no parallel configuration surface to drift. +func Wrap(d dialect.Driver, opts ...middleware.Option) *Driver { + return &Driver{Driver: d, g: middleware.NewGuard(opts...)} +} + +// ResetN1 clears N+1 tracker state. Call it at a per-request boundary +// (e.g. end of an HTTP handler) to scope N+1 detection to one unit of work. +// No-op unless WithN1Detection was passed to Wrap. +func (d *Driver) ResetN1() { d.g.ResetN1() } + +// Exec implements dialect.ExecQuerier. +func (d *Driver) Exec(ctx context.Context, query string, args, v any) error { + done := d.g.Observe(query) + err := d.Driver.Exec(ctx, query, args, v) + done(err) + return err +} + +// Query implements dialect.ExecQuerier. +func (d *Driver) Query(ctx context.Context, query string, args, v any) error { + done := d.g.Observe(query) + err := d.Driver.Query(ctx, query, args, v) + done(err) + return err +} + +// Tx wraps the transaction so statements executed inside it are analyzed too. +func (d *Driver) Tx(ctx context.Context) (dialect.Tx, error) { + t, err := d.Driver.Tx(ctx) + if err != nil { + return nil, err + } + return &tx{Tx: t, g: d.g}, nil +} + +// BeginTx forwards to the wrapped driver's BeginTx when it implements one +// (entsql.Driver does — this is how ent honours read-only / isolation +// options), and wraps the resulting transaction. It degrades to Tx when the +// base driver has no BeginTx, matching ent's own fallback. +func (d *Driver) BeginTx(ctx context.Context, opts *sql.TxOptions) (dialect.Tx, error) { + bt, ok := d.Driver.(interface { + BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error) + }) + if !ok { + return d.Tx(ctx) + } + t, err := bt.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &tx{Tx: t, g: d.g}, nil +} + +// tx decorates a dialect.Tx so in-transaction Exec/Query are analyzed. +// Commit/Rollback are inherited from the embedded transaction unchanged. +type tx struct { + dialect.Tx + g *middleware.Guard +} + +// Exec implements dialect.ExecQuerier. +func (t *tx) Exec(ctx context.Context, query string, args, v any) error { + done := t.g.Observe(query) + err := t.Tx.Exec(ctx, query, args, v) + done(err) + return err +} + +// Query implements dialect.ExecQuerier. +func (t *tx) Query(ctx context.Context, query string, args, v any) error { + done := t.g.Observe(query) + err := t.Tx.Query(ctx, query, args, v) + done(err) + return err +} diff --git a/integrations/entguard/entguard_test.go b/integrations/entguard/entguard_test.go new file mode 100644 index 0000000..aa135b8 --- /dev/null +++ b/integrations/entguard/entguard_test.go @@ -0,0 +1,193 @@ +package entguard + +import ( + "context" + "strings" + "sync" + "testing" + "time" + + entsql "entgo.io/ent/dialect/sql" + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/middleware" + _ "github.com/mattn/go-sqlite3" +) + +// capture is a thread-safe in-memory Reporter for assertions. +type capture struct { + mu sync.Mutex + r []analyzer.Result +} + +func (c *capture) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.r = append(c.r, rs...) +} + +func (c *capture) snapshot() []analyzer.Result { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]analyzer.Result, len(c.r)) + copy(out, c.r) + return out +} + +func (c *capture) has(rule string) bool { + for _, r := range c.snapshot() { + if r.RuleName == rule { + return true + } + } + return false +} + +// newDriverWithCapture opens a real sqlite-backed ent dialect.Driver, seeds it +// through the *unwrapped* driver (so the capture starts clean), then wraps it +// with the sqlguard decorator. The integration thus runs end-to-end +// (dialect.Driver seam → database/sql round trip) rather than mocked. +func newDriverWithCapture(t *testing.T, opts ...middleware.Option) (*Driver, *capture) { + t.Helper() + drv, err := entsql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("entsql.Open: %v", err) + } + t.Cleanup(func() { _ = drv.Close() }) + + ctx := context.Background() + if err := drv.Exec(ctx, "CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT)", []any{}, nil); err != nil { + t.Fatalf("create table: %v", err) + } + if err := drv.Exec(ctx, "INSERT INTO users (id, email) VALUES (?, ?)", []any{1, "leak@example.com"}, nil); err != nil { + t.Fatalf("seed: %v", err) + } + + cap := &capture{} + opts = append([]middleware.Option{middleware.WithReporter(cap)}, opts...) + return Wrap(drv, opts...), cap +} + +func query(t *testing.T, ctx context.Context, q interface { + Query(context.Context, string, any, any) error +}, sqlText string) error { + t.Helper() + var rows entsql.Rows + err := q.Query(ctx, sqlText, []any{}, &rows) + if err == nil { + _ = rows.Close() + } + return err +} + +func TestDriver_DetectsSelectStar(t *testing.T) { + drv, cap := newDriverWithCapture(t) + if err := query(t, context.Background(), drv, "SELECT * FROM users"); err != nil { + t.Fatalf("Query: %v", err) + } + if !cap.has("select-star") { + t.Fatalf("expected select-star finding, got %+v", cap.snapshot()) + } +} + +// TestDriver_RedactsLiteralsByDefault asserts the headline redaction +// guarantee: single-quoted literals never reach Result.Query and Fingerprint +// is always populated. +func TestDriver_RedactsLiteralsByDefault(t *testing.T) { + drv, cap := newDriverWithCapture(t) + if err := query(t, context.Background(), drv, "SELECT * FROM users WHERE email = 'leak@example.com'"); err != nil { + t.Fatalf("Query: %v", err) + } + results := cap.snapshot() + if len(results) == 0 { + t.Fatal("expected at least one finding") + } + for _, r := range results { + if strings.Contains(r.Query, "leak@example.com") { + t.Errorf("literal leaked into Result.Query: %q (rule=%s)", r.Query, r.RuleName) + } + if r.Fingerprint == "" { + t.Errorf("Fingerprint must always be populated, got empty for rule %s", r.RuleName) + } + } +} + +func TestDriver_SlowQueryReportedOnSuccess(t *testing.T) { + drv, cap := newDriverWithCapture(t, middleware.WithSlowQueryThreshold(0)) + if err := query(t, context.Background(), drv, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Query: %v", err) + } + if !cap.has("slow-query") { + t.Fatalf("expected slow-query finding with zero threshold, got %+v", cap.snapshot()) + } +} + +func TestDriver_SlowQuerySuppressedOnError(t *testing.T) { + drv, cap := newDriverWithCapture(t, middleware.WithSlowQueryThreshold(0)) + err := query(t, context.Background(), drv, "SELECT id FROM no_such_table_xyz WHERE id = 1") + if err == nil { + t.Fatal("expected error from selecting a missing table") + } + if cap.has("slow-query") { + t.Fatalf("slow-query must not fire when the query failed; got %+v", cap.snapshot()) + } +} + +func TestDriver_NPlusOneAcrossCalls(t *testing.T) { + drv, cap := newDriverWithCapture(t, middleware.WithN1Detection(3, time.Second)) + for range 3 { + if err := query(t, context.Background(), drv, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Query: %v", err) + } + } + if !cap.has("n-plus-one") { + t.Fatalf("expected n-plus-one finding after 3 identical queries, got %+v", cap.snapshot()) + } +} + +func TestDriver_ResetN1ClearsState(t *testing.T) { + drv, cap := newDriverWithCapture(t, middleware.WithN1Detection(3, time.Second)) + for range 2 { + if err := query(t, context.Background(), drv, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Query: %v", err) + } + } + drv.ResetN1() + if err := query(t, context.Background(), drv, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Query: %v", err) + } + if cap.has("n-plus-one") { + t.Fatalf("n-plus-one should not fire — ResetN1 zeroed the counter; got %+v", cap.snapshot()) + } +} + +func TestDriver_ExecUpdateAnalyzed(t *testing.T) { + drv, cap := newDriverWithCapture(t) + if err := drv.Exec(context.Background(), "UPDATE users SET email = 'x'", []any{}, nil); err != nil { + t.Fatalf("Exec: %v", err) + } + if !cap.has("update-without-where") { + t.Fatalf("expected update-without-where, got %+v", cap.snapshot()) + } +} + +// TestDriver_TxQueriesAnalyzed proves the transaction wrapper also routes +// in-tx statements through Guard — a query class the database/sql-only path +// would miss if Tx weren't decorated. +func TestDriver_TxQueriesAnalyzed(t *testing.T) { + drv, cap := newDriverWithCapture(t) + ctx := context.Background() + tx, err := drv.Tx(ctx) + if err != nil { + t.Fatalf("Tx: %v", err) + } + if err := query(t, ctx, tx, "SELECT * FROM users"); err != nil { + _ = tx.Rollback() + t.Fatalf("tx Query: %v", err) + } + if err := tx.Commit(); err != nil { + t.Fatalf("Commit: %v", err) + } + if !cap.has("select-star") { + t.Fatalf("expected select-star finding from in-tx query, got %+v", cap.snapshot()) + } +} diff --git a/integrations/entguard/go.mod b/integrations/entguard/go.mod new file mode 100644 index 0000000..a58e275 --- /dev/null +++ b/integrations/entguard/go.mod @@ -0,0 +1,13 @@ +module github.com/KARTIKrocks/sqlguard/integrations/entguard + +go 1.26 + +require ( + entgo.io/ent v0.14.6 + github.com/KARTIKrocks/sqlguard v0.0.0 + github.com/mattn/go-sqlite3 v1.14.45 +) + +require github.com/google/uuid v1.3.0 // indirect + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/integrations/entguard/go.sum b/integrations/entguard/go.sum new file mode 100644 index 0000000..90db7ad --- /dev/null +++ b/integrations/entguard/go.sum @@ -0,0 +1,16 @@ +entgo.io/ent v0.14.6 h1:/f2696BpwuWAEEG6PVGWflg6+Inrpq4pRWuNlWz/Skk= +entgo.io/ent v0.14.6/go.mod h1:z46QBUdGC+BATwsedbDuREfSS0oSCV+csdEYlL4p73s= +github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/integrations/gormguard/go.mod b/integrations/gormguard/go.mod new file mode 100644 index 0000000..5fd0221 --- /dev/null +++ b/integrations/gormguard/go.mod @@ -0,0 +1,18 @@ +module github.com/KARTIKrocks/sqlguard/integrations/gormguard + +go 1.26 + +require ( + github.com/KARTIKrocks/sqlguard v0.0.0 + gorm.io/driver/sqlite v1.6.0 + gorm.io/gorm v1.31.1 +) + +require ( + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/mattn/go-sqlite3 v1.14.45 // indirect + golang.org/x/text v0.20.0 // indirect +) + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/integrations/gormguard/go.sum b/integrations/gormguard/go.sum new file mode 100644 index 0000000..3bb1e0e --- /dev/null +++ b/integrations/gormguard/go.sum @@ -0,0 +1,12 @@ +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/integrations/gormguard/gormguard.go b/integrations/gormguard/gormguard.go new file mode 100644 index 0000000..29237d2 --- /dev/null +++ b/integrations/gormguard/gormguard.go @@ -0,0 +1,134 @@ +// Package gormguard integrates sqlguard with GORM. +// +// Analysis is driven by the single shared sqlguard core (middleware.Guard), +// so redaction-by-default, stable fingerprints, the pluggable real-grammar +// parser, slow-query timing and N+1 detection behave identically to the +// database/sql driver wrapper and to pgxguard. There is no parallel option +// surface — configure with the standard middleware options: +// +// gormDB, _ := gorm.Open(postgres.Open(dsn), &gorm.Config{}) +// gormguard.Register(gormDB, +// middleware.WithSlowQueryThreshold(500*time.Millisecond), +// middleware.WithN1Detection(10, time.Second), +// ) +// +// GORM only exposes the final built SQL in its after-callback (it has not +// been generated when the before-callback fires), so this plugin uses the +// explicit Check+CheckLatency pair rather than middleware.Guard.Observe. +// Behaviour matches Observe semantically: static rules run on every call, +// latency is reported only on success. +package gormguard + +import ( + "time" + + "github.com/KARTIKrocks/sqlguard/middleware" + "gorm.io/gorm" +) + +// Plugin implements gorm.Plugin and drives every traced statement through +// the shared sqlguard analysis core. +type Plugin struct { + g *middleware.Guard +} + +// Compile-time proof we satisfy gorm.Plugin. +var _ gorm.Plugin = (*Plugin)(nil) + +// New creates a new sqlguard GORM plugin. It accepts the standard sqlguard +// middleware options (WithAnalyzer, WithReporter, WithSlowQueryThreshold, +// WithParser, WithN1Detection, …) — the same option set the database/sql +// driver wrapper and pgxguard use, so there is no parallel configuration +// surface to drift. +func New(opts ...middleware.Option) *Plugin { + return &Plugin{g: middleware.NewGuard(opts...)} +} + +// Name implements gorm.Plugin. +func (p *Plugin) Name() string { return "sqlguard" } + +// ResetN1 clears N+1 tracker state. Call it at a per-request boundary +// (e.g. end of an HTTP handler) to scope N+1 detection to one unit of work. +// No-op unless WithN1Detection was passed to New / Register. +func (p *Plugin) ResetN1() { p.g.ResetN1() } + +// Initialize registers before/after callbacks on every GORM callback chain. +// +// GORM v2 routes operations through six distinct callback chains: +// - Create/Update/Delete — ORM-style mutating operations +// - Query — ORM-style reads (First/Find/Take/…) +// - Row — raw SQL that returns rows (db.Raw().Scan / .Row) +// - Raw — raw SQL without rows (db.Exec) +// +// Missing any chain silently uncovers a query class — pre-rewrite, only +// Create/Query/Update/Delete were hooked, so every db.Raw and db.Exec +// bypassed analysis (and there were no tests to catch it). All six chains +// are now registered. +// +// SQL is analyzed in the after-callback because GORM has not yet rendered +// db.Statement.SQL when the before-callback fires for the ORM chains. +func (p *Plugin) Initialize(db *gorm.DB) error { + cb := db.Callback() + registrations := []struct { + before, after func(name string, fn func(*gorm.DB)) error + chain string + }{ + {before: cb.Create().Before("gorm:create").Register, after: cb.Create().After("gorm:create").Register, chain: "create"}, + {before: cb.Query().Before("gorm:query").Register, after: cb.Query().After("gorm:query").Register, chain: "query"}, + {before: cb.Update().Before("gorm:update").Register, after: cb.Update().After("gorm:update").Register, chain: "update"}, + {before: cb.Delete().Before("gorm:delete").Register, after: cb.Delete().After("gorm:delete").Register, chain: "delete"}, + {before: cb.Row().Before("gorm:row").Register, after: cb.Row().After("gorm:row").Register, chain: "row"}, + {before: cb.Raw().Before("gorm:raw").Register, after: cb.Raw().After("gorm:raw").Register, chain: "raw"}, + } + for _, r := range registrations { + if err := r.before("sqlguard:before_"+r.chain, p.before); err != nil { + return err + } + if err := r.after("sqlguard:after_"+r.chain, p.after); err != nil { + return err + } + } + return nil +} + +// startTimeKey is the per-statement context key under which the before +// callback stashes the start timestamp. Unexported so it can't collide with +// keys set by other plugins. +const startTimeKey = "sqlguard:start_time" + +func (p *Plugin) before(db *gorm.DB) { + db.Set(startTimeKey, time.Now()) +} + +func (p *Plugin) after(db *gorm.DB) { + if db.Statement == nil { + return + } + sql := db.Statement.SQL.String() + if sql == "" { + return + } + + // Static rules + N+1 run on every call (matches Observe semantics). + p.g.Check(sql) + + // Latency is reported only on success — a failed query's duration is + // meaningless. This mirrors middleware.Guard.Observe. + if db.Error != nil { + return + } + val, ok := db.Get(startTimeKey) + if !ok { + return + } + start, ok := val.(time.Time) + if !ok { + return + } + p.g.CheckLatency(sql, time.Since(start)) +} + +// Register is a convenience function to create and register the plugin. +func Register(db *gorm.DB, opts ...middleware.Option) error { + return db.Use(New(opts...)) +} diff --git a/integrations/gormguard/gormguard_test.go b/integrations/gormguard/gormguard_test.go new file mode 100644 index 0000000..c0e03cf --- /dev/null +++ b/integrations/gormguard/gormguard_test.go @@ -0,0 +1,200 @@ +package gormguard + +import ( + "context" + "strings" + "sync" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/middleware" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// capture is a thread-safe in-memory Reporter for assertions. +type capture struct { + mu sync.Mutex + r []analyzer.Result +} + +func (c *capture) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.r = append(c.r, rs...) +} + +func (c *capture) snapshot() []analyzer.Result { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]analyzer.Result, len(c.r)) + copy(out, c.r) + return out +} + +func (c *capture) has(rule string) bool { + for _, r := range c.snapshot() { + if r.RuleName == rule { + return true + } + } + return false +} + +type user struct { + ID int64 `gorm:"primaryKey"` + Email string +} + +// newDBWithCapture spins up an in-memory sqlite-backed *gorm.DB with the +// sqlguard plugin registered, so the integration runs end-to-end (callback +// seam → driver round trip) rather than mocked. +func newDBWithCapture(t *testing.T, opts ...middleware.Option) (*gorm.DB, *capture, *Plugin) { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("gorm.Open: %v", err) + } + if err := db.AutoMigrate(&user{}); err != nil { + t.Fatalf("AutoMigrate: %v", err) + } + if err := db.Create(&user{ID: 1, Email: "leak@example.com"}).Error; err != nil { + t.Fatalf("seed: %v", err) + } + + cap := &capture{} + opts = append([]middleware.Option{middleware.WithReporter(cap)}, opts...) + plugin := New(opts...) + if err := db.Use(plugin); err != nil { + t.Fatalf("db.Use: %v", err) + } + // Reset capture so the seed INSERT's findings don't pollute test + // assertions — every test wants only the findings from its own queries. + cap.mu.Lock() + cap.r = nil + cap.mu.Unlock() + return db, cap, plugin +} + +func TestPlugin_DetectsRawSelectStar(t *testing.T) { + db, cap, _ := newDBWithCapture(t) + var us []user + if err := db.Raw("SELECT * FROM users").Scan(&us).Error; err != nil { + t.Fatalf("Raw: %v", err) + } + if !cap.has("select-star") { + t.Fatalf("expected select-star finding, got %+v", cap.snapshot()) + } +} + +// TestPlugin_RedactsLiteralsByDefault is the headline 11.1 regression: +// the old hand-rolled after() set Result.Query to the raw SQL, so single- +// quoted literals leaked into log sinks. After the Guard rewrite Query +// must be the redacted form and Fingerprint must always be populated. +func TestPlugin_RedactsLiteralsByDefault(t *testing.T) { + db, cap, _ := newDBWithCapture(t) + var us []user + if err := db.Raw("SELECT * FROM users WHERE email = 'leak@example.com'").Scan(&us).Error; err != nil { + t.Fatalf("Raw: %v", err) + } + results := cap.snapshot() + if len(results) == 0 { + t.Fatal("expected at least one finding") + } + for _, r := range results { + if strings.Contains(r.Query, "leak@example.com") { + t.Errorf("literal leaked into Result.Query: %q (rule=%s)", r.Query, r.RuleName) + } + if r.Fingerprint == "" { + t.Errorf("Fingerprint must always be populated, got empty for rule %s", r.RuleName) + } + } +} + +// TestPlugin_SlowQueryReportedOnSuccess uses a zero threshold so any +// successful query trips the slow-query path. Threshold arithmetic is +// covered by middleware.Guard's own tests — here we only assert that the +// integration's after-callback drives CheckLatency on success. +func TestPlugin_SlowQueryReportedOnSuccess(t *testing.T) { + db, cap, _ := newDBWithCapture(t, middleware.WithSlowQueryThreshold(0)) + var u user + if err := db.First(&u, 1).Error; err != nil { + t.Fatalf("First: %v", err) + } + if !cap.has("slow-query") { + t.Fatalf("expected slow-query finding with zero threshold, got %+v", cap.snapshot()) + } +} + +func TestPlugin_SlowQuerySuppressedOnError(t *testing.T) { + db, cap, _ := newDBWithCapture(t, middleware.WithSlowQueryThreshold(0)) + // Force a SQL error: SELECT from a missing table via Raw so we hit the + // after-callback with db.Error != nil. + var dst int + err := db.Raw("SELECT id FROM no_such_table_xyz WHERE id = 1").Scan(&dst).Error + if err == nil { + t.Fatal("expected error from selecting a missing table") + } + if cap.has("slow-query") { + t.Fatalf("slow-query must not fire when the query failed; got %+v", cap.snapshot()) + } +} + +func TestPlugin_NPlusOneAcrossCalls(t *testing.T) { + db, cap, _ := newDBWithCapture(t, middleware.WithN1Detection(3, time.Second)) + var u user + for range 3 { + if err := db.Raw("SELECT id FROM users WHERE id = 1").Scan(&u).Error; err != nil { + t.Fatalf("Raw: %v", err) + } + } + if !cap.has("n-plus-one") { + t.Fatalf("expected n-plus-one finding after 3 identical queries, got %+v", cap.snapshot()) + } +} + +func TestPlugin_ResetN1ClearsState(t *testing.T) { + db, cap, plugin := newDBWithCapture(t, middleware.WithN1Detection(3, time.Second)) + var u user + for range 2 { + if err := db.Raw("SELECT id FROM users WHERE id = 1").Scan(&u).Error; err != nil { + t.Fatalf("Raw: %v", err) + } + } + plugin.ResetN1() + if err := db.Raw("SELECT id FROM users WHERE id = 1").Scan(&u).Error; err != nil { + t.Fatalf("Raw: %v", err) + } + if cap.has("n-plus-one") { + t.Fatalf("n-plus-one should not fire — ResetN1 zeroed the counter; got %+v", cap.snapshot()) + } +} + +// Proves the UPDATE / DELETE callbacks also flow through Guard. +func TestPlugin_UpdateAndDeleteCallbacksAnalyzed(t *testing.T) { + db, cap, _ := newDBWithCapture(t) + if err := db.WithContext(context.Background()).Exec("UPDATE users SET email = 'x'").Error; err != nil { + t.Fatalf("UPDATE: %v", err) + } + if !cap.has("update-without-where") { + t.Fatalf("expected update-without-where from update callback, got %+v", cap.snapshot()) + } + + if err := db.Exec("DELETE FROM users").Error; err != nil { + t.Fatalf("DELETE: %v", err) + } + if !cap.has("delete-without-where") { + t.Fatalf("expected delete-without-where from delete callback, got %+v", cap.snapshot()) + } +} + +func TestRegister_ReturnsNoError(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("gorm.Open: %v", err) + } + if err := Register(db); err != nil { + t.Fatalf("Register: %v", err) + } +} diff --git a/integrations/pgxguard/go.mod b/integrations/pgxguard/go.mod new file mode 100644 index 0000000..136572a --- /dev/null +++ b/integrations/pgxguard/go.mod @@ -0,0 +1,19 @@ +module github.com/KARTIKrocks/sqlguard/integrations/pgxguard + +go 1.26 + +require ( + github.com/KARTIKrocks/sqlguard v0.0.0 + github.com/jackc/pgx/v5 v5.7.6 +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + golang.org/x/crypto v0.37.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/text v0.27.0 // indirect +) + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/integrations/pgxguard/go.sum b/integrations/pgxguard/go.sum new file mode 100644 index 0000000..f06d480 --- /dev/null +++ b/integrations/pgxguard/go.sum @@ -0,0 +1,30 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= +github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/integrations/pgxguard/pgxguard.go b/integrations/pgxguard/pgxguard.go new file mode 100644 index 0000000..426110a --- /dev/null +++ b/integrations/pgxguard/pgxguard.go @@ -0,0 +1,145 @@ +// Package pgxguard integrates sqlguard with pgx/v5 — the native, dominant +// PostgreSQL driver for Go (pgx/pgxpool, not the database/sql shim). +// +// It hooks pgx's own tracer seam (pgx.QueryTracer + pgx.BatchTracer), which +// is the idiomatic extension point every pgx ecosystem tool uses, so every +// Query/QueryRow/Exec and every SendBatch is analyzed without a method list +// or a wrapper type. +// +// Composability is a first-class concern: pgx allows exactly one Tracer per +// config, and production services usually already set one (otelpgx). Apply +// and ApplyPool therefore *compose* with any existing tracer via pgx's own +// multitracer rather than overwriting it. +// +// Usage with a pool: +// +// cfg, _ := pgxpool.ParseConfig(dsn) +// pgxguard.ApplyPool(cfg) // composes with cfg.ConnConfig.Tracer if set +// pool, _ := pgxpool.NewWithConfig(ctx, cfg) +// +// Usage with a single connection: +// +// cfg, _ := pgx.ParseConfig(dsn) +// pgxguard.Apply(cfg) +// conn, _ := pgx.ConnectConfig(ctx, cfg) +// +// Analysis is driven by the single sqlguard core (middleware.Guard), so +// redaction-by-default, stable fingerprints, the pluggable real-grammar +// parser, slow-query and N+1 detection all behave identically to the +// database/sql driver wrapper. Configure with the standard middleware +// options: +// +// pgxguard.NewTracer( +// middleware.WithSlowQueryThreshold(50*time.Millisecond), +// middleware.WithN1Detection(10, time.Second), +// ) +package pgxguard + +import ( + "context" + + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/multitracer" + "github.com/jackc/pgx/v5/pgxpool" +) + +// Tracer implements pgx.QueryTracer and pgx.BatchTracer, driving every +// traced statement through the shared sqlguard analysis core. +// +// It deliberately does not implement pgx.PrepareTracer: prepared statements +// are still analyzed when executed (execution routes through QueryTracer), +// so tracing Prepare as well would double-report findings and inflate N+1 +// counts. CopyFrom carries no SQL and is out of scope by nature. +type Tracer struct { + g *middleware.Guard +} + +// Compile-time proof we satisfy the pgx tracer interfaces we claim. +var ( + _ pgx.QueryTracer = (*Tracer)(nil) + _ pgx.BatchTracer = (*Tracer)(nil) +) + +// NewTracer builds a Tracer. It accepts the standard sqlguard middleware +// options (WithAnalyzer, WithReporter, WithSlowQueryThreshold, WithParser, +// WithN1Detection, …) — the same option set the database/sql driver wrapper +// uses, so there is no parallel configuration surface to drift. +func NewTracer(opts ...middleware.Option) *Tracer { + return &Tracer{g: middleware.NewGuard(opts...)} +} + +// ResetN1 clears N+1 tracker state. Call it at a per-request boundary +// (e.g. end of an HTTP handler) to scope N+1 detection to one unit of work. +// No-op unless WithN1Detection was passed to NewTracer. +func (t *Tracer) ResetN1() { t.g.ResetN1() } + +// ctxKey is unexported so the stashed latency closure can't collide with +// any other package's context values. +type ctxKey struct{} + +// TraceQueryStart runs static analysis + N+1 tracking and starts the latency +// timer, stashing the end closure in the returned context. +func (t *Tracer) TraceQueryStart(ctx context.Context, _ *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + done := t.g.Observe(data.SQL) + return context.WithValue(ctx, ctxKey{}, done) +} + +// TraceQueryEnd closes the latency window. Latency is recorded only on +// success (Guard.Observe drops it when data.Err != nil — a failed query's +// duration is meaningless). +func (t *Tracer) TraceQueryEnd(ctx context.Context, _ *pgx.Conn, data pgx.TraceQueryEndData) { + if done, ok := ctx.Value(ctxKey{}).(func(error)); ok { + done(data.Err) + } +} + +// TraceBatchStart is a no-op: the batch's SQL is only known per-query, in +// TraceBatchQuery. +func (t *Tracer) TraceBatchStart(ctx context.Context, _ *pgx.Conn, _ pgx.TraceBatchStartData) context.Context { + return ctx +} + +// TraceBatchQuery analyzes each statement in a batch (static rules + N+1). +// Per-statement latency is not exposed by pgx's batch tracer — only the +// whole-batch round trip — so slow-query timing is intentionally not +// reported here rather than reported wrongly. +func (t *Tracer) TraceBatchQuery(_ context.Context, _ *pgx.Conn, data pgx.TraceBatchQueryData) { + t.g.Check(data.SQL) +} + +// TraceBatchEnd is a no-op (per-statement analysis happens in TraceBatchQuery). +func (t *Tracer) TraceBatchEnd(_ context.Context, _ *pgx.Conn, _ pgx.TraceBatchEndData) {} + +// Apply installs a sqlguard Tracer on a *pgx.ConnConfig, composing with any +// tracer already configured (via pgx's multitracer) instead of overwriting +// it — so it coexists with otelpgx and friends. opts are the standard +// middleware options. Returns the same cfg for chaining. +func Apply(cfg *pgx.ConnConfig, opts ...middleware.Option) *pgx.ConnConfig { + if cfg == nil { + panic("pgxguard: Apply called with nil *pgx.ConnConfig") + } + cfg.Tracer = compose(cfg.Tracer, NewTracer(opts...)) + return cfg +} + +// ApplyPool installs a sqlguard Tracer on a *pgxpool.Config (delegating to +// Apply on the embedded ConnConfig), composing with any existing tracer. +// Returns the same cfg for chaining. +func ApplyPool(cfg *pgxpool.Config, opts ...middleware.Option) *pgxpool.Config { + if cfg == nil { + panic("pgxguard: ApplyPool called with nil *pgxpool.Config") + } + Apply(cfg.ConnConfig, opts...) + return cfg +} + +// compose merges an existing tracer with ours. multitracer.New fans each +// call out to every wrapped tracer and routes by interface type-assertion, +// so the existing tracer keeps receiving exactly the events it did before. +func compose(existing pgx.QueryTracer, ours pgx.QueryTracer) pgx.QueryTracer { + if existing == nil { + return ours + } + return multitracer.New(existing, ours) +} diff --git a/integrations/pgxguard/pgxguard_test.go b/integrations/pgxguard/pgxguard_test.go new file mode 100644 index 0000000..6a2120f --- /dev/null +++ b/integrations/pgxguard/pgxguard_test.go @@ -0,0 +1,247 @@ +package pgxguard + +import ( + "context" + "errors" + "strings" + "sync" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/multitracer" + "github.com/jackc/pgx/v5/pgxpool" +) + +// capture is a thread-safe in-memory Reporter for assertions. +type capture struct { + mu sync.Mutex + r []analyzer.Result +} + +func (c *capture) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.r = append(c.r, rs...) +} + +func (c *capture) snapshot() []analyzer.Result { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]analyzer.Result, len(c.r)) + copy(out, c.r) + return out +} + +func (c *capture) has(rule string) bool { + for _, r := range c.snapshot() { + if r.RuleName == rule { + return true + } + } + return false +} + +// stubTracer is a fake existing pgx.QueryTracer used to prove Apply composes +// instead of clobbering. +type stubTracer struct { + mu sync.Mutex + starts int + ends int +} + +func (s *stubTracer) TraceQueryStart(ctx context.Context, _ *pgx.Conn, _ pgx.TraceQueryStartData) context.Context { + s.mu.Lock() + s.starts++ + s.mu.Unlock() + return ctx +} + +func (s *stubTracer) TraceQueryEnd(_ context.Context, _ *pgx.Conn, _ pgx.TraceQueryEndData) { + s.mu.Lock() + s.ends++ + s.mu.Unlock() +} + +func newTracerWithCapture(t *testing.T, opts ...middleware.Option) (*Tracer, *capture) { + t.Helper() + cap := &capture{} + opts = append([]middleware.Option{middleware.WithReporter(cap)}, opts...) + return NewTracer(opts...), cap +} + +// driveQuery runs a full Start→End round trip with no error. +func driveQuery(tr *Tracer, sql string, err error) { + ctx := tr.TraceQueryStart(context.Background(), nil, pgx.TraceQueryStartData{SQL: sql}) + tr.TraceQueryEnd(ctx, nil, pgx.TraceQueryEndData{Err: err}) +} + +func TestTracer_DetectsSelectStarOnQueryStart(t *testing.T) { + tr, cap := newTracerWithCapture(t) + driveQuery(tr, "SELECT * FROM users", nil) + if !cap.has("select-star") { + t.Fatalf("expected select-star finding, got %+v", cap.snapshot()) + } +} + +func TestTracer_RedactsLiteralsByDefault(t *testing.T) { + tr, cap := newTracerWithCapture(t) + driveQuery(tr, "SELECT * FROM users WHERE email = 'leak@example.com'", nil) + results := cap.snapshot() + if len(results) == 0 { + t.Fatal("expected at least one finding") + } + for _, r := range results { + if strings.Contains(r.Query, "leak@example.com") { + t.Errorf("literal leaked into Result.Query: %q", r.Query) + } + if r.Fingerprint == "" { + t.Errorf("Fingerprint must always be populated, got empty for rule %s", r.RuleName) + } + } +} + +func TestTracer_SlowQueryReportedOnEnd(t *testing.T) { + tr, cap := newTracerWithCapture(t, middleware.WithSlowQueryThreshold(1*time.Millisecond)) + ctx := tr.TraceQueryStart(context.Background(), nil, pgx.TraceQueryStartData{SQL: "SELECT id FROM users WHERE id = 1"}) + time.Sleep(5 * time.Millisecond) + tr.TraceQueryEnd(ctx, nil, pgx.TraceQueryEndData{Err: nil}) + if !cap.has("slow-query") { + t.Fatalf("expected slow-query finding, got %+v", cap.snapshot()) + } +} + +func TestTracer_SlowQuerySuppressedOnError(t *testing.T) { + tr, cap := newTracerWithCapture(t, middleware.WithSlowQueryThreshold(1*time.Millisecond)) + ctx := tr.TraceQueryStart(context.Background(), nil, pgx.TraceQueryStartData{SQL: "SELECT id FROM users WHERE id = 1"}) + time.Sleep(5 * time.Millisecond) + tr.TraceQueryEnd(ctx, nil, pgx.TraceQueryEndData{Err: errors.New("boom")}) + if cap.has("slow-query") { + t.Fatalf("slow-query should not fire when the query failed; got %+v", cap.snapshot()) + } +} + +func TestTracer_NPlusOneAcrossCalls(t *testing.T) { + tr, cap := newTracerWithCapture(t, middleware.WithN1Detection(3, time.Second)) + for range 3 { + driveQuery(tr, "SELECT id FROM users WHERE id = 1", nil) + } + if !cap.has("n-plus-one") { + t.Fatalf("expected n-plus-one finding after 3 identical queries, got %+v", cap.snapshot()) + } +} + +func TestTracer_ResetN1ClearsState(t *testing.T) { + tr, cap := newTracerWithCapture(t, middleware.WithN1Detection(3, time.Second)) + for range 2 { + driveQuery(tr, "SELECT id FROM users WHERE id = 1", nil) + } + tr.ResetN1() + driveQuery(tr, "SELECT id FROM users WHERE id = 1", nil) + if cap.has("n-plus-one") { + t.Fatalf("n-plus-one should not fire — reset zeroed the counter; got %+v", cap.snapshot()) + } +} + +func TestTracer_BatchQueryAnalyzed(t *testing.T) { + tr, cap := newTracerWithCapture(t) + ctx := tr.TraceBatchStart(context.Background(), nil, pgx.TraceBatchStartData{}) + tr.TraceBatchQuery(ctx, nil, pgx.TraceBatchQueryData{SQL: "SELECT * FROM users"}) + tr.TraceBatchEnd(ctx, nil, pgx.TraceBatchEndData{}) + if !cap.has("select-star") { + t.Fatalf("expected select-star finding from batch path, got %+v", cap.snapshot()) + } +} + +func TestApply_NilExistingSetsOursDirectly(t *testing.T) { + cfg, err := pgx.ParseConfig("postgres://u:p@localhost:5432/db") + if err != nil { + t.Fatalf("ParseConfig: %v", err) + } + if cfg.Tracer != nil { + t.Fatalf("baseline assumption broken: ParseConfig set a tracer (%T)", cfg.Tracer) + } + Apply(cfg) + if _, ok := cfg.Tracer.(*Tracer); !ok { + t.Fatalf("expected *pgxguard.Tracer, got %T", cfg.Tracer) + } +} + +// TestApply_ComposesWithExistingTracer is the headline community-fitness +// guarantee: if the user has already wired e.g. otelpgx, Apply must NOT +// silently overwrite it. +func TestApply_ComposesWithExistingTracer(t *testing.T) { + cfg, err := pgx.ParseConfig("postgres://u:p@localhost:5432/db") + if err != nil { + t.Fatalf("ParseConfig: %v", err) + } + stub := &stubTracer{} + cfg.Tracer = stub + + Apply(cfg) + + mt, ok := cfg.Tracer.(*multitracer.Tracer) + if !ok { + t.Fatalf("expected *multitracer.Tracer after composition, got %T", cfg.Tracer) + } + + var sawStub, sawOurs bool + for _, qt := range mt.QueryTracers { + switch qt.(type) { + case *stubTracer: + sawStub = true + case *Tracer: + sawOurs = true + } + } + if !sawStub { + t.Error("existing tracer was dropped by Apply — community-fitness contract violated") + } + if !sawOurs { + t.Error("our tracer was not installed by Apply") + } + + // And drive it: the existing stub must still receive Start/End events. + ctx := cfg.Tracer.TraceQueryStart(context.Background(), nil, pgx.TraceQueryStartData{SQL: "SELECT 1"}) + cfg.Tracer.TraceQueryEnd(ctx, nil, pgx.TraceQueryEndData{}) + stub.mu.Lock() + defer stub.mu.Unlock() + if stub.starts != 1 || stub.ends != 1 { + t.Errorf("existing tracer not driven through composition: starts=%d ends=%d", stub.starts, stub.ends) + } +} + +func TestApplyPool_DelegatesAndComposes(t *testing.T) { + cfg, err := pgxpool.ParseConfig("postgres://u:p@localhost:5432/db") + if err != nil { + t.Fatalf("pgxpool.ParseConfig: %v", err) + } + stub := &stubTracer{} + cfg.ConnConfig.Tracer = stub + + ApplyPool(cfg) + + if _, ok := cfg.ConnConfig.Tracer.(*multitracer.Tracer); !ok { + t.Fatalf("ApplyPool did not compose: got %T", cfg.ConnConfig.Tracer) + } +} + +func TestApply_NilConfigPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatal("expected panic on nil *pgx.ConnConfig") + } + }() + Apply(nil) +} + +func TestApplyPool_NilConfigPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatal("expected panic on nil *pgxpool.Config") + } + }() + ApplyPool(nil) +} diff --git a/integrations/sqlxguard/go.mod b/integrations/sqlxguard/go.mod new file mode 100644 index 0000000..8fb1885 --- /dev/null +++ b/integrations/sqlxguard/go.mod @@ -0,0 +1,12 @@ +module github.com/KARTIKrocks/sqlguard/integrations/sqlxguard + +go 1.26 + +require ( + github.com/KARTIKrocks/sqlguard v0.0.0 + github.com/jmoiron/sqlx v1.4.0 +) + +require github.com/mattn/go-sqlite3 v1.14.45 + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/integrations/sqlxguard/go.sum b/integrations/sqlxguard/go.sum new file mode 100644 index 0000000..af18b98 --- /dev/null +++ b/integrations/sqlxguard/go.sum @@ -0,0 +1,11 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= +github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= diff --git a/integrations/sqlxguard/sqlxguard.go b/integrations/sqlxguard/sqlxguard.go new file mode 100644 index 0000000..6ddcdbb --- /dev/null +++ b/integrations/sqlxguard/sqlxguard.go @@ -0,0 +1,157 @@ +// Package sqlxguard integrates sqlguard with sqlx. +// +// Every wrapped method routes through the shared sqlguard analysis core +// (middleware.Guard), so redaction-by-default, stable fingerprints, the +// pluggable real-grammar parser, slow-query timing and N+1 detection behave +// identically to the database/sql driver wrapper and to pgxguard. There is +// no parallel option surface — configure with the standard middleware +// options: +// +// db := sqlxguard.WrapSqlx(sqlxDB, +// middleware.WithSlowQueryThreshold(50*time.Millisecond), +// middleware.WithN1Detection(10, time.Second), +// ) +// +// Coverage note: WrappedDB exposes the sqlx-specific extension methods +// (Select/Get/Queryx/NamedExec and their *Context variants, plus Query/Exec +// passthrough). For full surface coverage — including QueryRow*, NamedQuery, +// MustExec and the transaction helpers — layer sqlx on top of the sqlguard +// driver chain instead: +// +// sqlguard.Register("sqlguard-pgx", pq.Driver{}, opts...) +// sqlDB, _ := sql.Open("sqlguard-pgx", dsn) +// db := sqlx.NewDb(sqlDB, "postgres") +// +// That path covers every sqlx method automatically because interception +// happens at the database/sql driver layer. +package sqlxguard + +import ( + "context" + "database/sql" + + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/jmoiron/sqlx" +) + +// WrappedDB wraps a *sqlx.DB with sqlguard analysis. Every analysis-bearing +// method drives the shared middleware.Guard, so behavior matches pgxguard +// and the database/sql driver chain exactly. +type WrappedDB struct { + db *sqlx.DB + g *middleware.Guard +} + +// WrapSqlx creates a new WrappedDB around the given sqlx connection. +// It accepts the standard sqlguard middleware options (WithAnalyzer, +// WithReporter, WithSlowQueryThreshold, WithParser, WithN1Detection, …) — +// the same option set the database/sql driver wrapper and pgxguard use, so +// there is no parallel configuration surface to drift. +func WrapSqlx(db *sqlx.DB, opts ...middleware.Option) *WrappedDB { + if db == nil { + panic("sqlxguard: WrapSqlx called with nil *sqlx.DB") + } + return &WrappedDB{db: db, g: middleware.NewGuard(opts...)} +} + +// DB returns the underlying *sqlx.DB. +func (w *WrappedDB) DB() *sqlx.DB { return w.db } + +// ResetN1 clears N+1 tracker state. Call it at a per-request boundary +// (e.g. end of an HTTP handler) to scope N+1 detection to one unit of work. +// No-op unless WithN1Detection was passed to WrapSqlx. +func (w *WrappedDB) ResetN1() { w.g.ResetN1() } + +// Select executes a query and scans the results into dest. +func (w *WrappedDB) Select(dest any, query string, args ...any) error { + done := w.g.Observe(query) + err := w.db.Select(dest, query, args...) + done(err) + return err +} + +// SelectContext executes a query with context and scans the results into dest. +func (w *WrappedDB) SelectContext(ctx context.Context, dest any, query string, args ...any) error { + done := w.g.Observe(query) + err := w.db.SelectContext(ctx, dest, query, args...) + done(err) + return err +} + +// Get executes a query and scans a single row into dest. +func (w *WrappedDB) Get(dest any, query string, args ...any) error { + done := w.g.Observe(query) + err := w.db.Get(dest, query, args...) + done(err) + return err +} + +// GetContext executes a query with context and scans a single row into dest. +func (w *WrappedDB) GetContext(ctx context.Context, dest any, query string, args ...any) error { + done := w.g.Observe(query) + err := w.db.GetContext(ctx, dest, query, args...) + done(err) + return err +} + +// Query executes a query that returns rows. +func (w *WrappedDB) Query(query string, args ...any) (*sql.Rows, error) { + done := w.g.Observe(query) + rows, err := w.db.Query(query, args...) + done(err) + return rows, err +} + +// QueryContext executes a query with context that returns rows. +func (w *WrappedDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + done := w.g.Observe(query) + rows, err := w.db.QueryContext(ctx, query, args...) + done(err) + return rows, err +} + +// Queryx executes a query that returns sqlx.Rows. +func (w *WrappedDB) Queryx(query string, args ...any) (*sqlx.Rows, error) { + done := w.g.Observe(query) + rows, err := w.db.Queryx(query, args...) + done(err) + return rows, err +} + +// Exec executes a query without returning rows. +func (w *WrappedDB) Exec(query string, args ...any) (sql.Result, error) { + done := w.g.Observe(query) + result, err := w.db.Exec(query, args...) + done(err) + return result, err +} + +// ExecContext executes a query with context without returning rows. +func (w *WrappedDB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + done := w.g.Observe(query) + result, err := w.db.ExecContext(ctx, query, args...) + done(err) + return result, err +} + +// NamedExec executes a named query. +func (w *WrappedDB) NamedExec(query string, arg any) (sql.Result, error) { + done := w.g.Observe(query) + result, err := w.db.NamedExec(query, arg) + done(err) + return result, err +} + +// NamedExecContext executes a named query with context. +func (w *WrappedDB) NamedExecContext(ctx context.Context, query string, arg any) (sql.Result, error) { + done := w.g.Observe(query) + result, err := w.db.NamedExecContext(ctx, query, arg) + done(err) + return result, err +} + +// Ping verifies the database connection. +func (w *WrappedDB) Ping() error { return w.db.Ping() } + +// Close closes the database connection. +func (w *WrappedDB) Close() error { return w.db.Close() } diff --git a/integrations/sqlxguard/sqlxguard_test.go b/integrations/sqlxguard/sqlxguard_test.go new file mode 100644 index 0000000..8f692ce --- /dev/null +++ b/integrations/sqlxguard/sqlxguard_test.go @@ -0,0 +1,187 @@ +package sqlxguard + +import ( + "context" + "strings" + "sync" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" +) + +// capture is a thread-safe in-memory Reporter for assertions. +type capture struct { + mu sync.Mutex + r []analyzer.Result +} + +func (c *capture) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.r = append(c.r, rs...) +} + +func (c *capture) snapshot() []analyzer.Result { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]analyzer.Result, len(c.r)) + copy(out, c.r) + return out +} + +func (c *capture) has(rule string) bool { + for _, r := range c.snapshot() { + if r.RuleName == rule { + return true + } + } + return false +} + +// newWrappedWithCapture spins up an in-memory sqlite-backed *sqlx.DB so the +// integration is exercised end-to-end (sqlx extension method → database/sql → +// real driver round trip) rather than mocked. +func newWrappedWithCapture(t *testing.T, opts ...middleware.Option) (*WrappedDB, *capture) { + t.Helper() + sqlxDB, err := sqlx.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("sqlx.Open: %v", err) + } + t.Cleanup(func() { _ = sqlxDB.Close() }) + if _, err := sqlxDB.Exec(`CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT)`); err != nil { + t.Fatalf("create table: %v", err) + } + if _, err := sqlxDB.Exec(`INSERT INTO users (id, email) VALUES (1, 'leak@example.com')`); err != nil { + t.Fatalf("seed: %v", err) + } + cap := &capture{} + opts = append([]middleware.Option{middleware.WithReporter(cap)}, opts...) + return WrapSqlx(sqlxDB, opts...), cap +} + +type user struct { + ID int64 `db:"id"` + Email string `db:"email"` +} + +func TestWrappedDB_DetectsSelectStar(t *testing.T) { + w, cap := newWrappedWithCapture(t) + var us []user + if err := w.Select(&us, "SELECT * FROM users"); err != nil { + t.Fatalf("Select: %v", err) + } + if !cap.has("select-star") { + t.Fatalf("expected select-star finding, got %+v", cap.snapshot()) + } +} + +// TestWrappedDB_RedactsLiteralsByDefault is the headline 11.1 regression: +// the old hand-rolled check() set Result.Query to the raw SQL, so single- +// quoted literals leaked into log sinks. After the Guard rewrite Query +// must be the redacted form and Fingerprint must always be populated. +func TestWrappedDB_RedactsLiteralsByDefault(t *testing.T) { + w, cap := newWrappedWithCapture(t) + var us []user + if err := w.Select(&us, "SELECT * FROM users WHERE email = 'leak@example.com'"); err != nil { + t.Fatalf("Select: %v", err) + } + results := cap.snapshot() + if len(results) == 0 { + t.Fatal("expected at least one finding") + } + for _, r := range results { + if strings.Contains(r.Query, "leak@example.com") { + t.Errorf("literal leaked into Result.Query: %q (rule=%s)", r.Query, r.RuleName) + } + if r.Fingerprint == "" { + t.Errorf("Fingerprint must always be populated, got empty for rule %s", r.RuleName) + } + } +} + +// TestWrappedDB_SlowQueryReportedOnSuccess uses a zero threshold so any +// successful round trip trips the slow-query path. The integration-level +// claim under test is "slow-query check runs on success", not the threshold +// arithmetic itself (that lives in middleware.Guard's own tests). +func TestWrappedDB_SlowQueryReportedOnSuccess(t *testing.T) { + w, cap := newWrappedWithCapture(t, middleware.WithSlowQueryThreshold(0)) + var u user + if err := w.Get(&u, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Get: %v", err) + } + if !cap.has("slow-query") { + t.Fatalf("expected slow-query finding with zero threshold, got %+v", cap.snapshot()) + } +} + +func TestWrappedDB_SlowQuerySuppressedOnError(t *testing.T) { + w, cap := newWrappedWithCapture(t, middleware.WithSlowQueryThreshold(0)) + var u user + if err := w.Get(&u, "SELECT id FROM no_such_table_xyz WHERE id = 1"); err == nil { + t.Fatal("expected error from selecting a missing table") + } + if cap.has("slow-query") { + t.Fatalf("slow-query must not fire when the query failed; got %+v", cap.snapshot()) + } +} + +func TestWrappedDB_NPlusOneAcrossCalls(t *testing.T) { + w, cap := newWrappedWithCapture(t, middleware.WithN1Detection(3, time.Second)) + var u user + for range 3 { + if err := w.Get(&u, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Get: %v", err) + } + } + if !cap.has("n-plus-one") { + t.Fatalf("expected n-plus-one finding after 3 identical queries, got %+v", cap.snapshot()) + } +} + +func TestWrappedDB_ResetN1ClearsState(t *testing.T) { + w, cap := newWrappedWithCapture(t, middleware.WithN1Detection(3, time.Second)) + var u user + for range 2 { + if err := w.Get(&u, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Get: %v", err) + } + } + w.ResetN1() + if err := w.Get(&u, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Get: %v", err) + } + if cap.has("n-plus-one") { + t.Fatalf("n-plus-one should not fire — ResetN1 zeroed the counter; got %+v", cap.snapshot()) + } +} + +// Proves the non-SELECT and *Context paths also flow through Guard. +func TestWrappedDB_ExecAndContextVariantsAnalyzed(t *testing.T) { + w, cap := newWrappedWithCapture(t) + if _, err := w.Exec("DELETE FROM users"); err != nil { + t.Fatalf("Exec: %v", err) + } + if !cap.has("delete-without-where") { + t.Fatalf("expected delete-without-where from Exec path, got %+v", cap.snapshot()) + } + + if _, err := w.ExecContext(context.Background(), "UPDATE users SET email = 'x'"); err != nil { + t.Fatalf("ExecContext: %v", err) + } + if !cap.has("update-without-where") { + t.Fatalf("expected update-without-where from ExecContext path, got %+v", cap.snapshot()) + } +} + +func TestWrapSqlx_NilPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatal("expected panic on nil *sqlx.DB") + } + }() + WrapSqlx(nil) +} diff --git a/integrations/xormguard/go.mod b/integrations/xormguard/go.mod new file mode 100644 index 0000000..8cfd74c --- /dev/null +++ b/integrations/xormguard/go.mod @@ -0,0 +1,18 @@ +module github.com/KARTIKrocks/sqlguard/integrations/xormguard + +go 1.26 + +require ( + github.com/KARTIKrocks/sqlguard v0.0.0 + github.com/mattn/go-sqlite3 v1.14.45 + xorm.io/xorm v1.3.11 +) + +require ( + github.com/goccy/go-json v0.10.5 // indirect + github.com/golang/snappy v0.0.4 // indirect + github.com/syndtr/goleveldb v1.0.0 // indirect + xorm.io/builder v0.3.13 // indirect +) + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/integrations/xormguard/go.sum b/integrations/xormguard/go.sum new file mode 100644 index 0000000..eb82c3b --- /dev/null +++ b/integrations/xormguard/go.sum @@ -0,0 +1,94 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s= +gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= +github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= +golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4= +golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8= +golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= +lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw= +modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0= +modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw= +modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY= +modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U= +modernc.org/libc v1.55.3/go.mod h1:qFXepLhz+JjFThQ4kzwzOjA/y/artDeg+pcYnY+Q83w= +modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= +modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= +modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E= +modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sqlite v1.20.4 h1:J8+m2trkN+KKoE7jglyHYYYiaq5xmz2HoHJIiBlRzbE= +modernc.org/sqlite v1.20.4/go.mod h1:zKcGyrICaxNTMEHSr1HQ2GUraP0j+845GYw37+EyT6A= +modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= +modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +xorm.io/builder v0.3.13 h1:a3jmiVVL19psGeXx8GIurTp7p0IIgqeDmwhcR6BAOAo= +xorm.io/builder v0.3.13/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= +xorm.io/xorm v1.3.11 h1:i4tlVUASogb0ZZFJHA7dZqoRU2pUpUsutnNdaOlFyMI= +xorm.io/xorm v1.3.11/go.mod h1:cs0ePc8O4a0jD78cNvD+0VFwhqotTvLQZv372QsDw7Q= diff --git a/integrations/xormguard/xormguard.go b/integrations/xormguard/xormguard.go new file mode 100644 index 0000000..047bdbc --- /dev/null +++ b/integrations/xormguard/xormguard.go @@ -0,0 +1,74 @@ +// Package xormguard integrates sqlguard with xorm (xorm.io/xorm). +// +// Analysis is driven by the single shared sqlguard core (middleware.Guard), +// so redaction-by-default, stable fingerprints, the pluggable real-grammar +// parser, slow-query timing and N+1 detection behave identically to the +// database/sql driver wrapper, pgxguard, gormguard and bunguard. There is no +// parallel option surface — configure with the standard middleware options: +// +// engine, _ := xorm.NewEngine("postgres", dsn) +// engine.AddHook(xormguard.New( +// middleware.WithSlowQueryThreshold(500*time.Millisecond), +// middleware.WithN1Detection(10, time.Second), +// )) +// +// xorm's contexts.Hook exposes the rendered SQL and the measured execution +// time on the ContextHook in AfterProcess, so this uses the explicit +// Check+CheckLatency pair (matching gormguard): static rules run on every +// query, latency is reported only on success. +package xormguard + +import ( + "context" + + "github.com/KARTIKrocks/sqlguard/middleware" + "xorm.io/xorm/contexts" +) + +// Hook implements xorm's contexts.Hook and drives every traced statement +// through the shared sqlguard analysis core. +type Hook struct { + g *middleware.Guard +} + +// Compile-time proof we satisfy contexts.Hook. +var _ contexts.Hook = (*Hook)(nil) + +// New creates a new sqlguard xorm hook. It accepts the standard sqlguard +// middleware options (WithAnalyzer, WithReporter, WithSlowQueryThreshold, +// WithParser, WithN1Detection, …) — the same option set every other sqlguard +// surface uses, so there is no parallel configuration surface to drift. +func New(opts ...middleware.Option) *Hook { + return &Hook{g: middleware.NewGuard(opts...)} +} + +// ResetN1 clears N+1 tracker state. Call it at a per-request boundary +// (e.g. end of an HTTP handler) to scope N+1 detection to one unit of work. +// No-op unless WithN1Detection was passed to New. +func (h *Hook) ResetN1() { h.g.ResetN1() } + +// BeforeProcess implements contexts.Hook. xorm stamps the start time itself +// and reports the elapsed duration as ContextHook.ExecuteTime in +// AfterProcess, so there is nothing to do here but pass the context through. +func (h *Hook) BeforeProcess(c *contexts.ContextHook) (context.Context, error) { + return c.Ctx, nil +} + +// AfterProcess implements contexts.Hook. c.SQL holds the rendered SQL, +// c.ExecuteTime the measured latency, and c.Err the query error (which is +// returned unchanged so the hook never swallows it). +func (h *Hook) AfterProcess(c *contexts.ContextHook) error { + if c.SQL == "" { + return c.Err + } + + // Static rules + N+1 run on every call (matches Observe semantics). + h.g.Check(c.SQL) + + // Latency is reported only on success — a failed query's duration is + // meaningless. This mirrors middleware.Guard.Observe. + if c.Err == nil { + h.g.CheckLatency(c.SQL, c.ExecuteTime) + } + return c.Err +} diff --git a/integrations/xormguard/xormguard_test.go b/integrations/xormguard/xormguard_test.go new file mode 100644 index 0000000..c9b28f6 --- /dev/null +++ b/integrations/xormguard/xormguard_test.go @@ -0,0 +1,166 @@ +package xormguard + +import ( + "strings" + "sync" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/middleware" + _ "github.com/mattn/go-sqlite3" + "xorm.io/xorm" +) + +// capture is a thread-safe in-memory Reporter for assertions. +type capture struct { + mu sync.Mutex + r []analyzer.Result +} + +func (c *capture) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.r = append(c.r, rs...) +} + +func (c *capture) snapshot() []analyzer.Result { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]analyzer.Result, len(c.r)) + copy(out, c.r) + return out +} + +func (c *capture) has(rule string) bool { + for _, r := range c.snapshot() { + if r.RuleName == rule { + return true + } + } + return false +} + +// newEngineWithCapture spins up an in-memory sqlite-backed *xorm.Engine with +// the sqlguard hook registered, so the integration runs end-to-end +// (contexts.Hook seam → driver round trip) rather than mocked. The hook is +// added after seeding so the capture starts clean. +func newEngineWithCapture(t *testing.T, opts ...middleware.Option) (*xorm.Engine, *capture, *Hook) { + t.Helper() + engine, err := xorm.NewEngine("sqlite3", ":memory:") + if err != nil { + t.Fatalf("NewEngine: %v", err) + } + t.Cleanup(func() { _ = engine.Close() }) + + if _, err := engine.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT)"); err != nil { + t.Fatalf("create table: %v", err) + } + if _, err := engine.Exec("INSERT INTO users (id, email) VALUES (?, ?)", 1, "leak@example.com"); err != nil { + t.Fatalf("seed: %v", err) + } + + cap := &capture{} + opts = append([]middleware.Option{middleware.WithReporter(cap)}, opts...) + hook := New(opts...) + engine.AddHook(hook) + return engine, cap, hook +} + +func TestHook_DetectsRawSelectStar(t *testing.T) { + engine, cap, _ := newEngineWithCapture(t) + if _, err := engine.QueryString("SELECT * FROM users"); err != nil { + t.Fatalf("QueryString: %v", err) + } + if !cap.has("select-star") { + t.Fatalf("expected select-star finding, got %+v", cap.snapshot()) + } +} + +// TestHook_RedactsLiteralsByDefault asserts the headline redaction guarantee: +// single-quoted literals never reach Result.Query and Fingerprint is always +// populated. +func TestHook_RedactsLiteralsByDefault(t *testing.T) { + engine, cap, _ := newEngineWithCapture(t) + if _, err := engine.QueryString("SELECT * FROM users WHERE email = 'leak@example.com'"); err != nil { + t.Fatalf("QueryString: %v", err) + } + results := cap.snapshot() + if len(results) == 0 { + t.Fatal("expected at least one finding") + } + for _, r := range results { + if strings.Contains(r.Query, "leak@example.com") { + t.Errorf("literal leaked into Result.Query: %q (rule=%s)", r.Query, r.RuleName) + } + if r.Fingerprint == "" { + t.Errorf("Fingerprint must always be populated, got empty for rule %s", r.RuleName) + } + } +} + +func TestHook_SlowQueryReportedOnSuccess(t *testing.T) { + engine, cap, _ := newEngineWithCapture(t, middleware.WithSlowQueryThreshold(0)) + if _, err := engine.QueryString("SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("QueryString: %v", err) + } + if !cap.has("slow-query") { + t.Fatalf("expected slow-query finding with zero threshold, got %+v", cap.snapshot()) + } +} + +func TestHook_SlowQuerySuppressedOnError(t *testing.T) { + engine, cap, _ := newEngineWithCapture(t, middleware.WithSlowQueryThreshold(0)) + _, err := engine.QueryString("SELECT id FROM no_such_table_xyz WHERE id = 1") + if err == nil { + t.Fatal("expected error from selecting a missing table") + } + if cap.has("slow-query") { + t.Fatalf("slow-query must not fire when the query failed; got %+v", cap.snapshot()) + } +} + +func TestHook_NPlusOneAcrossCalls(t *testing.T) { + engine, cap, _ := newEngineWithCapture(t, middleware.WithN1Detection(3, time.Second)) + for range 3 { + if _, err := engine.QueryString("SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("QueryString: %v", err) + } + } + if !cap.has("n-plus-one") { + t.Fatalf("expected n-plus-one finding after 3 identical queries, got %+v", cap.snapshot()) + } +} + +func TestHook_ResetN1ClearsState(t *testing.T) { + engine, cap, hook := newEngineWithCapture(t, middleware.WithN1Detection(3, time.Second)) + for range 2 { + if _, err := engine.QueryString("SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("QueryString: %v", err) + } + } + hook.ResetN1() + if _, err := engine.QueryString("SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("QueryString: %v", err) + } + if cap.has("n-plus-one") { + t.Fatalf("n-plus-one should not fire — ResetN1 zeroed the counter; got %+v", cap.snapshot()) + } +} + +// Proves UPDATE / DELETE statements also flow through Guard. +func TestHook_UpdateAndDeleteAnalyzed(t *testing.T) { + engine, cap, _ := newEngineWithCapture(t) + if _, err := engine.Exec("UPDATE users SET email = 'x'"); err != nil { + t.Fatalf("UPDATE: %v", err) + } + if !cap.has("update-without-where") { + t.Fatalf("expected update-without-where, got %+v", cap.snapshot()) + } + if _, err := engine.Exec("DELETE FROM users"); err != nil { + t.Fatalf("DELETE: %v", err) + } + if !cap.has("delete-without-where") { + t.Fatalf("expected delete-without-where, got %+v", cap.snapshot()) + } +} diff --git a/middleware/cache.go b/middleware/cache.go new file mode 100644 index 0000000..2389f1c --- /dev/null +++ b/middleware/cache.go @@ -0,0 +1,84 @@ +package middleware + +import ( + "container/list" + "sync" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +// analysisCache memoizes analyzer.Analyze results so each distinct query is +// parsed and rule-checked once instead of on every execution. It is a bounded +// LRU keyed on the **exact** query string. +// +// Why the exact string and not the fingerprint: the fingerprint folds away +// literal values, but a few rules read literal-derived facts the fingerprint +// discards — large-offset (OffsetValue), in-list-too-large (MaxInListLen), and +// leading-wildcard min-length (LeadingWildcardTermLen). Two queries can share a +// fingerprint yet warrant different findings, so fingerprint-keying would cache +// a wrong verdict. Identical query strings always analyze identically, which +// makes the exact string the only fully-correct key — and an effective one, +// since parameterized queries (the common case) and repeated identical queries +// hit while varying-literal queries miss (and those need re-analysis anyway). +// +// Cached result slices are shared and must be treated as read-only by callers +// (Guard.report and the reporters do). An analysisCache is safe for concurrent +// use. +type analysisCache struct { + mu sync.Mutex + ll *list.List + items map[string]*list.Element + capacity int +} + +type cacheEntry struct { + key string + results []analyzer.Result +} + +func newAnalysisCache(capacity int) *analysisCache { + return &analysisCache{ + ll: list.New(), + items: make(map[string]*list.Element), + capacity: capacity, + } +} + +// get returns the cached results for query and true, or nil and false on a +// miss. A cached empty/nil slice is a hit (the query was analyzed and produced +// no findings) — exactly the common case worth memoizing. +func (c *analysisCache) get(query string) ([]analyzer.Result, bool) { + c.mu.Lock() + defer c.mu.Unlock() + if el, ok := c.items[query]; ok { + c.ll.MoveToFront(el) + return el.Value.(*cacheEntry).results, true + } + return nil, false +} + +// put stores results for query, evicting the least-recently-used entry when the +// cache exceeds its capacity. +func (c *analysisCache) put(query string, results []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + if el, ok := c.items[query]; ok { + c.ll.MoveToFront(el) + el.Value.(*cacheEntry).results = results + return + } + el := c.ll.PushFront(&cacheEntry{key: query, results: results}) + c.items[query] = el + if c.ll.Len() > c.capacity { + if oldest := c.ll.Back(); oldest != nil { + c.ll.Remove(oldest) + delete(c.items, oldest.Value.(*cacheEntry).key) + } + } +} + +func (c *analysisCache) len() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.ll.Len() +} diff --git a/middleware/cache_test.go b/middleware/cache_test.go new file mode 100644 index 0000000..2ee4634 --- /dev/null +++ b/middleware/cache_test.go @@ -0,0 +1,125 @@ +package middleware + +import ( + "testing" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func TestAnalysisCache_HitMissAndStore(t *testing.T) { + c := newAnalysisCache(4) + + if _, ok := c.get("q"); ok { + t.Fatal("empty cache should miss") + } + + res := []analyzer.Result{{RuleName: "select-star"}} + c.put("q", res) + + got, ok := c.get("q") + if !ok { + t.Fatal("expected a hit after put") + } + if len(got) != 1 || got[0].RuleName != "select-star" { + t.Errorf("cached results mismatch: %+v", got) + } +} + +func TestAnalysisCache_CachesEmptyResults(t *testing.T) { + c := newAnalysisCache(4) + c.put("clean", nil) // a query that produced no findings is still worth caching + + got, ok := c.get("clean") + if !ok { + t.Fatal("a cached no-findings query must be a hit, not a miss") + } + if len(got) != 0 { + t.Errorf("expected zero findings, got %d", len(got)) + } +} + +func TestAnalysisCache_LRUEviction(t *testing.T) { + c := newAnalysisCache(2) + c.put("a", nil) + c.put("b", nil) + // Touch "a" so "b" becomes least-recently-used. + if _, ok := c.get("a"); !ok { + t.Fatal("a should still be present") + } + c.put("c", nil) // exceeds capacity -> evict LRU ("b") + + if _, ok := c.get("b"); ok { + t.Error("b should have been evicted as least-recently-used") + } + if _, ok := c.get("a"); !ok { + t.Error("a should survive (recently used)") + } + if _, ok := c.get("c"); !ok { + t.Error("c should be present (just added)") + } + if c.len() > 2 { + t.Errorf("cache exceeded capacity: %d", c.len()) + } +} + +// The cache must not change which findings are produced, including for the +// literal-sensitive rules whose verdict the fingerprint would have folded away. +func TestGuard_CacheCorrectForLiteralSensitiveRules(t *testing.T) { + rep := &countingReporter{} + g := NewGuard(WithReporter(rep), WithFindingDedup(0)) // dedup off to count every finding + + // Same fingerprint ("... OFFSET ?"), different OffsetValue: only the first + // crosses the large-offset threshold (default 1000). WHERE + LIMIT keep the + // only finding large-offset. If the cache keyed on fingerprint, the second + // would wrongly inherit the first's finding. + g.Check("SELECT id FROM users WHERE tenant = ? ORDER BY id LIMIT 10 OFFSET 5000") + if rep.count() != 1 { + t.Fatalf("expected exactly large-offset on OFFSET 5000, got %d findings", rep.count()) + } + g.Check("SELECT id FROM users WHERE tenant = ? ORDER BY id LIMIT 10 OFFSET 10") + if rep.count() != 1 { + t.Errorf("OFFSET 10 must not inherit a cached large-offset finding; total findings = %d", rep.count()) + } +} + +func TestGuard_CacheReturnsConsistentFindingsOnRepeat(t *testing.T) { + rep := &countingReporter{} + g := NewGuard(WithReporter(rep), WithFindingDedup(0)) + + for range 5 { + g.Check("DELETE FROM accounts") + } + // Cache must not swallow findings: dedup is off, so all 5 are reported. + if got := rep.count(); got != 5 { + t.Errorf("expected 5 findings across 5 identical calls, got %d", got) + } +} + +func TestGuard_CacheDisabled(t *testing.T) { + g := NewGuard(WithAnalysisCacheSize(0)) + if g.cache != nil { + t.Error("cache size 0 should leave the cache nil (disabled)") + } + // Still functions without a cache. + g.Check("DELETE FROM accounts") +} + +// benchQuery is a clean, parameterized query: representative of the prod-common +// case and produces no findings, so Check reduces to the analyze path. +const benchQuery = "SELECT id, name FROM users WHERE id = ? AND tenant = ?" + +func BenchmarkGuardCheck_Cached(b *testing.B) { + g := NewGuard(WithReporter(&countingReporter{}), WithFindingDedup(0)) + b.ReportAllocs() + for b.Loop() { + g.Check(benchQuery) + } +} + +func BenchmarkGuardCheck_Uncached(b *testing.B) { + g := NewGuard(WithReporter(&countingReporter{}), WithFindingDedup(0), WithAnalysisCacheSize(0)) + b.ReportAllocs() + for b.Loop() { + g.Check(benchQuery) + } +} diff --git a/middleware/dedup.go b/middleware/dedup.go new file mode 100644 index 0000000..f63f275 --- /dev/null +++ b/middleware/dedup.go @@ -0,0 +1,74 @@ +package middleware + +import ( + "sync" + "time" +) + +// deduper suppresses repeat emission of the same finding within a time window. +// A finding's identity is (fingerprint, ruleName): the same rule firing on the +// same canonical query shape. Without it, Guard.Check would re-emit every +// static finding on every execution of a recurring query (or every Exec of a +// prepared statement) and flood the log sink. The N+1 detector already +// self-dedups and slow-query is intentionally per-execution; this covers the +// per-query static rules. It reuses the QueryTracker windowing shape. +// +// A deduper is safe for concurrent use. +type deduper struct { + mu sync.Mutex + seen map[string]time.Time // key -> time the finding was last allowed + window time.Duration + maxKeys int +} + +func newDeduper(window time.Duration) *deduper { + return &deduper{ + seen: make(map[string]time.Time), + window: window, + maxKeys: 10000, + } +} + +// allow reports whether a finding identified by (fingerprint, rule) should be +// emitted at time now. It returns true the first time the finding is seen and +// again only after window has elapsed since it was last allowed. A window <= 0 +// disables dedup, so every call returns true (the legacy report-every-time +// behavior). +func (d *deduper) allow(fingerprint, rule string, now time.Time) bool { + if d.window <= 0 { + return true + } + key := fingerprint + "\x00" + rule + + d.mu.Lock() + defer d.mu.Unlock() + + if len(d.seen) >= d.maxKeys { + d.evictExpired(now) + // If eviction freed nothing (every entry still in-window) and this is a + // new key, drop the finding rather than grow the map without bound. A + // key already present is still updated below — never lose dedup state + // for a finding we're actively tracking. + if len(d.seen) >= d.maxKeys { + if _, ok := d.seen[key]; !ok { + return false + } + } + } + + last, ok := d.seen[key] + if !ok || now.Sub(last) > d.window { + d.seen[key] = now + return true + } + return false +} + +// evictExpired removes entries whose window has elapsed. Caller holds the lock. +func (d *deduper) evictExpired(now time.Time) { + for k, t := range d.seen { + if now.Sub(t) > d.window { + delete(d.seen, k) + } + } +} diff --git a/middleware/dedup_test.go b/middleware/dedup_test.go new file mode 100644 index 0000000..b400563 --- /dev/null +++ b/middleware/dedup_test.go @@ -0,0 +1,163 @@ +package middleware + +import ( + "strings" + "sync" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +// countingReporter records every Result it is handed, concurrency-safe. +type countingReporter struct { + mu sync.Mutex + results []analyzer.Result +} + +func (c *countingReporter) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.results = append(c.results, rs...) +} + +func (c *countingReporter) count() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.results) +} + +func TestDeduper_AllowsFirstSuppressesRepeatThenReReportsAfterWindow(t *testing.T) { + d := newDeduper(time.Minute) + now := time.Now() + + if !d.allow("fp", "select-star", now) { + t.Fatal("first occurrence should be allowed") + } + if d.allow("fp", "select-star", now) { + t.Error("repeat within window should be suppressed") + } + if !d.allow("fp", "select-star", now.Add(2*time.Minute)) { + t.Error("occurrence after window elapsed should be allowed again") + } +} + +func TestDeduper_DistinctIdentitiesIndependent(t *testing.T) { + d := newDeduper(time.Minute) + now := time.Now() + + // Different rule, same fingerprint. + if !d.allow("fp", "select-star", now) || !d.allow("fp", "select-without-limit", now) { + t.Error("distinct rules on the same fingerprint should each be allowed once") + } + // Different fingerprint, same rule. + if !d.allow("fp2", "select-star", now) { + t.Error("same rule on a distinct fingerprint should be allowed") + } +} + +func TestDeduper_DisabledWindowAlwaysAllows(t *testing.T) { + d := newDeduper(0) + now := time.Now() + for i := range 5 { + if !d.allow("fp", "select-star", now) { + t.Errorf("window<=0 disables dedup; call %d should be allowed", i) + } + } +} + +func TestDeduper_BoundedAtMaxKeys(t *testing.T) { + // All entries stay in-window, so eviction frees nothing: new keys past the + // cap must be dropped rather than grow the map without bound. + d := &deduper{seen: map[string]time.Time{}, window: time.Hour, maxKeys: 2} + now := time.Now() + + if !d.allow("a", "r", now) || !d.allow("b", "r", now) { + t.Fatal("first two distinct keys should be allowed") + } + if d.allow("c", "r", now) { + t.Error("a new key past maxKeys with nothing to evict should be dropped") + } + if len(d.seen) > d.maxKeys { + t.Errorf("map grew past maxKeys: %d", len(d.seen)) + } + // An already-tracked key is still served (dedup state is not lost). + if d.allow("a", "r", now) { + t.Error("an in-window tracked key should remain suppressed, not re-reported") + } +} + +func TestGuard_DedupSuppressesRepeatStaticFindings(t *testing.T) { + rep := &countingReporter{} + g := NewGuard(WithReporter(rep)) // default dedup window = 1m + + // DELETE without WHERE triggers exactly one rule (delete-without-where). + for range 10 { + g.Check("DELETE FROM accounts") + } + + if got := rep.count(); got != 1 { + t.Errorf("expected 1 static finding for a repeated query, got %d", got) + } +} + +func TestGuard_DedupDisabledReportsEveryTime(t *testing.T) { + rep := &countingReporter{} + g := NewGuard(WithReporter(rep), WithFindingDedup(0)) + + for range 5 { + g.Check("DELETE FROM accounts") + } + + if got := rep.count(); got != 5 { + t.Errorf("with dedup disabled expected 5 findings, got %d", got) + } +} + +func TestGuard_DedupIsPerIdentityNotPerQuery(t *testing.T) { + rep := &countingReporter{} + g := NewGuard(WithReporter(rep)) + + // Two literal variants share one fingerprint -> one select-star finding. + g.Check("SELECT * FROM users WHERE id = 1") + g.Check("SELECT * FROM users WHERE id = 2") + // A genuinely different flagged query is reported independently. + g.Check("DELETE FROM accounts") + + if got := rep.count(); got != 2 { + t.Errorf("expected 2 findings (one per identity), got %d", got) + } +} + +func TestGuard_DedupConcurrent(t *testing.T) { + rep := &countingReporter{} + g := NewGuard(WithReporter(rep)) + + var wg sync.WaitGroup + for range 100 { + wg.Go(func() { + g.Check("DELETE FROM accounts") + }) + } + wg.Wait() + + if got := rep.count(); got != 1 { + t.Errorf("expected exactly 1 finding under concurrency, got %d", got) + } +} + +func TestDriver_DedupRepeatedStaticFinding(t *testing.T) { + db, buf := guardedWithBuffer(t) + + for range 10 { + rows, err := db.Query("SELECT * FROM users") + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + } + + if n := strings.Count(buf.String(), "select-star"); n != 1 { + t.Errorf("expected select-star reported once across 10 executions, got %d", n) + } +} diff --git a/middleware/driver.go b/middleware/driver.go new file mode 100644 index 0000000..8be10ea --- /dev/null +++ b/middleware/driver.go @@ -0,0 +1,390 @@ +package middleware + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" +) + +// This file implements the standard database/sql driver-wrapping pattern +// (the approach used by ngrok/sqlmw, luna-duclos/instrumentedsql and +// OpenTelemetry's otelsql), hand-written with zero dependencies. +// +// Wrapping at the driver.Driver layer means every query — including those +// issued by ORMs and query builders through database/sql internals — flows +// through the analyzer automatically. There is no method list to keep in +// sync with database/sql, and the result is a real *sql.DB that composes +// with sqlc, ent, sqlx, gorm, pgx-stdlib and anything else. +// +// Optional driver interfaces (QueryerContext, Pinger, SessionResetter, …) +// are forwarded only when the wrapped driver implements them. Because the +// wrapper type structurally implements every optional interface, database/sql +// will always call them; the wrapper returns driver.ErrSkip (or the documented +// no-op) when the base does not support an operation, so database/sql falls +// back exactly as it would for the bare driver. This preserves the base +// driver's behavior without the combinatorial type-switch other libraries use. + +// Register wraps the database/sql driver currently registered under +// baseDriver and registers the analyzed result under name. Afterwards +// sql.Open(name, dsn) yields a *sql.DB whose every query is analyzed. +// +// middleware.Register("sqlguard-sqlite", "sqlite3") +// db, _ := sql.Open("sqlguard-sqlite", ":memory:") +// +// It returns an error if name is already registered or baseDriver is not +// a known driver. +func Register(name, baseDriver string, opts ...Option) (err error) { + // sql.Open does not connect; it only resolves the registered driver, + // so this is a cheap way to obtain the base driver.Driver by name. + probe, oerr := sql.Open(baseDriver, "") + if oerr != nil { + return fmt.Errorf("sqlguard: base driver %q: %w", baseDriver, oerr) + } + base := probe.Driver() + _ = probe.Close() + + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("sqlguard: register %q: %v", name, r) + } + }() + sql.Register(name, WrapDriver(base, opts...)) + return nil +} + +// OpenDB wraps a driver.Connector and returns an analyzed *sql.DB. Use this +// when you already hold a connector — for example pgx's stdlib.GetConnector +// or a driver-specific Connector — and don't want a global registration. +// +// connector := stdlib.GetConnector(*pgxConfig) +// db := middleware.OpenDB(connector) +func OpenDB(c driver.Connector, opts ...Option) *sql.DB { + return sql.OpenDB(WrapConnector(c, opts...)) +} + +// WrapDriver returns a driver.Driver that analyzes every query executed +// through it. The returned driver also implements driver.DriverContext so +// connector-based pooling is preserved. +func WrapDriver(base driver.Driver, opts ...Option) driver.Driver { + return &wDriver{base: base, g: NewGuard(opts...)} +} + +// WrapConnector returns a driver.Connector that analyzes every query +// executed through connections it produces. +func WrapConnector(base driver.Connector, opts ...Option) driver.Connector { + return &wConnector{base: base, g: NewGuard(opts...)} +} + +// ---- driver.Driver / driver.DriverContext ---- + +type wDriver struct { + base driver.Driver + g *Guard +} + +var ( + _ driver.Driver = (*wDriver)(nil) + _ driver.DriverContext = (*wDriver)(nil) +) + +func (d *wDriver) Open(name string) (driver.Conn, error) { + c, err := d.base.Open(name) + if err != nil { + return nil, err + } + return &wConn{base: c, g: d.g}, nil +} + +// OpenConnector implements driver.DriverContext. If the base driver supports +// connectors we wrap its connector; otherwise we synthesize a DSN-based +// connector equivalent to the one database/sql builds internally. +func (d *wDriver) OpenConnector(name string) (driver.Connector, error) { + if dc, ok := d.base.(driver.DriverContext); ok { + bc, err := dc.OpenConnector(name) + if err != nil { + return nil, err + } + return &wConnector{base: bc, g: d.g}, nil + } + return &wConnector{base: dsnConnector{dsn: name, driver: d.base}, g: d.g}, nil +} + +// dsnConnector mirrors database/sql's internal dsnConnector for base drivers +// that do not implement driver.DriverContext. +type dsnConnector struct { + dsn string + driver driver.Driver +} + +func (c dsnConnector) Connect(_ context.Context) (driver.Conn, error) { + return c.driver.Open(c.dsn) +} +func (c dsnConnector) Driver() driver.Driver { return c.driver } + +// ---- driver.Connector ---- + +type wConnector struct { + base driver.Connector + g *Guard +} + +var _ driver.Connector = (*wConnector)(nil) + +func (c *wConnector) Connect(ctx context.Context) (driver.Conn, error) { + conn, err := c.base.Connect(ctx) + if err != nil { + return nil, err + } + return &wConn{base: conn, g: c.g}, nil +} + +func (c *wConnector) Driver() driver.Driver { + return &wDriver{base: c.base.Driver(), g: c.g} +} + +// ---- driver.Conn and its optional interfaces ---- + +type wConn struct { + base driver.Conn + g *Guard +} + +var ( + _ driver.Conn = (*wConn)(nil) + _ driver.ConnPrepareContext = (*wConn)(nil) + _ driver.ConnBeginTx = (*wConn)(nil) + _ driver.QueryerContext = (*wConn)(nil) + _ driver.ExecerContext = (*wConn)(nil) + _ driver.Pinger = (*wConn)(nil) + _ driver.SessionResetter = (*wConn)(nil) + _ driver.Validator = (*wConn)(nil) + _ driver.NamedValueChecker = (*wConn)(nil) +) + +func (c *wConn) Prepare(query string) (driver.Stmt, error) { + s, err := c.base.Prepare(query) + if err != nil { + return nil, err + } + return &wStmt{base: s, query: query, g: c.g}, nil +} + +func (c *wConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + var ( + s driver.Stmt + err error + ) + if cpc, ok := c.base.(driver.ConnPrepareContext); ok { + s, err = cpc.PrepareContext(ctx, query) + } else { + s, err = c.base.Prepare(query) + } + if err != nil { + return nil, err + } + return &wStmt{base: s, query: query, g: c.g}, nil +} + +func (c *wConn) Close() error { return c.base.Close() } + +func (c *wConn) Begin() (driver.Tx, error) { + tx, err := c.base.Begin() //nolint:staticcheck // delegated deprecated path + if err != nil { + return nil, err + } + return &wTx{base: tx}, nil +} + +func (c *wConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var ( + tx driver.Tx + err error + ) + if cbt, ok := c.base.(driver.ConnBeginTx); ok { + tx, err = cbt.BeginTx(ctx, opts) + } else { + tx, err = c.base.Begin() //nolint:staticcheck // delegated deprecated path + } + if err != nil { + return nil, err + } + return &wTx{base: tx}, nil +} + +func (c *wConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + // Observe only on a path that actually executes. When the base has no direct + // Query entry point we return driver.ErrSkip *without* analyzing, so + // database/sql's Prepare+Query fallback — which re-enters through wStmt — is + // the single place this query is analyzed. Analyzing here too would count + // the same logical query twice (a duplicate finding and an inflated N+1). + if qc, ok := c.base.(driver.QueryerContext); ok { + done := c.g.Observe(query) + rows, err := qc.QueryContext(ctx, query, args) + done(err) + return rows, err + } + if q, ok := c.base.(driver.Queryer); ok { //nolint:staticcheck // legacy fallback + values, verr := namedToValues(args) + if verr != nil { + return nil, verr + } + done := c.g.Observe(query) + rows, err := q.Query(query, values) //nolint:staticcheck // legacy fallback + done(err) + return rows, err + } + return nil, driver.ErrSkip +} + +func (c *wConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + // See QueryContext: analyze only when this path executes. Returning ErrSkip + // without analyzing lets the Prepare+Exec fallback (via wStmt) be the single + // analysis point, avoiding a double count. + if ec, ok := c.base.(driver.ExecerContext); ok { + done := c.g.Observe(query) + res, err := ec.ExecContext(ctx, query, args) + done(err) + return res, err + } + if e, ok := c.base.(driver.Execer); ok { //nolint:staticcheck // legacy fallback + values, verr := namedToValues(args) + if verr != nil { + return nil, verr + } + done := c.g.Observe(query) + res, err := e.Exec(query, values) //nolint:staticcheck // legacy fallback + done(err) + return res, err + } + return nil, driver.ErrSkip +} + +func (c *wConn) Ping(ctx context.Context) error { + if p, ok := c.base.(driver.Pinger); ok { + return p.Ping(ctx) + } + // Base is not a Pinger; ErrSkip tells database/sql ping is unsupported + // and the connection should be assumed valid, matching the bare driver. + return driver.ErrSkip +} + +func (c *wConn) ResetSession(ctx context.Context) error { + if r, ok := c.base.(driver.SessionResetter); ok { + return r.ResetSession(ctx) + } + return nil +} + +func (c *wConn) IsValid() bool { + if v, ok := c.base.(driver.Validator); ok { + return v.IsValid() + } + return true +} + +func (c *wConn) CheckNamedValue(nv *driver.NamedValue) error { + if ck, ok := c.base.(driver.NamedValueChecker); ok { + return ck.CheckNamedValue(nv) + } + // Defer to database/sql's default argument conversion. + return driver.ErrSkip +} + +// ---- driver.Stmt and its optional interfaces ---- + +type wStmt struct { + base driver.Stmt + query string + g *Guard +} + +var ( + _ driver.Stmt = (*wStmt)(nil) + _ driver.StmtExecContext = (*wStmt)(nil) + _ driver.StmtQueryContext = (*wStmt)(nil) + _ driver.NamedValueChecker = (*wStmt)(nil) +) + +func (s *wStmt) Close() error { return s.base.Close() } +func (s *wStmt) NumInput() int { return s.base.NumInput() } + +func (s *wStmt) Exec(args []driver.Value) (driver.Result, error) { + done := s.g.Observe(s.query) + res, err := s.base.Exec(args) //nolint:staticcheck // delegated deprecated path + done(err) + return res, err +} + +func (s *wStmt) Query(args []driver.Value) (driver.Rows, error) { + done := s.g.Observe(s.query) + rows, err := s.base.Query(args) //nolint:staticcheck // delegated deprecated path + done(err) + return rows, err +} + +func (s *wStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + done := s.g.Observe(s.query) + if ec, ok := s.base.(driver.StmtExecContext); ok { + res, err := ec.ExecContext(ctx, args) + done(err) + return res, err + } + values, verr := namedToValues(args) + if verr != nil { + return nil, verr + } + res, err := s.base.Exec(values) //nolint:staticcheck // legacy fallback + done(err) + return res, err +} + +func (s *wStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + done := s.g.Observe(s.query) + if qc, ok := s.base.(driver.StmtQueryContext); ok { + rows, err := qc.QueryContext(ctx, args) + done(err) + return rows, err + } + values, verr := namedToValues(args) + if verr != nil { + return nil, verr + } + rows, err := s.base.Query(values) //nolint:staticcheck // legacy fallback + done(err) + return rows, err +} + +func (s *wStmt) CheckNamedValue(nv *driver.NamedValue) error { + if ck, ok := s.base.(driver.NamedValueChecker); ok { + return ck.CheckNamedValue(nv) + } + return driver.ErrSkip +} + +// ---- driver.Tx ---- + +type wTx struct { + base driver.Tx +} + +var _ driver.Tx = (*wTx)(nil) + +func (t *wTx) Commit() error { return t.base.Commit() } +func (t *wTx) Rollback() error { return t.base.Rollback() } + +// ---- helpers ---- + +// namedToValues converts named values to positional values for the legacy +// Queryer/Execer/Stmt fallback paths, which predate named parameters. +func namedToValues(named []driver.NamedValue) ([]driver.Value, error) { + values := make([]driver.Value, len(named)) + for i, nv := range named { + if nv.Name != "" { + return nil, errors.New("sqlguard: driver does not support named parameters") + } + values[i] = nv.Value + } + return values, nil +} diff --git a/middleware/driver_fallback_test.go b/middleware/driver_fallback_test.go new file mode 100644 index 0000000..071ee3b --- /dev/null +++ b/middleware/driver_fallback_test.go @@ -0,0 +1,111 @@ +package middleware + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "io" + "testing" + "time" +) + +// fakeNoQueryerDriver is a minimal driver whose Conn implements neither +// QueryerContext/ExecerContext nor the legacy Queryer/Execer. database/sql is +// therefore forced down its Prepare+Stmt fallback path for every Query/Exec — +// the path where wConn.{Query,Exec}Context return driver.ErrSkip. It exists to +// prove a single logical query is analyzed exactly once even then. +type fakeNoQueryerDriver struct{} + +func (fakeNoQueryerDriver) Open(string) (driver.Conn, error) { return &fakeConn{}, nil } + +type fakeConn struct{} + +func (*fakeConn) Prepare(string) (driver.Stmt, error) { return &fakeStmt{}, nil } +func (*fakeConn) Close() error { return nil } +func (*fakeConn) Begin() (driver.Tx, error) { return &fakeTx{}, nil } + +type fakeStmt struct{} + +func (*fakeStmt) Close() error { return nil } +func (*fakeStmt) NumInput() int { return -1 } // skip arg-count checking +func (*fakeStmt) Exec([]driver.Value) (driver.Result, error) { return driver.RowsAffected(0), nil } +func (*fakeStmt) Query([]driver.Value) (driver.Rows, error) { return &fakeRows{}, nil } + +type fakeRows struct{} + +func (*fakeRows) Columns() []string { return nil } +func (*fakeRows) Close() error { return nil } +func (*fakeRows) Next([]driver.Value) error { return io.EOF } + +type fakeTx struct{} + +func (*fakeTx) Commit() error { return nil } +func (*fakeTx) Rollback() error { return nil } + +// openFakeGuarded registers a wrapped fakeNoQueryerDriver and returns the DB +// plus the reporter that records findings. Dedup is off so every analysis is +// counted (the bug would surface as 2 findings for one query). +func openFakeGuarded(t *testing.T) (*sql.DB, *countingReporter) { + t.Helper() + rep := &countingReporter{} + name := fmt.Sprintf("sqlguard-fake-%d", driverSeq.Add(1)) + sql.Register(name, WrapDriver(fakeNoQueryerDriver{}, WithReporter(rep), WithFindingDedup(0))) + db, err := sql.Open(name, "") + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { db.Close() }) + return db, rep +} + +func TestDriver_NoQueryerContextAnalyzedOnce(t *testing.T) { + db, rep := openFakeGuarded(t) + + rows, err := db.Query("DELETE FROM accounts") // flagged: delete-without-where + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if got := rep.count(); got != 1 { + t.Errorf("expected one logical query analyzed once via the prepare fallback, got %d", got) + } +} + +func TestDriver_NoExecerContextAnalyzedOnce(t *testing.T) { + db, rep := openFakeGuarded(t) + + if _, err := db.Exec("DELETE FROM accounts"); err != nil { + t.Fatalf("exec: %v", err) + } + + if got := rep.count(); got != 1 { + t.Errorf("expected one logical query analyzed once via the prepare fallback, got %d", got) + } +} + +// With N+1 enabled, each logical query must increment the counter once. If the +// ErrSkip path double-counted, threshold=2 would trip after a single query. +func TestDriver_NoQueryerContextN1CountedOnce(t *testing.T) { + rep := &countingReporter{} + name := fmt.Sprintf("sqlguard-fake-%d", driverSeq.Add(1)) + sql.Register(name, WrapDriver(fakeNoQueryerDriver{}, + WithReporter(rep), WithFindingDedup(0), WithN1Detection(2, time.Minute))) + db, err := sql.Open(name, "") + if err != nil { + t.Fatalf("open: %v", err) + } + defer db.Close() + + // One execution of a non-flagged query: no static finding, and the N+1 + // counter should be at 1 (below threshold 2), so nothing is reported. + rows, err := db.Query("SELECT id, name FROM users WHERE id = ?", 1) + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if got := rep.count(); got != 0 { + t.Errorf("one logical query must not trip N+1 (threshold 2); got %d reports", got) + } +} diff --git a/middleware/driver_test.go b/middleware/driver_test.go new file mode 100644 index 0000000..d3e4328 --- /dev/null +++ b/middleware/driver_test.go @@ -0,0 +1,265 @@ +package middleware + +import ( + "bytes" + "database/sql" + "fmt" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/reporter" + + _ "github.com/mattn/go-sqlite3" +) + +var driverSeq atomic.Int64 + +// newGuardedDB registers a uniquely-named wrapped sqlite3 driver with the +// given options and returns an analyzed *sql.DB backed by a temp-file +// database (so the connection pool sees a consistent schema). +func newGuardedDB(t *testing.T, opts ...Option) *sql.DB { + t.Helper() + name := fmt.Sprintf("sqlguard-test-%d", driverSeq.Add(1)) + if err := Register(name, "sqlite3", opts...); err != nil { + t.Fatalf("Register: %v", err) + } + dsn := filepath.Join(t.TempDir(), "test.db") + db, err := sql.Open(name, dsn) + if err != nil { + t.Fatalf("sql.Open: %v", err) + } + t.Cleanup(func() { db.Close() }) + + if _, err := db.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, email TEXT)"); err != nil { + t.Fatalf("create table: %v", err) + } + if _, err := db.Exec("INSERT INTO users (name, email) VALUES ('alice', 'alice@example.com')"); err != nil { + t.Fatalf("insert: %v", err) + } + return db +} + +func guardedWithBuffer(t *testing.T, extra ...Option) (*sql.DB, *bytes.Buffer) { + t.Helper() + var buf bytes.Buffer + opts := append([]Option{WithReporter(&reporter.ConsoleReporter{Out: &buf})}, extra...) + return newGuardedDB(t, opts...), &buf +} + +func TestDriver_ReturnsRealSQLDB(t *testing.T) { + db, _ := guardedWithBuffer(t) + // The whole point: Register/sql.Open yield a real *sql.DB, usable + // anywhere one is expected (no wrapper type to thread through). + if db == nil { + t.Fatal("expected a *sql.DB") + } +} + +func TestDriver_QueryDetectsSelectStar(t *testing.T) { + db, buf := guardedWithBuffer(t) + + rows, err := db.Query("SELECT * FROM users") + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if !strings.Contains(buf.String(), "select-star") { + t.Errorf("expected select-star warning, got: %q", buf.String()) + } +} + +func TestDriver_NoWarningForSafeQuery(t *testing.T) { + db, buf := guardedWithBuffer(t) + + rows, err := db.Query("SELECT id, name FROM users WHERE id = ?", 1) + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if buf.Len() != 0 { + t.Errorf("expected no warnings, got: %q", buf.String()) + } +} + +func TestDriver_ExecDetectsDeleteWithoutWhere(t *testing.T) { + db, buf := guardedWithBuffer(t) + + if _, err := db.Exec("DELETE FROM users"); err != nil { + t.Fatalf("exec: %v", err) + } + + if !strings.Contains(buf.String(), "delete-without-where") { + t.Errorf("expected delete-without-where, got: %q", buf.String()) + } + if !strings.Contains(buf.String(), "CRITICAL") { + t.Error("expected CRITICAL severity") + } +} + +func TestDriver_QueryRowDetectsLeadingWildcard(t *testing.T) { + db, buf := guardedWithBuffer(t) + + _ = db.QueryRow("SELECT id FROM users WHERE email LIKE '%gmail%'") + + if !strings.Contains(buf.String(), "leading-wildcard") { + t.Errorf("expected leading-wildcard, got: %q", buf.String()) + } +} + +func TestDriver_PreparedStatementIsAnalyzed(t *testing.T) { + db, buf := guardedWithBuffer(t) + + stmt, err := db.Prepare("SELECT * FROM users") + if err != nil { + t.Fatalf("prepare: %v", err) + } + defer stmt.Close() + + rows, err := stmt.Query() + if err != nil { + t.Fatalf("stmt query: %v", err) + } + rows.Close() + + if !strings.Contains(buf.String(), "select-star") { + t.Errorf("expected select-star on prepared exec, got: %q", buf.String()) + } +} + +func TestDriver_TransactionIsAnalyzed(t *testing.T) { + db, buf := guardedWithBuffer(t) + + tx, err := db.Begin() + if err != nil { + t.Fatalf("begin: %v", err) + } + if _, err := tx.Exec("DELETE FROM users"); err != nil { + t.Fatalf("tx exec: %v", err) + } + if err := tx.Rollback(); err != nil { + t.Fatalf("rollback: %v", err) + } + + if !strings.Contains(buf.String(), "delete-without-where") { + t.Errorf("expected delete-without-where in tx, got: %q", buf.String()) + } +} + +func TestDriver_TransactionCommitRollback(t *testing.T) { + db, _ := guardedWithBuffer(t) + + tx, err := db.Begin() + if err != nil { + t.Fatalf("begin: %v", err) + } + if _, err := tx.Exec("INSERT INTO users (name, email) VALUES (?, ?)", "bob", "bob@example.com"); err != nil { + t.Fatalf("exec: %v", err) + } + if err := tx.Commit(); err != nil { + t.Fatalf("commit: %v", err) + } + + var count int + if err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count); err != nil { + t.Fatalf("scan: %v", err) + } + if count != 2 { + t.Errorf("expected 2 rows after commit, got %d", count) + } + + tx2, err := db.Begin() + if err != nil { + t.Fatalf("begin: %v", err) + } + if _, err := tx2.Exec("DELETE FROM users WHERE name = ?", "bob"); err != nil { + t.Fatalf("exec: %v", err) + } + if err := tx2.Rollback(); err != nil { + t.Fatalf("rollback: %v", err) + } + + if err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count); err != nil { + t.Fatalf("scan: %v", err) + } + if count != 2 { + t.Errorf("expected 2 rows after rollback, got %d", count) + } +} + +func TestDriver_SlowQueryDetection(t *testing.T) { + db, buf := guardedWithBuffer(t, WithSlowQueryThreshold(1*time.Nanosecond)) + + rows, err := db.Query("SELECT id FROM users WHERE id = ?", 1) + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if !strings.Contains(buf.String(), "slow-query") { + t.Errorf("expected slow-query with 1ns threshold, got: %q", buf.String()) + } +} + +func TestDriver_NoSlowQueryBelowThreshold(t *testing.T) { + db, buf := guardedWithBuffer(t, WithSlowQueryThreshold(1*time.Hour)) + + rows, err := db.Query("SELECT id FROM users WHERE id = ?", 1) + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if strings.Contains(buf.String(), "slow-query") { + t.Errorf("did not expect slow-query, got: %q", buf.String()) + } +} + +func TestDriver_CustomAnalyzer(t *testing.T) { + db, buf := guardedWithBuffer(t, WithAnalyzer(analyzer.New(analyzer.CheckDeleteWithoutWhere))) + + rows, err := db.Query("SELECT * FROM users") + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if strings.Contains(buf.String(), "select-star") { + t.Errorf("did not expect select-star with custom analyzer, got: %q", buf.String()) + } +} + +func TestDriver_N1Detection(t *testing.T) { + db, buf := guardedWithBuffer(t, WithN1Detection(3, time.Second)) + + for i := range 5 { + row := db.QueryRow("SELECT name FROM users WHERE id = ?", i) + var name string + _ = row.Scan(&name) + } + + if !strings.Contains(buf.String(), "n-plus-one") { + t.Errorf("expected n-plus-one warning, got: %q", buf.String()) + } +} + +func TestRegister_DuplicateNameErrors(t *testing.T) { + name := fmt.Sprintf("sqlguard-dup-%d", driverSeq.Add(1)) + if err := Register(name, "sqlite3"); err != nil { + t.Fatalf("first Register: %v", err) + } + if err := Register(name, "sqlite3"); err == nil { + t.Error("expected error registering duplicate name") + } +} + +func TestRegister_UnknownBaseDriverErrors(t *testing.T) { + if err := Register("sqlguard-x", "no-such-driver"); err == nil { + t.Error("expected error for unknown base driver") + } +} diff --git a/middleware/guard.go b/middleware/guard.go new file mode 100644 index 0000000..e29e92b --- /dev/null +++ b/middleware/guard.go @@ -0,0 +1,137 @@ +package middleware + +import ( + "fmt" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +// Guard is the single shared analysis core. It runs the configured analyzer +// and reporter against every executed query, measures latency, and feeds the +// N+1 tracker. Every interception point — the database/sql driver chain and +// every out-of-tree integration (pgxguard, …) — drives the same Guard so +// analysis logic, redaction, fingerprinting, N+1, the parser seam and config +// live here exactly once. Integrations must build on Guard rather than +// re-implementing check/latency by hand (that path silently loses +// redaction-by-default and fingerprints). +// +// A Guard is safe for concurrent use. +type Guard struct { + opts options + tracker *QueryTracker + deduper *deduper + cache *analysisCache +} + +// NewGuard builds a Guard from the given options. +func NewGuard(opts ...Option) *Guard { + o := defaultOptions() + for _, opt := range opts { + opt(&o) + } + if o.parser != nil { + o.analyzer = o.analyzer.WithParser(o.parser) + } + g := &Guard{opts: o, deduper: newDeduper(o.dedupWindow)} + if o.cacheSize > 0 { + g.cache = newAnalysisCache(o.cacheSize) + } + if o.enableN1 { + g.tracker = NewQueryTracker(o.n1Threshold, o.n1Window, func(results []analyzer.Result) { + o.reporter.Report(results) + }) + } + return g +} + +// Analyzer returns the configured analyzer. Useful for integrations that need +// the canonical redact/fingerprint helpers without re-deriving policy. +func (g *Guard) Analyzer() *analyzer.Analyzer { return g.opts.analyzer } + +// Check runs the static rules against the query and feeds the N+1 tracker. +func (g *Guard) Check(query string) { + results := g.analyze(query) + if len(results) > 0 { + g.report(results) + } + if g.tracker != nil { + g.tracker.Track(query) + } +} + +// analyze returns the static findings for query, memoizing per distinct query +// string so a recurring query is parsed and rule-checked once. The cache is +// keyed on the exact query string because a few rules read literal-derived +// facts the fingerprint folds away (see analysisCache). The returned slice may +// be shared from the cache and must be treated as read-only. +func (g *Guard) analyze(query string) []analyzer.Result { + if g.cache == nil { + return g.opts.analyzer.Analyze(query) + } + if cached, ok := g.cache.get(query); ok { + return cached + } + results := g.opts.analyzer.Analyze(query) + g.cache.put(query, results) + return results +} + +// report emits static findings, suppressing repeats of the same +// (fingerprint, rule) within the dedup window so a recurring query does not +// flood the reporter. results may be a shared cache entry, so it is never +// mutated; kept is allocated only when a finding actually passes dedup (rare +// after the first occurrence, and never for the common no-findings case). +func (g *Guard) report(results []analyzer.Result) { + now := time.Now() + var kept []analyzer.Result + for _, r := range results { + if g.deduper.allow(r.Fingerprint, r.RuleName, now) { + kept = append(kept, r) + } + } + if len(kept) > 0 { + g.opts.reporter.Report(kept) + } +} + +// CheckLatency reports a slow-query finding if elapsed exceeds the threshold. +func (g *Guard) CheckLatency(query string, elapsed time.Duration) { + if elapsed >= g.opts.slowThreshold { + display, fingerprint := g.opts.analyzer.PrepareQuery(query) + g.opts.reporter.Report([]analyzer.Result{{ + RuleName: "slow-query", + Severity: analyzer.SeverityWarning, + Query: display, + Fingerprint: fingerprint, + Message: fmt.Sprintf("Query took %s (threshold: %s)", elapsed.Round(time.Millisecond), g.opts.slowThreshold), + Suggestion: "Consider adding indexes or optimizing the query.", + }}) + } +} + +// Observe analyzes a query and times its execution. The returned function +// must be called once the underlying operation completes; it records latency +// only when err is nil (a failed query's latency is meaningless). It is +// designed for split start/end interception points such as pgx tracers: +// call Observe in the start hook, stash the closure, invoke it in the end +// hook with the operation error. +func (g *Guard) Observe(query string) func(err error) { + g.Check(query) + start := time.Now() + return func(err error) { + if err == nil { + g.CheckLatency(query, time.Since(start)) + } + } +} + +// ResetN1 clears the N+1 tracker's accumulated state. Call this at a +// per-request boundary (e.g. end of an HTTP handler) so N+1 detection is +// scoped to a single logical unit of work rather than process-global. It is +// a no-op when N+1 detection is not enabled. +func (g *Guard) ResetN1() { + if g.tracker != nil { + g.tracker.Reset() + } +} diff --git a/middleware/n_plus_one.go b/middleware/n_plus_one.go new file mode 100644 index 0000000..38bcaba --- /dev/null +++ b/middleware/n_plus_one.go @@ -0,0 +1,124 @@ +package middleware + +import ( + "fmt" + "sync" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +// normalizeQuery is the N+1 grouping key: the canonical, literal-free query +// fingerprint. It delegates to analyzer.Fingerprint so there is a single +// normalizer in the codebase (the comment/string-literal-aware one) rather +// than a second, subtly different regex pass. +func normalizeQuery(query string) string { + return analyzer.Fingerprint(query) +} + +type queryRecord struct { + count int + firstSeen time.Time + reported bool +} + +// QueryTracker detects N+1 query patterns at runtime. +// It tracks normalized query patterns and flags when the same pattern +// is executed more than a threshold number of times within a time window. +type QueryTracker struct { + mu sync.Mutex + queries map[string]*queryRecord + threshold int + window time.Duration + maxKeys int + reporter func(results []analyzer.Result) +} + +// NewQueryTracker creates a tracker that flags when the same query pattern +// appears more than threshold times within the given window. +func NewQueryTracker(threshold int, window time.Duration, reportFn func([]analyzer.Result)) *QueryTracker { + return &QueryTracker{ + queries: make(map[string]*queryRecord), + threshold: threshold, + window: window, + maxKeys: 10000, + reporter: reportFn, + } +} + +// Track records a query execution and reports if N+1 pattern is detected. +func (qt *QueryTracker) Track(query string) { + normalized := normalizeQuery(query) + + qt.mu.Lock() + + now := time.Now() + + // Bound memory: when the map is at capacity, evict expired entries first. + if len(qt.queries) >= qt.maxKeys { + qt.evictExpired(now) + } + + rec, exists := qt.queries[normalized] + if !exists { + // A new key past the cap (eviction freed nothing — every entry is still + // in-window) is dropped rather than grown without bound: a rare, + // harmless false negative under pathological query-shape cardinality. + // Already-tracked keys (the exists path below) are always honored, so + // in-flight N+1 detection is never lost. + if len(qt.queries) >= qt.maxKeys { + qt.mu.Unlock() + return + } + qt.queries[normalized] = &queryRecord{count: 1, firstSeen: now} + qt.mu.Unlock() + return + } + + // If outside the window, reset + if now.Sub(rec.firstSeen) > qt.window { + rec.count = 1 + rec.firstSeen = now + rec.reported = false + qt.mu.Unlock() + return + } + + rec.count++ + + shouldReport := rec.count >= qt.threshold && !rec.reported + if shouldReport { + rec.reported = true + } + + // Release lock before calling reporter to avoid holding mutex during I/O + count := rec.count + qt.mu.Unlock() + + if shouldReport { + qt.reporter([]analyzer.Result{{ + RuleName: "n-plus-one", + Severity: analyzer.SeverityWarning, + Query: normalized, + Fingerprint: normalized, + Message: fmt.Sprintf("Possible N+1 query detected: same pattern executed %d times in %s", count, qt.window), + Suggestion: "Consider using a JOIN or IN clause to batch these queries.", + }}) + } +} + +// evictExpired removes entries older than the window. Must be called with mutex held. +func (qt *QueryTracker) evictExpired(now time.Time) { + for key, rec := range qt.queries { + if now.Sub(rec.firstSeen) > qt.window { + delete(qt.queries, key) + } + } +} + +// Reset clears all tracked queries. Call this between requests. +func (qt *QueryTracker) Reset() { + qt.mu.Lock() + defer qt.mu.Unlock() + qt.queries = make(map[string]*queryRecord) +} diff --git a/middleware/n_plus_one_test.go b/middleware/n_plus_one_test.go new file mode 100644 index 0000000..141b2c1 --- /dev/null +++ b/middleware/n_plus_one_test.go @@ -0,0 +1,168 @@ +package middleware + +import ( + "fmt" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func TestNormalizeQuery(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"numbers", "SELECT * FROM users WHERE id = 42", "SELECT * FROM users WHERE id = ?"}, + {"strings", "SELECT * FROM users WHERE name = 'alice'", "SELECT * FROM users WHERE name = ?"}, + {"mixed", "SELECT * FROM users WHERE id = 1 AND name = 'bob'", "SELECT * FROM users WHERE id = ? AND name = ?"}, + {"no literals", "SELECT * FROM users WHERE id = ?", "SELECT * FROM users WHERE id = ?"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeQuery(tt.input) + if got != tt.want { + t.Errorf("normalizeQuery(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestQueryTracker_DetectsN1(t *testing.T) { + var reported []analyzer.Result + tracker := NewQueryTracker(3, 5*time.Second, func(results []analyzer.Result) { + reported = append(reported, results...) + }) + + // Same pattern 3 times should trigger + tracker.Track("SELECT * FROM orders WHERE user_id = 1") + tracker.Track("SELECT * FROM orders WHERE user_id = 2") + tracker.Track("SELECT * FROM orders WHERE user_id = 3") + + if len(reported) != 1 { + t.Fatalf("expected 1 N+1 report, got %d", len(reported)) + } + if reported[0].RuleName != "n-plus-one" { + t.Errorf("expected rule n-plus-one, got %s", reported[0].RuleName) + } +} + +func TestQueryTracker_DifferentPatterns(t *testing.T) { + var reported []analyzer.Result + tracker := NewQueryTracker(3, 5*time.Second, func(results []analyzer.Result) { + reported = append(reported, results...) + }) + + // Different patterns should not trigger + tracker.Track("SELECT * FROM orders WHERE user_id = 1") + tracker.Track("SELECT * FROM users WHERE id = 1") + tracker.Track("SELECT * FROM products WHERE id = 1") + + if len(reported) != 0 { + t.Errorf("expected no reports for different patterns, got %d", len(reported)) + } +} + +func TestQueryTracker_BelowThreshold(t *testing.T) { + var reported []analyzer.Result + tracker := NewQueryTracker(5, 5*time.Second, func(results []analyzer.Result) { + reported = append(reported, results...) + }) + + // Only 3 of same pattern, threshold is 5 + tracker.Track("SELECT * FROM orders WHERE user_id = 1") + tracker.Track("SELECT * FROM orders WHERE user_id = 2") + tracker.Track("SELECT * FROM orders WHERE user_id = 3") + + if len(reported) != 0 { + t.Errorf("expected no reports below threshold, got %d", len(reported)) + } +} + +func TestQueryTracker_ReportsOnlyOnce(t *testing.T) { + var reported []analyzer.Result + tracker := NewQueryTracker(2, 5*time.Second, func(results []analyzer.Result) { + reported = append(reported, results...) + }) + + tracker.Track("SELECT * FROM orders WHERE user_id = 1") + tracker.Track("SELECT * FROM orders WHERE user_id = 2") + tracker.Track("SELECT * FROM orders WHERE user_id = 3") + tracker.Track("SELECT * FROM orders WHERE user_id = 4") + + if len(reported) != 1 { + t.Errorf("expected exactly 1 report (not per-query), got %d", len(reported)) + } +} + +func TestQueryTracker_Reset(t *testing.T) { + var reported []analyzer.Result + tracker := NewQueryTracker(2, 5*time.Second, func(results []analyzer.Result) { + reported = append(reported, results...) + }) + + tracker.Track("SELECT * FROM orders WHERE user_id = 1") + tracker.Reset() + tracker.Track("SELECT * FROM orders WHERE user_id = 2") + + // After reset, count should restart + if len(reported) != 0 { + t.Errorf("expected no reports after reset, got %d", len(reported)) + } +} + +// N+1 detection through the driver path is covered by +// TestDriver_N1Detection in driver_test.go. QueryTracker.Reset is +// exercised directly above; per-request reset is no longer exposed on +// the *sql.DB returned by the driver wrapper. + +func TestQueryTracker_BoundedAtMaxKeys(t *testing.T) { + qt := &QueryTracker{ + queries: make(map[string]*queryRecord), + threshold: 1000, // high so nothing reports + window: time.Hour, // long so nothing expires (eviction frees nothing) + maxKeys: 3, + reporter: func([]analyzer.Result) {}, + } + + // Distinct query *shapes* (distinct column names → distinct fingerprints), + // far more than maxKeys, all in-window. The map must stay capped. + for i := range 50 { + qt.Track(fmt.Sprintf("SELECT col%d FROM t WHERE id = ?", i)) + } + + if len(qt.queries) > qt.maxKeys { + t.Errorf("tracker map grew past maxKeys: %d > %d", len(qt.queries), qt.maxKeys) + } +} + +func TestQueryTracker_TrackedKeyHonoredAtCap(t *testing.T) { + var reports int + qt := &QueryTracker{ + queries: make(map[string]*queryRecord), + threshold: 3, + window: time.Hour, + maxKeys: 2, + reporter: func([]analyzer.Result) { reports++ }, + } + + // "cola" gets tracked to count 2 (below threshold), then the cap fills. + qt.Track("SELECT cola FROM t") + qt.Track("SELECT cola FROM t") + qt.Track("SELECT colb FROM t") // map now {cola, colb}, at cap + + // A brand-new key at the cap is dropped (map stays bounded)... + qt.Track("SELECT colc FROM t") + if len(qt.queries) > qt.maxKeys { + t.Fatalf("map exceeded cap: %d", len(qt.queries)) + } + + // ...but the already-tracked "cola" still increments to threshold and fires + // exactly once — in-flight detection is never lost to the cap. + qt.Track("SELECT cola FROM t") + if reports != 1 { + t.Errorf("expected the tracked key to still reach threshold and report once, got %d", reports) + } +} diff --git a/middleware/options.go b/middleware/options.go new file mode 100644 index 0000000..8285bf6 --- /dev/null +++ b/middleware/options.go @@ -0,0 +1,98 @@ +package middleware + +import ( + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/reporter" +) + +type options struct { + slowThreshold time.Duration + reporter reporter.Reporter + analyzer *analyzer.Analyzer + parser analyzer.Parser + n1Threshold int + n1Window time.Duration + enableN1 bool + dedupWindow time.Duration + cacheSize int +} + +// Option configures the runtime guard. +type Option func(*options) + +// WithSlowQueryThreshold sets the duration above which a query is flagged as slow. +// Default is 200ms. +func WithSlowQueryThreshold(d time.Duration) Option { + return func(o *options) { + o.slowThreshold = d + } +} + +// WithReporter sets a custom reporter. Default is ConsoleReporter. +func WithReporter(r reporter.Reporter) Option { + return func(o *options) { + o.reporter = r + } +} + +// WithAnalyzer sets a custom analyzer. Default is analyzer.Default(). +func WithAnalyzer(a *analyzer.Analyzer) Option { + return func(o *options) { + o.analyzer = a + } +} + +// WithParser sets the SQL parser the analyzer uses. Default is the +// zero-dependency analyzer.FallbackParser. Pass a real dialect parser +// (e.g. from sqlguard/parsers/pgparser) for exact, structural analysis. +func WithParser(p analyzer.Parser) Option { + return func(o *options) { + o.parser = p + } +} + +// WithN1Detection enables N+1 query detection with the given threshold and window. +// When the same query pattern is executed threshold times within window, a warning is reported. +func WithN1Detection(threshold int, window time.Duration) Option { + return func(o *options) { + o.enableN1 = true + o.n1Threshold = threshold + o.n1Window = window + } +} + +// WithFindingDedup sets the window within which a repeated static finding — +// the same rule firing on the same canonical query shape — is reported at most +// once. This keeps a recurring query (or a prepared statement run in a loop) +// from flooding the log sink with the same warning on every execution. The +// default is one minute. Pass 0 to disable dedup and report every occurrence +// (the legacy behavior). Slow-query and N+1 findings have their own emission +// policy and are unaffected. +func WithFindingDedup(window time.Duration) Option { + return func(o *options) { + o.dedupWindow = window + } +} + +// WithAnalysisCacheSize sets the maximum number of distinct query strings whose +// analysis results are memoized, so a recurring query is parsed and rule-checked +// once instead of on every execution. The cache is an LRU keyed on the exact +// query string (correct even for the literal-sensitive rules). Default is 1024. +// Pass 0 to disable the cache and analyze every query. +func WithAnalysisCacheSize(n int) Option { + return func(o *options) { + o.cacheSize = n + } +} + +func defaultOptions() options { + return options{ + slowThreshold: 200 * time.Millisecond, + reporter: reporter.NewConsoleReporter(), + analyzer: analyzer.Default(), + dedupWindow: time.Minute, + cacheSize: 1024, + } +} diff --git a/parsers/mysqlparser/go.mod b/parsers/mysqlparser/go.mod new file mode 100644 index 0000000..6360f11 --- /dev/null +++ b/parsers/mysqlparser/go.mod @@ -0,0 +1,9 @@ +module github.com/KARTIKrocks/sqlguard/parsers/mysqlparser + +go 1.26 + +require github.com/KARTIKrocks/sqlguard v0.0.0 + +require github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/parsers/mysqlparser/go.sum b/parsers/mysqlparser/go.sum new file mode 100644 index 0000000..6354a21 --- /dev/null +++ b/parsers/mysqlparser/go.sum @@ -0,0 +1,2 @@ +github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 h1:zzrxE1FKn5ryBNl9eKOeqQ58Y/Qpo3Q9QNxKHX5uzzQ= +github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2/go.mod h1:hzfGeIUDq/j97IG+FhNqkowIyEcD88LrW6fyU3K3WqY= diff --git a/parsers/mysqlparser/mysqlparser.go b/parsers/mysqlparser/mysqlparser.go new file mode 100644 index 0000000..2df5792 --- /dev/null +++ b/parsers/mysqlparser/mysqlparser.go @@ -0,0 +1,136 @@ +// Package mysqlparser is an optional sqlguard Parser backed by a real +// MySQL grammar (github.com/xwb1989/sqlparser — a pure-Go, no-cgo, +// lightweight Vitess-derived MySQL parser). +// +// It produces exact, structural answers for the false-positive-prone facts +// (statement kind, WHERE/LIMIT/ORDER BY/FROM presence, SELECT *, explicit +// INSERT columns) instead of regex guesses. SQL the grammar rejects — +// CTEs it doesn't support, dynamic fragments, dialect extensions — +// transparently degrades to sqlguard's zero-dependency FallbackParser, so +// analysis never breaks the caller's query path. +// +// Usage: +// +// sqlguard.Register("sqlguard-mysql", "mysql", middleware.WithParser(mysqlparser.New())) +// db, _ := sql.Open("sqlguard-mysql", dsn) +package mysqlparser + +import ( + "strconv" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/xwb1989/sqlparser" +) + +// Parser implements analyzer.Parser using a MySQL grammar. +type Parser struct { + fallback analyzer.Parser +} + +// New returns a MySQL-dialect Parser that falls back to the +// zero-dependency FallbackParser on parse failure. +func New() *Parser { + return &Parser{fallback: analyzer.NewFallbackParser()} +} + +var _ analyzer.Parser = (*Parser)(nil) + +// Parse implements analyzer.Parser. It never returns an error: unparseable +// SQL yields the fallback parser's best-effort Statement (Exact=false). +func (p *Parser) Parse(sql string) (*analyzer.Statement, error) { + // Baseline from the fallback. It detects the literal/text-level fields + // (leading-wildcard LIKE, non-sargable predicates, unsafe NOT NULL adds) + // that the AST loses after parsing, so those fields are kept; only + // structural fields are overwritten. + st, _ := p.fallback.Parse(sql) + if st == nil { + st = &analyzer.Statement{Raw: sql} + } + + ast, err := sqlparser.Parse(sql) + if err != nil || ast == nil { + return st, nil // keep best-effort fallback Statement + } + + st.Kind = analyzer.StmtOther + st.HasWhere = false + st.HasLimit = false + st.HasOrderBy = false + st.HasFrom = false + st.SelectStar = false + st.SelectDistinct = false + st.OffsetValue = 0 + st.InsertColumnsListed = false + + switch n := ast.(type) { + case *sqlparser.Select: + st.Kind = analyzer.StmtSelect + st.HasWhere = n.Where != nil + st.HasLimit = n.Limit != nil + st.HasOrderBy = len(n.OrderBy) > 0 + st.HasFrom = hasRealFrom(n.From) + st.SelectDistinct = n.Distinct != "" + st.OffsetValue = offsetValue(n.Limit) + for _, e := range n.SelectExprs { + if _, ok := e.(*sqlparser.StarExpr); ok { // '*' or 'table.*' + st.SelectStar = true + } + } + case *sqlparser.Delete: + st.Kind = analyzer.StmtDelete + st.HasWhere = n.Where != nil + st.HasLimit = n.Limit != nil + st.HasOrderBy = len(n.OrderBy) > 0 + st.OffsetValue = offsetValue(n.Limit) + case *sqlparser.Update: + st.Kind = analyzer.StmtUpdate + st.HasWhere = n.Where != nil + st.HasLimit = n.Limit != nil + st.HasOrderBy = len(n.OrderBy) > 0 + st.OffsetValue = offsetValue(n.Limit) + case *sqlparser.Insert: + st.Kind = analyzer.StmtInsert + st.InsertColumnsListed = len(n.Columns) > 0 + } + + st.Exact = true + return st, nil +} + +// offsetValue extracts a literal OFFSET as an int, or 0 when there is no limit +// clause, no offset, or a non-literal (parameterized) offset — matching the +// large-offset rule's contract that only statically-known offsets are flagged. +// Covers both "LIMIT count OFFSET n" and MySQL's "LIMIT n, count" (the parser +// puts n in Offset for both). +func offsetValue(lim *sqlparser.Limit) int { + if lim == nil || lim.Offset == nil { + return 0 + } + v, ok := lim.Offset.(*sqlparser.SQLVal) + if !ok || v.Type != sqlparser.IntVal { + return 0 + } + n, err := strconv.Atoi(string(v.Val)) + if err != nil || n < 0 { + return 0 + } + return n +} + +// hasRealFrom reports whether a FROM clause references a real table, not the +// implicit "dual" the parser injects for FROM-less selects like SELECT 1. +func hasRealFrom(from sqlparser.TableExprs) bool { + for _, te := range from { + ate, ok := te.(*sqlparser.AliasedTableExpr) + if !ok { + return true // join / subquery / etc. — a real source + } + if tn, ok := ate.Expr.(sqlparser.TableName); ok { + if tn.Name.String() == "dual" { + continue + } + } + return true + } + return false +} diff --git a/parsers/mysqlparser/mysqlparser_test.go b/parsers/mysqlparser/mysqlparser_test.go new file mode 100644 index 0000000..f9977c2 --- /dev/null +++ b/parsers/mysqlparser/mysqlparser_test.go @@ -0,0 +1,136 @@ +package mysqlparser + +import ( + "testing" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func TestParser_ExactStructuralFacts(t *testing.T) { + p := New() + tests := []struct { + name string + sql string + want analyzer.Statement + }{ + { + name: "delete without where", + sql: "DELETE FROM users", + want: analyzer.Statement{Kind: analyzer.StmtDelete, Exact: true}, + }, + { + name: "delete with where", + sql: "DELETE FROM users WHERE id = 1", + want: analyzer.Statement{Kind: analyzer.StmtDelete, HasWhere: true, Exact: true}, + }, + { + name: "update without where", + sql: "UPDATE users SET name = 'x'", + want: analyzer.Statement{Kind: analyzer.StmtUpdate, Exact: true}, + }, + { + name: "select star with from", + sql: "SELECT * FROM users", + want: analyzer.Statement{Kind: analyzer.StmtSelect, SelectStar: true, HasFrom: true, Exact: true}, + }, + { + name: "qualified star", + sql: "SELECT u.* FROM users u", + want: analyzer.Statement{Kind: analyzer.StmtSelect, SelectStar: true, HasFrom: true, Exact: true}, + }, + { + name: "count star is not select star", + sql: "SELECT COUNT(*) FROM users WHERE id = 1", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasWhere: true, Exact: true}, + }, + { + name: "select 1 has no real from", + sql: "SELECT 1", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: false, Exact: true}, + }, + { + name: "insert with columns", + sql: "INSERT INTO users (name) VALUES ('a')", + want: analyzer.Statement{Kind: analyzer.StmtInsert, InsertColumnsListed: true, Exact: true}, + }, + { + name: "insert without columns", + sql: "INSERT INTO users VALUES ('a')", + want: analyzer.Statement{Kind: analyzer.StmtInsert, Exact: true}, + }, + { + name: "order by without limit", + sql: "SELECT id FROM users ORDER BY name", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasOrderBy: true, Exact: true}, + }, + { + name: "select distinct", + sql: "SELECT DISTINCT name FROM users", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, SelectDistinct: true, Exact: true}, + }, + { + name: "count distinct is not select distinct", + sql: "SELECT COUNT(DISTINCT id) FROM users WHERE id = 1", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasWhere: true, Exact: true}, + }, + { + name: "literal offset (OFFSET form)", + sql: "SELECT id FROM users WHERE x = 1 ORDER BY id LIMIT 10 OFFSET 5000", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasWhere: true, HasOrderBy: true, HasLimit: true, OffsetValue: 5000, Exact: true}, + }, + { + name: "literal offset (LIMIT n, count form)", + sql: "SELECT id FROM users WHERE x = 1 LIMIT 5000, 10", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasWhere: true, HasLimit: true, OffsetValue: 5000, Exact: true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + st, err := p.Parse(tt.sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if st.Kind != tt.want.Kind || + st.HasWhere != tt.want.HasWhere || + st.HasLimit != tt.want.HasLimit || + st.HasOrderBy != tt.want.HasOrderBy || + st.HasFrom != tt.want.HasFrom || + st.SelectStar != tt.want.SelectStar || + st.SelectDistinct != tt.want.SelectDistinct || + st.OffsetValue != tt.want.OffsetValue || + st.InsertColumnsListed != tt.want.InsertColumnsListed || + st.Exact != tt.want.Exact { + t.Errorf("Parse(%q)\n got: %+v\nwant: %+v", tt.sql, *st, tt.want) + } + }) + } +} + +func TestParser_FallsBackOnUnparseable(t *testing.T) { + p := New() + // Postgres-style placeholders the MySQL grammar rejects must not error + // and must come back as a best-effort (non-exact) Statement. + st, err := p.Parse("SELECT * FROM t WHERE id = $1") + if err != nil { + t.Fatalf("fallback path must not error: %v", err) + } + if st == nil || st.Exact { + t.Errorf("expected non-nil, non-exact fallback statement, got %+v", st) + } +} + +func TestParser_IntegratesWithAnalyzer(t *testing.T) { + a := analyzer.Default().WithParser(New()) + + got := a.Analyze("UPDATE users SET active = 0 /* WHERE id = 1 */") + found := false + for _, r := range got { + if r.RuleName == "update-without-where" { + found = true + } + } + if !found { + t.Errorf("expected update-without-where (WHERE only in comment), got %+v", got) + } +} diff --git a/parsers/pgparser/go.mod b/parsers/pgparser/go.mod new file mode 100644 index 0000000..88cf8cb --- /dev/null +++ b/parsers/pgparser/go.mod @@ -0,0 +1,35 @@ +module github.com/KARTIKrocks/sqlguard/parsers/pgparser + +go 1.26 + +require github.com/KARTIKrocks/sqlguard v0.0.0 + +require ( + github.com/auxten/postgresql-parser v1.0.1 + github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054 // indirect + github.com/cockroachdb/apd v1.1.1-0.20181017181144-bced77f817b4 // indirect + github.com/cockroachdb/errors v1.8.2 // indirect + github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f // indirect + github.com/cockroachdb/redact v1.0.8 // indirect + github.com/cockroachdb/sentry-go v0.6.1-cockroachdb.2 // indirect + github.com/dustin/go-humanize v1.0.0 // indirect + github.com/getsentry/raven-go v0.2.0 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/protobuf v1.4.3 // indirect + github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect + github.com/konsorten/go-windows-terminal-sequences v1.0.3 // indirect + github.com/kr/pretty v0.2.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/lib/pq v1.9.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/sirupsen/logrus v1.6.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e // indirect + golang.org/x/text v0.3.4 // indirect + google.golang.org/genproto v0.0.0-20200911024640-645f7a48b24f // indirect + google.golang.org/grpc v1.33.1 // indirect + google.golang.org/protobuf v1.25.0 // indirect +) + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/parsers/pgparser/go.sum b/parsers/pgparser/go.sum new file mode 100644 index 0000000..c8df32f --- /dev/null +++ b/parsers/pgparser/go.sum @@ -0,0 +1,347 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/CloudyKit/fastprinter v0.0.0-20170127035650-74b38d55f37a/go.mod h1:EFZQ978U7x8IRnstaskI3IysnWY5Ao3QgZUKOXlsAdw= +github.com/CloudyKit/jet v2.1.3-0.20180809161101-62edd43e4f88+incompatible/go.mod h1:HPYO+50pSWkPoj9Q/eq0aRGByCL6ScRlUmiEX5Zgm+w= +github.com/Joker/hpp v1.0.0/go.mod h1:8x5n+M1Hp5hC0g8okX3sR3vFQwynaX/UgSOM9MeBKzY= +github.com/Joker/jade v1.0.1-0.20190614124447-d475f43051e7/go.mod h1:6E6s8o2AE4KhCrqr6GRJjdC/gNfTdxkIXvuGZZda2VM= +github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= +github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= +github.com/auxten/postgresql-parser v1.0.1 h1:x+qiEHAe2cH55Kly64dWh4tGvUKEQwMmJgma7a1kbj4= +github.com/auxten/postgresql-parser v1.0.1/go.mod h1:Nf27dtv8EU1C+xNkoLD3zEwfgJfDDVi8Zl86gznxPvI= +github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054 h1:uH66TXeswKn5PW5zdZ39xEwfS9an067BirqA+P4QaLI= +github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/cockroachdb/apd v1.1.1-0.20181017181144-bced77f817b4 h1:XWEdfNxDkZI3DXXlpo0hZJ1xdaH/f3CKuZpk93pS/Y0= +github.com/cockroachdb/apd v1.1.1-0.20181017181144-bced77f817b4/go.mod h1:mdGz2CnkJrefFtlLevmE7JpL2zB9tKofya/6w7wWzNA= +github.com/cockroachdb/datadriven v1.0.0/go.mod h1:5Ib8Meh+jk1RlHIXej6Pzevx/NLlNvQB9pmSBZErGA4= +github.com/cockroachdb/errors v1.6.1/go.mod h1:tm6FTP5G81vwJ5lC0SizQo374JNCOPrHyXGitRJoDqM= +github.com/cockroachdb/errors v1.8.2 h1:rnnWK9Nn5kEMOGz9531HuDx/FOleL4NVH20VsDexVC8= +github.com/cockroachdb/errors v1.8.2/go.mod h1:qGwQn6JmZ+oMjuLwjWzUNqblqk0xl4CVV3SQbGwK7Ac= +github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f h1:o/kfcElHqOiXqcou5a3rIlMc7oJbMQkeLk0VQJ7zgqY= +github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f/go.mod h1:i/u985jwjWRlyHXQbwatDASoW0RMlZ/3i9yJHE2xLkI= +github.com/cockroachdb/redact v1.0.8 h1:8QG/764wK+vmEYoOlfobpe12EQcS81ukx/a4hdVMxNw= +github.com/cockroachdb/redact v1.0.8/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg= +github.com/cockroachdb/sentry-go v0.6.1-cockroachdb.2 h1:IKgmqgMQlVJIZj19CdocBeSfSaiCbEBZGKODaixqtHM= +github.com/cockroachdb/sentry-go v0.6.1-cockroachdb.2/go.mod h1:8BT+cPK6xvFOcRlk0R8eg+OTkcqI6baNH4xAkpiYVvQ= +github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0/go.mod h1:4Zcjuz89kmFXt9morQgcfYZAYZ5n8WHjt81YYWIwtTM= +github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= +github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= +github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385/go.mod h1:0vRUJqYpeSZifjYj7uP3BG/gKcuzL9xWVV/Y+cK33KM= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/etcd-io/bbolt v1.3.3/go.mod h1:ZF2nL25h33cCyBtcyWeZ2/I3HQOfTP+0PIEvHjkjCrw= +github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod h1:duJ4Jxv5lDcvg4QuQr0oowTf7dz4/CR8NtyCooz9HL8= +github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= +github.com/flosch/pongo2 v0.0.0-20190707114632-bbf5a6c351f4/go.mod h1:T9YF2M40nIgbVgp3rreNmTged+9HrbNTIQf1PsaIiTA= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc= +github.com/getsentry/raven-go v0.2.0 h1:no+xWJRb5ZI7eE8TWgIq1jLulQiIoLG0IfYxv5JYMGs= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s= +github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= +github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w= +github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= +github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= +github.com/gogo/googleapis v0.0.0-20180223154316-0cd9801be74a/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= +github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/gogo/status v1.1.0/go.mod h1:BFv9nrluPLmrS0EmGVvLaPNmRosr9KapBYd5/hpY1WM= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/gomodule/redigo v1.7.1-0.20190724094224-574c33c3df38/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k= +github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/hydrogen18/memlistener v0.0.0-20141126152155-54553eb933fb/go.mod h1:qEIFzExnS6016fRpRfxrExeVn2gbClQA99gQhnIcdhE= +github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/iris-contrib/blackfriday v2.0.0+incompatible/go.mod h1:UzZ2bDEoaSGPbkg6SAB4att1aAwTmVIx/5gCVqeyUdI= +github.com/iris-contrib/go.uuid v2.0.0+incompatible/go.mod h1:iz2lgM/1UnEf1kP0L/+fafWORmlnuysV2EMP8MW+qe0= +github.com/iris-contrib/i18n v0.0.0-20171121225848-987a633949d0/go.mod h1:pMCz62A0xJL6I+umB2YTlFRwWXaDFA0jy+5HzGiJjqI= +github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/juju/errors v0.0.0-20181118221551-089d3ea4e4d5/go.mod h1:W54LbzXuIE0boCoNJfwqpmkKJ1O4TCTZMetAt6jGk7Q= +github.com/juju/loggo v0.0.0-20180524022052-584905176618/go.mod h1:vgyd7OREkbtVEN/8IXZe5Ooef3LQePvuBm9UWj6ZL8U= +github.com/juju/testing v0.0.0-20180920084828-472a3e8b2073/go.mod h1:63prj8cnj0tU0S9OHjGJn+b1h0ZghCndfnbQolrYTwA= +github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= +github.com/kataras/golog v0.0.9/go.mod h1:12HJgwBIZFNGL0EJnMRhmvGA0PQGx8VFwrZtM4CqbAk= +github.com/kataras/iris/v12 v12.0.1/go.mod h1:udK4vLQKkdDqMGJJVd/msuMtN6hpYJhg/lSzuxjhO+U= +github.com/kataras/neffos v0.0.10/go.mod h1:ZYmJC07hQPW67eKuzlfY7SO3bC0mw83A3j6im82hfqw= +github.com/kataras/pio v0.0.0-20190103105442-ea782b38602d/go.mod h1:NV88laa9UiiDuX9AhMbDPkGYSPugBOV6yTZB1l2K9Z0= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/klauspost/compress v1.9.0/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= +github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8= +github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= +github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/labstack/echo/v4 v4.1.11/go.mod h1:i541M3Fj6f76NZtHSj7TXnyM8n2gaodfvfxNnFqi74g= +github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= +github.com/lib/pq v1.9.0 h1:L8nSXQQzAYByakOFMTwpjRoHsMJklur4Gi59b6VivR8= +github.com/lib/pq v1.9.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= +github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= +github.com/mediocregopher/mediocre-go-lib v0.0.0-20181029021733-cb65787f37ed/go.mod h1:dSsfyI2zABAdhcbvkXqgxOxrCsbYeHCPgrZkku60dSg= +github.com/mediocregopher/radix/v3 v3.3.0/go.mod h1:EmfVyvspXz1uZEyPBMyGK+kjWiKQGvsUt6O3Pj+LDCQ= +github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= +github.com/nats-io/nats.go v1.8.1/go.mod h1:BrFz9vVn0fU3AcH9Vn4Kd7W0NpJ651tD5omQ3M8LwxM= +github.com/nats-io/nkeys v0.0.2/go.mod h1:dab7URMsZm6Z/jp9Z5UGa87Uutgc2mVpXLC4B7TDb/4= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 h1:q2e307iGHPdTGp0hoxKjt1H5pDo6utceo3dQVK3I5XQ= +github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5/go.mod h1:jvVRKCrJTQWu0XVbaOlby/2lO20uSCHEMzzplHXte1o= +github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= +github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= +github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= +github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= +github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= +github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= +github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= +github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= +github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= +github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBnvPM1Su9w= +github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= +github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= +github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI= +github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg= +github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= +github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190327091125-710a502c58a2/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e h1:AyodaIpKjppX+cBfTASF2E1US3H2JFBj920Ot3rtDjs= +golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4 h1:0YWbFKbhXG/wIiuHDSKpS0Iy7FSA+u45VtBMfQcFTTc= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190327201419-c70d86f8b7cf/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180518175338-11a468237815/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20200911024640-645f7a48b24f h1:Yv4xsIx7HZOoyUGSJ2ksDyWE2qIBXROsZKt2ny3hCGM= +google.golang.org/genproto v0.0.0-20200911024640-645f7a48b24f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/grpc v1.12.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= +google.golang.org/grpc v1.33.1 h1:DGeFlSan2f+WEtCERJ4J9GJWk15TxUi8QGagfI87Xyc= +google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= +google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= +gopkg.in/go-playground/validator.v8 v8.18.2/go.mod h1:RX2a/7Ha8BgOhfk7j780h4/u/RRjR0eouCJSH80/M2Y= +gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/parsers/pgparser/pgparser.go b/parsers/pgparser/pgparser.go new file mode 100644 index 0000000..e0b7123 --- /dev/null +++ b/parsers/pgparser/pgparser.go @@ -0,0 +1,149 @@ +// Package pgparser is an optional sqlguard Parser backed by a real +// PostgreSQL grammar (github.com/auxten/postgresql-parser, pure Go, no cgo). +// +// It produces exact, structural answers for the false-positive-prone facts +// (statement kind, WHERE/LIMIT/ORDER BY/FROM presence, SELECT *, explicit +// INSERT columns) instead of regex guesses. SQL the grammar rejects — +// dynamic fragments, dialect extensions, driver placeholders it can't +// handle — transparently degrades to sqlguard's zero-dependency +// FallbackParser, so analysis never breaks the caller's query path. +// +// Usage: +// +// sqlguard.Register("sqlguard-pg", "pgx", middleware.WithParser(pgparser.New())) +// db, _ := sql.Open("sqlguard-pg", dsn) +package pgparser + +import ( + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/auxten/postgresql-parser/pkg/sql/parser" + "github.com/auxten/postgresql-parser/pkg/sql/sem/tree" +) + +// Parser implements analyzer.Parser using a PostgreSQL grammar. +type Parser struct { + fallback analyzer.Parser +} + +// New returns a Postgres-dialect Parser that falls back to the +// zero-dependency FallbackParser on parse failure. +func New() *Parser { + return &Parser{fallback: analyzer.NewFallbackParser()} +} + +var _ analyzer.Parser = (*Parser)(nil) + +// Parse implements analyzer.Parser. It never returns an error: unparseable +// SQL yields the fallback parser's best-effort Statement (Exact=false). +func (p *Parser) Parse(sql string) (*analyzer.Statement, error) { + // The fallback result is the baseline. It already detects the literal/text- + // level fields (leading-wildcard LIKE, non-sargable predicates, unsafe + // NOT NULL adds) that the AST loses after parsing, so we keep those fields + // and overwrite only the structural ones. + st, _ := p.fallback.Parse(sql) + if st == nil { + st = &analyzer.Statement{Raw: sql} + } + + stmts, err := parser.Parse(sql) + if err != nil || len(stmts) == 0 || stmts[0].AST == nil { + return st, nil // keep best-effort fallback Statement + } + + st.Kind = analyzer.StmtOther + st.HasWhere = false + st.HasLimit = false + st.HasOrderBy = false + st.HasFrom = false + st.SelectStar = false + st.SelectDistinct = false + st.OffsetValue = 0 + st.InsertColumnsListed = false + + switch n := stmts[0].AST.(type) { + case *tree.Select: + st.Kind = analyzer.StmtSelect + st.HasOrderBy = len(n.OrderBy) > 0 + st.HasLimit = n.Limit != nil + st.OffsetValue = offsetValue(n.Limit) + fillSelectBody(st, n.Select) + case *tree.SelectClause: + st.Kind = analyzer.StmtSelect + fillSelectClause(st, n) + case *tree.Delete: + st.Kind = analyzer.StmtDelete + st.HasWhere = n.Where != nil + st.HasLimit = n.Limit != nil + st.HasOrderBy = len(n.OrderBy) > 0 + st.OffsetValue = offsetValue(n.Limit) + case *tree.Update: + st.Kind = analyzer.StmtUpdate + st.HasWhere = n.Where != nil + st.HasLimit = n.Limit != nil + st.HasOrderBy = len(n.OrderBy) > 0 + st.OffsetValue = offsetValue(n.Limit) + case *tree.Insert: + st.Kind = analyzer.StmtInsert + st.InsertColumnsListed = len(n.Columns) > 0 + } + + st.Exact = true + return st, nil +} + +// fillSelectBody unwraps the inner SelectStatement of a *tree.Select. +func fillSelectBody(st *analyzer.Statement, sel tree.SelectStatement) { + switch c := sel.(type) { + case *tree.SelectClause: + fillSelectClause(st, c) + case *tree.ParenSelect: + if c.Select != nil { + st.HasOrderBy = st.HasOrderBy || len(c.Select.OrderBy) > 0 + st.HasLimit = st.HasLimit || c.Select.Limit != nil + if v := offsetValue(c.Select.Limit); v > st.OffsetValue { + st.OffsetValue = v + } + fillSelectBody(st, c.Select.Select) + } + } + // UnionClause / ValuesClause: leave structural defaults; the rules that + // matter for those forms don't trigger on set operations. +} + +// offsetValue extracts a literal OFFSET as an int, or 0 when there is no limit +// clause, no offset, or a non-literal (parameterized) offset — matching the +// large-offset rule's contract that only statically-known offsets are flagged. +func offsetValue(lim *tree.Limit) int { + if lim == nil { + return 0 + } + nv, ok := lim.Offset.(*tree.NumVal) + if !ok { + return 0 + } + n, err := nv.AsInt64() + if err != nil || n < 0 { + return 0 + } + return int(n) +} + +func fillSelectClause(st *analyzer.Statement, c *tree.SelectClause) { + st.HasWhere = c.Where != nil + st.HasFrom = len(c.From.Tables) > 0 + // DISTINCT and DISTINCT ON both set the select-level distinct flag; an + // aggregate-level DISTINCT (count(DISTINCT x)) lives in the expr, not here. + st.SelectDistinct = c.Distinct || len(c.DistinctOn) > 0 + for _, e := range c.Exprs { + switch ex := e.Expr.(type) { + case tree.UnqualifiedStar, *tree.UnqualifiedStar: + st.SelectStar = true // SELECT * + case *tree.AllColumnsSelector: + st.SelectStar = true // SELECT t.* (resolved form) + case *tree.UnresolvedName: + if ex.Star { // SELECT t.* (unresolved form) + st.SelectStar = true + } + } + } +} diff --git a/parsers/pgparser/pgparser_test.go b/parsers/pgparser/pgparser_test.go new file mode 100644 index 0000000..b2e51b2 --- /dev/null +++ b/parsers/pgparser/pgparser_test.go @@ -0,0 +1,137 @@ +package pgparser + +import ( + "testing" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func TestParser_ExactStructuralFacts(t *testing.T) { + p := New() + tests := []struct { + name string + sql string + want analyzer.Statement + }{ + { + name: "cte-wrapped delete with where", + sql: "WITH r AS (SELECT id FROM o WHERE ts > now()) DELETE FROM o WHERE id IN (SELECT id FROM r)", + want: analyzer.Statement{Kind: analyzer.StmtDelete, HasWhere: true, Exact: true}, + }, + { + name: "delete without where", + sql: "DELETE FROM users", + want: analyzer.Statement{Kind: analyzer.StmtDelete, HasWhere: false, Exact: true}, + }, + { + name: "select star with from", + sql: "SELECT * FROM users", + want: analyzer.Statement{Kind: analyzer.StmtSelect, SelectStar: true, HasFrom: true, Exact: true}, + }, + { + name: "qualified star", + sql: "SELECT u.* FROM users u", + want: analyzer.Statement{Kind: analyzer.StmtSelect, SelectStar: true, HasFrom: true, Exact: true}, + }, + { + name: "count star is not select star", + sql: "SELECT count(*) FROM users", + want: analyzer.Statement{Kind: analyzer.StmtSelect, SelectStar: false, HasFrom: true, Exact: true}, + }, + { + name: "select no from", + sql: "SELECT 1", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: false, Exact: true}, + }, + { + name: "insert with columns", + sql: "INSERT INTO users (name) VALUES ('a')", + want: analyzer.Statement{Kind: analyzer.StmtInsert, InsertColumnsListed: true, Exact: true}, + }, + { + name: "insert without columns", + sql: "INSERT INTO users VALUES ('a')", + want: analyzer.Statement{Kind: analyzer.StmtInsert, InsertColumnsListed: false, Exact: true}, + }, + { + name: "order by without limit", + sql: "SELECT id FROM users ORDER BY name", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasOrderBy: true, Exact: true}, + }, + { + name: "select distinct", + sql: "SELECT DISTINCT name FROM users", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, SelectDistinct: true, Exact: true}, + }, + { + name: "distinct on", + sql: "SELECT DISTINCT ON (dept) dept, name FROM emp", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, SelectDistinct: true, Exact: true}, + }, + { + name: "count distinct is not select distinct", + sql: "SELECT count(DISTINCT id) FROM users", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, Exact: true}, + }, + { + name: "literal offset", + sql: "SELECT id FROM users WHERE x = 1 ORDER BY id LIMIT 10 OFFSET 5000", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasWhere: true, HasOrderBy: true, HasLimit: true, OffsetValue: 5000, Exact: true}, + }, + { + name: "parameterized offset is zero", + sql: "SELECT id FROM users WHERE x = 1 LIMIT 10 OFFSET $1", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasWhere: true, HasLimit: true, Exact: true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + st, err := p.Parse(tt.sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if st.Kind != tt.want.Kind || + st.HasWhere != tt.want.HasWhere || + st.HasLimit != tt.want.HasLimit || + st.HasOrderBy != tt.want.HasOrderBy || + st.HasFrom != tt.want.HasFrom || + st.SelectStar != tt.want.SelectStar || + st.SelectDistinct != tt.want.SelectDistinct || + st.OffsetValue != tt.want.OffsetValue || + st.InsertColumnsListed != tt.want.InsertColumnsListed || + st.Exact != tt.want.Exact { + t.Errorf("Parse(%q)\n got: %+v\nwant: %+v", tt.sql, *st, tt.want) + } + }) + } +} + +func TestParser_FallsBackOnUnparseable(t *testing.T) { + p := New() + // Driver placeholders the PG grammar won't accept as-is still must not + // error, and must come back as a best-effort (non-exact) Statement. + st, err := p.Parse("SELECT * FROM t WHERE id = ?") + if err != nil { + t.Fatalf("fallback path must not error: %v", err) + } + if st == nil { + t.Fatal("nil statement") + } + if st.Exact { + t.Error("expected Exact=false when grammar rejected the SQL") + } +} + +func TestParser_IntegratesWithAnalyzer(t *testing.T) { + a := analyzer.Default().WithParser(New()) + + got := a.Analyze("DELETE FROM users -- WHERE id = 1") + if len(got) != 1 || got[0].RuleName != "delete-without-where" { + t.Errorf("expected delete-without-where (WHERE only in comment), got %+v", got) + } + + if r := a.Analyze("SELECT id FROM users WHERE id = 1 LIMIT 1"); len(r) != 0 { + t.Errorf("expected no findings for safe query, got %+v", r) + } +} diff --git a/reporter/console.go b/reporter/console.go new file mode 100644 index 0000000..fb92fbb --- /dev/null +++ b/reporter/console.go @@ -0,0 +1,57 @@ +package reporter + +import ( + "fmt" + "io" + "os" + "sync" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +const ( + colorReset = "\033[0m" + colorRed = "\033[31m" + colorYellow = "\033[33m" + colorCyan = "\033[36m" +) + +// ConsoleReporter prints analysis results to the terminal with color. +type ConsoleReporter struct { + Out io.Writer + mu sync.Mutex +} + +// NewConsoleReporter creates a ConsoleReporter that writes to stderr. +func NewConsoleReporter() *ConsoleReporter { + return &ConsoleReporter{Out: os.Stderr} +} + +// Report writes each result to the configured output, colored by severity. +func (c *ConsoleReporter) Report(results []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, r := range results { + color := colorCyan + switch r.Severity { + case analyzer.SeverityWarning: + color = colorYellow + case analyzer.SeverityCritical: + color = colorRed + } + + _, _ = fmt.Fprintf(c.Out, "\n%s[SQLGUARD %s]%s %s\n", color, r.Severity, colorReset, r.RuleName) + + if r.File != "" { + _, _ = fmt.Fprintf(c.Out, " File: %s:%d\n", r.File, r.Line) + } + + _, _ = fmt.Fprintf(c.Out, " Query: %s\n", r.Query) + _, _ = fmt.Fprintf(c.Out, " Issue: %s\n", r.Message) + + if r.Suggestion != "" { + _, _ = fmt.Fprintf(c.Out, " Fix: %s\n", r.Suggestion) + } + } +} diff --git a/reporter/console_test.go b/reporter/console_test.go new file mode 100644 index 0000000..e4adca2 --- /dev/null +++ b/reporter/console_test.go @@ -0,0 +1,86 @@ +package reporter + +import ( + "bytes" + "strings" + "testing" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func TestConsoleReporter_Report(t *testing.T) { + var buf bytes.Buffer + rep := &ConsoleReporter{Out: &buf} + + results := []analyzer.Result{ + { + RuleName: "select-star", + Severity: analyzer.SeverityWarning, + Query: "SELECT * FROM users", + Message: "SELECT * detected.", + Suggestion: "Select only needed columns.", + }, + } + + rep.Report(results) + output := buf.String() + + if !strings.Contains(output, "SQLGUARD WARNING") { + t.Error("expected WARNING label in output") + } + if !strings.Contains(output, "select-star") { + t.Error("expected rule name in output") + } + if !strings.Contains(output, "SELECT * FROM users") { + t.Error("expected query in output") + } + if !strings.Contains(output, "Select only needed columns.") { + t.Error("expected suggestion in output") + } +} + +func TestConsoleReporter_CriticalSeverity(t *testing.T) { + var buf bytes.Buffer + rep := &ConsoleReporter{Out: &buf} + + rep.Report([]analyzer.Result{{ + RuleName: "delete-without-where", + Severity: analyzer.SeverityCritical, + Query: "DELETE FROM users", + Message: "DELETE without WHERE.", + }}) + + if !strings.Contains(buf.String(), "SQLGUARD CRITICAL") { + t.Error("expected CRITICAL label in output") + } +} + +func TestConsoleReporter_WithFileInfo(t *testing.T) { + var buf bytes.Buffer + rep := &ConsoleReporter{Out: &buf} + + rep.Report([]analyzer.Result{{ + RuleName: "select-star", + Severity: analyzer.SeverityWarning, + Query: "SELECT * FROM users", + Message: "SELECT * detected.", + File: "repo/user.go", + Line: 42, + }}) + + output := buf.String() + if !strings.Contains(output, "repo/user.go:42") { + t.Error("expected file:line in output") + } +} + +func TestConsoleReporter_EmptyResults(t *testing.T) { + var buf bytes.Buffer + rep := &ConsoleReporter{Out: &buf} + + rep.Report(nil) + + if buf.Len() != 0 { + t.Error("expected no output for empty results") + } +} diff --git a/reporter/json.go b/reporter/json.go new file mode 100644 index 0000000..cabeded --- /dev/null +++ b/reporter/json.go @@ -0,0 +1,56 @@ +package reporter + +import ( + "encoding/json" + "io" + "os" + "sync" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +// JSONReporter outputs analysis results as JSON. +type JSONReporter struct { + Out io.Writer + mu sync.Mutex +} + +// NewJSONReporter creates a JSONReporter that writes to stderr. +func NewJSONReporter() *JSONReporter { + return &JSONReporter{Out: os.Stderr} +} + +type jsonResult struct { + Rule string `json:"rule"` + Severity string `json:"severity"` + Query string `json:"query"` + Fingerprint string `json:"fingerprint,omitempty"` + Message string `json:"message"` + Suggestion string `json:"suggestion,omitempty"` + File string `json:"file,omitempty"` + Line int `json:"line,omitempty"` +} + +// Report writes the results to the configured output as a JSON array. +func (j *JSONReporter) Report(results []analyzer.Result) { + j.mu.Lock() + defer j.mu.Unlock() + + out := make([]jsonResult, len(results)) + for i, r := range results { + out[i] = jsonResult{ + Rule: r.RuleName, + Severity: r.Severity.String(), + Query: r.Query, + Fingerprint: r.Fingerprint, + Message: r.Message, + Suggestion: r.Suggestion, + File: r.File, + Line: r.Line, + } + } + + enc := json.NewEncoder(j.Out) + enc.SetIndent("", " ") + _ = enc.Encode(out) +} diff --git a/reporter/json_test.go b/reporter/json_test.go new file mode 100644 index 0000000..05e31dd --- /dev/null +++ b/reporter/json_test.go @@ -0,0 +1,77 @@ +package reporter + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func TestJSONReporter_Report(t *testing.T) { + var buf bytes.Buffer + rep := &JSONReporter{Out: &buf} + + results := []analyzer.Result{ + { + RuleName: "select-star", + Severity: analyzer.SeverityWarning, + Query: "SELECT * FROM users", + Message: "SELECT * detected.", + Suggestion: "Select only needed columns.", + File: "user.go", + Line: 10, + }, + { + RuleName: "delete-without-where", + Severity: analyzer.SeverityCritical, + Query: "DELETE FROM users", + Message: "DELETE without WHERE.", + }, + } + + rep.Report(results) + + var parsed []map[string]any + if err := json.Unmarshal(buf.Bytes(), &parsed); err != nil { + t.Fatalf("invalid JSON output: %v\nGot: %s", err, buf.String()) + } + + if len(parsed) != 2 { + t.Fatalf("expected 2 results, got %d", len(parsed)) + } + + if parsed[0]["rule"] != "select-star" { + t.Errorf("expected rule 'select-star', got %v", parsed[0]["rule"]) + } + if parsed[0]["severity"] != "WARNING" { + t.Errorf("expected severity 'WARNING', got %v", parsed[0]["severity"]) + } + if parsed[0]["file"] != "user.go" { + t.Errorf("expected file 'user.go', got %v", parsed[0]["file"]) + } + + if parsed[1]["severity"] != "CRITICAL" { + t.Errorf("expected severity 'CRITICAL', got %v", parsed[1]["severity"]) + } + // file should be omitted (empty) + if _, ok := parsed[1]["file"]; ok && parsed[1]["file"] != "" { + t.Errorf("expected file to be omitted, got %v", parsed[1]["file"]) + } +} + +func TestJSONReporter_EmptyResults(t *testing.T) { + var buf bytes.Buffer + rep := &JSONReporter{Out: &buf} + + rep.Report([]analyzer.Result{}) + + var parsed []map[string]any + if err := json.Unmarshal(buf.Bytes(), &parsed); err != nil { + t.Fatalf("invalid JSON output: %v", err) + } + + if len(parsed) != 0 { + t.Errorf("expected empty array, got %d items", len(parsed)) + } +} diff --git a/reporter/reporter.go b/reporter/reporter.go new file mode 100644 index 0000000..fec8431 --- /dev/null +++ b/reporter/reporter.go @@ -0,0 +1,8 @@ +package reporter + +import "github.com/KARTIKrocks/sqlguard/analyzer" + +// Reporter defines the interface for reporting analysis results. +type Reporter interface { + Report(results []analyzer.Result) +} diff --git a/sqlguard.go b/sqlguard.go new file mode 100644 index 0000000..912ac95 --- /dev/null +++ b/sqlguard.go @@ -0,0 +1,40 @@ +// Package sqlguard is a production-safe SQL query analyzer for Go applications. +// +// It detects slow queries, dangerous SQL patterns, and performance issues +// both at runtime (via a database/sql driver wrapper) and statically +// (via the CLI). +// +// The runtime guard wraps at the driver.Driver layer, so it returns a real +// *sql.DB and analyzes every query — including those issued by ORMs and +// query builders — without a method list to keep in sync. +// +// Register a wrapped driver by name: +// +// sqlguard.Register("sqlguard-pg", "pgx") +// db, _ := sql.Open("sqlguard-pg", dsn) +// db.Query("SELECT * FROM users") // logs warning about SELECT * +// +// Or wrap an existing driver.Connector directly: +// +// db := sqlguard.OpenDB(connector) +package sqlguard + +import ( + "database/sql" + "database/sql/driver" + + "github.com/KARTIKrocks/sqlguard/middleware" +) + +// Register wraps the database/sql driver registered under baseDriver and +// registers the analyzed result under name. Afterwards sql.Open(name, dsn) +// yields a *sql.DB whose every query is analyzed. +func Register(name, baseDriver string, opts ...middleware.Option) error { + return middleware.Register(name, baseDriver, opts...) +} + +// OpenDB wraps a driver.Connector and returns an analyzed *sql.DB. Use this +// when you already hold a connector (e.g. pgx's stdlib.GetConnector). +func OpenDB(c driver.Connector, opts ...middleware.Option) *sql.DB { + return middleware.OpenDB(c, opts...) +} From 909d07e3a95382d363e594354ef8670eb5dcffa7 Mon Sep 17 00:00:00 2001 From: KARTIKrocks <105914814+KARTIKrocks@users.noreply.github.com> Date: Mon, 8 Jun 2026 16:13:07 +0530 Subject: [PATCH 2/9] Update .github/workflows/ci.yml Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d6ba9b2..610233e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,6 +6,10 @@ on: pull_request: branches: [main] +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + permissions: contents: read From b38719e6683bfd5920ce6ec775c57f81a52fbf41 Mon Sep 17 00:00:00 2001 From: KARTIKrocks <105914814+KARTIKrocks@users.noreply.github.com> Date: Mon, 8 Jun 2026 16:16:46 +0530 Subject: [PATCH 3/9] Update .github/workflows/codeql.yml Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- .github/workflows/codeql.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 90b96ee..1189bf6 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -14,6 +14,7 @@ concurrency: cancel-in-progress: ${{ github.event_name != 'schedule' }} permissions: + # CodeQL requires security-events: write to upload SARIF results security-events: write contents: read From b5f7fa9eb6374d207ac23175810e36e9cd581697 Mon Sep 17 00:00:00 2001 From: kartik Date: Mon, 8 Jun 2026 16:18:54 +0530 Subject: [PATCH 4/9] fix: disable credential persistence for checkout actions in CI workflows --- .github/workflows/ci.yml | 8 ++++++++ .github/workflows/codeql.yml | 2 ++ 2 files changed, 10 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 610233e..90fdb30 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,6 +21,8 @@ jobs: go-version: ["1.26"] steps: - uses: actions/checkout@v6 + with: + persist-credentials: false - uses: actions/setup-go@v6 with: @@ -57,6 +59,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 + with: + persist-credentials: false - uses: actions/setup-go@v6 with: @@ -71,6 +75,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 + with: + persist-credentials: false - uses: actions/setup-go@v6 with: @@ -83,6 +89,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 + with: + persist-credentials: false - uses: actions/setup-go@v6 with: diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 1189bf6..7d5a6da 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -25,6 +25,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v6 + with: + persist-credentials: false - name: Setup Go uses: actions/setup-go@v6 From 5c7ab7a21a200e7e045de5fca1d345e5373b7a1b Mon Sep 17 00:00:00 2001 From: KARTIKrocks <105914814+KARTIKrocks@users.noreply.github.com> Date: Mon, 8 Jun 2026 16:24:25 +0530 Subject: [PATCH 5/9] Update cmd/sqlguard/scan.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- cmd/sqlguard/scan.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cmd/sqlguard/scan.go b/cmd/sqlguard/scan.go index 23f94ef..f1fd6e9 100644 --- a/cmd/sqlguard/scan.go +++ b/cmd/sqlguard/scan.go @@ -115,7 +115,10 @@ func newReporter(format string) (reporter.Reporter, error) { // dependency-free go/parser walk that still handles inline string literals, so // a broken or module-less tree is never silently skipped. func scanDir(dir string, a *analyzer.Analyzer, exclude func(string) bool) ([]analyzer.Result, int, error) { - absDir, _ := filepath.Abs(dir) + absDir, err := filepath.Abs(dir) + if err != nil { + return nil, 0, fmt.Errorf("cannot resolve absolute path: %w", err) + } if results, n, ok := scanViaPackages(absDir, a, exclude); ok { return results, n, nil From 95b90edd53c4c310920002618c2c7d56d3d2e422 Mon Sep 17 00:00:00 2001 From: KARTIKrocks <105914814+KARTIKrocks@users.noreply.github.com> Date: Mon, 8 Jun 2026 16:25:16 +0530 Subject: [PATCH 6/9] Update reporter/json.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- reporter/json.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/reporter/json.go b/reporter/json.go index cabeded..81712b6 100644 --- a/reporter/json.go +++ b/reporter/json.go @@ -52,5 +52,8 @@ func (j *JSONReporter) Report(results []analyzer.Result) { enc := json.NewEncoder(j.Out) enc.SetIndent("", " ") - _ = enc.Encode(out) + if err := enc.Encode(out); err != nil { + // Fallback: log encoding failure since Reporter interface can't return error + fmt.Fprintf(os.Stderr, "sqlguard: failed to encode JSON report: %v\n", err) + } } From f3131075f32151e7aca5af9d4a7aa2892d7ec759 Mon Sep 17 00:00:00 2001 From: kartik Date: Mon, 8 Jun 2026 16:28:55 +0530 Subject: [PATCH 7/9] fix: correct comment typo in skipSingleQuoted function --- analyzer/redact.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/analyzer/redact.go b/analyzer/redact.go index ea3de97..41b97df 100644 --- a/analyzer/redact.go +++ b/analyzer/redact.go @@ -123,7 +123,7 @@ func scanNumber(s string, i int) int { } // skipSingleQuoted returns the index just past the single-quoted string -// literal starting at s[i] == '\”, honoring ” doubled-quote escapes. +// literal starting at s[i] == '\”, honoring '' doubled-quote escapes. func skipSingleQuoted(s string, i int) int { i++ // opening quote for i < len(s) { From c51e319e648b3a497dc5cbbb7f77e00834663c23 Mon Sep 17 00:00:00 2001 From: kartik Date: Mon, 8 Jun 2026 16:29:36 +0530 Subject: [PATCH 8/9] fix: add missing import for fmt in JSONReporter --- reporter/json.go | 1 + 1 file changed, 1 insertion(+) diff --git a/reporter/json.go b/reporter/json.go index 81712b6..6f89099 100644 --- a/reporter/json.go +++ b/reporter/json.go @@ -2,6 +2,7 @@ package reporter import ( "encoding/json" + "fmt" "io" "os" "sync" From d203993c0a494bf475e59099ccf89294f10afb9d Mon Sep 17 00:00:00 2001 From: kartik Date: Mon, 8 Jun 2026 16:49:59 +0530 Subject: [PATCH 9/9] feat: enhance reporting capabilities and improve documentation --- .coderabbit.yaml | 6 +++++- CHANGELOG.md | 4 ++-- README.md | 2 ++ analyzer/fallback.go | 5 +++-- analyzer/redact.go | 3 ++- config/middleware_test.go | 2 +- middleware/driver_test.go | 2 +- parsers/mysqlparser/mysqlparser.go | 5 ++++- parsers/mysqlparser/mysqlparser_test.go | 15 +++++++++++++++ reporter/console.go | 21 ++++++++++++++------- reporter/console_test.go | 8 ++++---- reporter/json.go | 13 ++++++++++--- reporter/json_test.go | 4 ++-- 13 files changed, 65 insertions(+), 25 deletions(-) diff --git a/.coderabbit.yaml b/.coderabbit.yaml index 3221c5b..81a41ed 100644 --- a/.coderabbit.yaml +++ b/.coderabbit.yaml @@ -121,7 +121,11 @@ reviews: comment/string-aware multi-statement rejection and SELECT/WITH-only (DML behind WithAllowDML) policy. EXPLAIN takes no bind params, so concatenation is by design — the defense is validate() + the rolled-back - read-only tx; do not "fix" it with parameterization. + read-only tx; do not "fix" it with parameterization. Deliberate + carve-out: explain keeps Result.Query (and the inner analyzer.Result + .Query of its findings) RAW — the user typed the query on their own CLI, + it never reaches a log/telemetry sink, and Fingerprint is still set. Do + NOT flag explain findings for not redacting Query; that is intended. - path: "**/*_test.go" instructions: >- diff --git a/CHANGELOG.md b/CHANGELOG.md index e602e58..31cfd92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,13 +19,13 @@ Initial public release. (`Register` / `OpenDB`), so any query — including those issued by ORMs and query builders — is analyzed and you get back a real `*sql.DB`. Zero third-party dependencies in the core. -- **Analyzer with 19 detection rules** across static, runtime, and EXPLAIN +- **Analyzer with 21 detection rules** across static, runtime, and EXPLAIN surfaces: `select-star`, `leading-wildcard`, `non-sargable-predicate`, `add-not-null-without-default`, `implicit-join`, `cartesian-join`, `in-list-too-large`, `large-offset`, `select-distinct`, `delete-without-where`, `update-without-where`, `insert-without-columns`, `select-without-limit`, `orderby-without-limit`, `n-plus-one`, `slow-query`, `seq-scan`, - `full-table-scan`, `high-cost`. + `full-table-scan`, `high-cost`, `no-index-used`, `filesort`. - **Redaction by default**: every `Result.Query` is redacted (literals → `?`) before it leaves the process, and every `Result.Fingerprint` is a PII-free, low-cardinality query identity safe as a metric label. Opt out with diff --git a/README.md b/README.md index 7886379..d4f2b00 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,8 @@ go install github.com/KARTIKrocks/sqlguard/cmd/sqlguard@latest | `seq-scan` | WARNING | Sequential scan detected via EXPLAIN (postgres) | | `full-table-scan` | WARNING | Full table scan detected via EXPLAIN (mysql) | | `high-cost` | WARNING | High cost operation in query plan | +| `no-index-used` | WARNING | No index used for a table access detected via EXPLAIN (mysql) | +| `filesort` | INFO | `Using filesort` in the query plan — `ORDER BY` not covered by an index (mysql) | ## Configuration diff --git a/analyzer/fallback.go b/analyzer/fallback.go index dc9dd31..53e43df 100644 --- a/analyzer/fallback.go +++ b/analyzer/fallback.go @@ -502,8 +502,9 @@ func stripComments(s string) string { } // copyStringLiteral writes the string literal that begins at s[i] (a quote -// byte) verbatim, honoring ” / "" doubled-quote escapes, and returns the -// index just past the literal. +// byte) verbatim, treating a doubled quote (two of the same quote byte in a +// row) as an escaped quote rather than the terminator, and returns the index +// just past the literal. func copyStringLiteral(b *strings.Builder, s string, i int) int { q := s[i] b.WriteByte(q) diff --git a/analyzer/redact.go b/analyzer/redact.go index 41b97df..14f81b3 100644 --- a/analyzer/redact.go +++ b/analyzer/redact.go @@ -123,7 +123,8 @@ func scanNumber(s string, i int) int { } // skipSingleQuoted returns the index just past the single-quoted string -// literal starting at s[i] == '\”, honoring '' doubled-quote escapes. +// literal that opens at s[i], treating a doubled single-quote (two in a row) +// as an escaped quote rather than the terminator. func skipSingleQuoted(s string, i int) int { i++ // opening quote for i < len(s) { diff --git a/config/middleware_test.go b/config/middleware_test.go index 6caf642..cfe4b9f 100644 --- a/config/middleware_test.go +++ b/config/middleware_test.go @@ -23,7 +23,7 @@ func TestMiddlewareOptionsAppliesProfile(t *testing.T) { } var buf strings.Builder - opts = append(opts, middleware.WithReporter(&reporter.ConsoleReporter{Out: &buf})) + opts = append(opts, middleware.WithReporter(reporter.NewConsoleReporterTo(&buf))) name := "sqlguard-cfg-test" if err := sqlguard.Register(name, "sqlite3", opts...); err != nil { diff --git a/middleware/driver_test.go b/middleware/driver_test.go index d3e4328..a33ed1b 100644 --- a/middleware/driver_test.go +++ b/middleware/driver_test.go @@ -46,7 +46,7 @@ func newGuardedDB(t *testing.T, opts ...Option) *sql.DB { func guardedWithBuffer(t *testing.T, extra ...Option) (*sql.DB, *bytes.Buffer) { t.Helper() var buf bytes.Buffer - opts := append([]Option{WithReporter(&reporter.ConsoleReporter{Out: &buf})}, extra...) + opts := append([]Option{WithReporter(reporter.NewConsoleReporterTo(&buf))}, extra...) return newGuardedDB(t, opts...), &buf } diff --git a/parsers/mysqlparser/mysqlparser.go b/parsers/mysqlparser/mysqlparser.go index 2df5792..7943340 100644 --- a/parsers/mysqlparser/mysqlparser.go +++ b/parsers/mysqlparser/mysqlparser.go @@ -17,6 +17,7 @@ package mysqlparser import ( "strconv" + "strings" "github.com/KARTIKrocks/sqlguard/analyzer" "github.com/xwb1989/sqlparser" @@ -126,7 +127,9 @@ func hasRealFrom(from sqlparser.TableExprs) bool { return true // join / subquery / etc. — a real source } if tn, ok := ate.Expr.(sqlparser.TableName); ok { - if tn.Name.String() == "dual" { + // Case-insensitive: sqlparser preserves the casing of backticked + // identifiers, so `DUAL` would otherwise read as a real table. + if strings.EqualFold(tn.Name.String(), "dual") { continue } } diff --git a/parsers/mysqlparser/mysqlparser_test.go b/parsers/mysqlparser/mysqlparser_test.go index f9977c2..abb1271 100644 --- a/parsers/mysqlparser/mysqlparser_test.go +++ b/parsers/mysqlparser/mysqlparser_test.go @@ -48,6 +48,21 @@ func TestParser_ExactStructuralFacts(t *testing.T) { sql: "SELECT 1", want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: false, Exact: true}, }, + { + name: "explicit dual is not a real from", + sql: "SELECT 1 FROM dual", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: false, Exact: true}, + }, + { + name: "uppercase DUAL is not a real from", + sql: "SELECT 1 FROM DUAL", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: false, Exact: true}, + }, + { + name: "backticked DUAL is not a real from", + sql: "SELECT 1 FROM `DUAL`", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: false, Exact: true}, + }, { name: "insert with columns", sql: "INSERT INTO users (name) VALUES ('a')", diff --git a/reporter/console.go b/reporter/console.go index fb92fbb..286aa13 100644 --- a/reporter/console.go +++ b/reporter/console.go @@ -17,14 +17,21 @@ const ( ) // ConsoleReporter prints analysis results to the terminal with color. +// The output writer is fixed at construction so Report is safe for concurrent +// use (the writer cannot be swapped out from under an in-flight Report). type ConsoleReporter struct { - Out io.Writer + out io.Writer mu sync.Mutex } // NewConsoleReporter creates a ConsoleReporter that writes to stderr. func NewConsoleReporter() *ConsoleReporter { - return &ConsoleReporter{Out: os.Stderr} + return NewConsoleReporterTo(os.Stderr) +} + +// NewConsoleReporterTo creates a ConsoleReporter that writes to w. +func NewConsoleReporterTo(w io.Writer) *ConsoleReporter { + return &ConsoleReporter{out: w} } // Report writes each result to the configured output, colored by severity. @@ -41,17 +48,17 @@ func (c *ConsoleReporter) Report(results []analyzer.Result) { color = colorRed } - _, _ = fmt.Fprintf(c.Out, "\n%s[SQLGUARD %s]%s %s\n", color, r.Severity, colorReset, r.RuleName) + _, _ = fmt.Fprintf(c.out, "\n%s[SQLGUARD %s]%s %s\n", color, r.Severity, colorReset, r.RuleName) if r.File != "" { - _, _ = fmt.Fprintf(c.Out, " File: %s:%d\n", r.File, r.Line) + _, _ = fmt.Fprintf(c.out, " File: %s:%d\n", r.File, r.Line) } - _, _ = fmt.Fprintf(c.Out, " Query: %s\n", r.Query) - _, _ = fmt.Fprintf(c.Out, " Issue: %s\n", r.Message) + _, _ = fmt.Fprintf(c.out, " Query: %s\n", r.Query) + _, _ = fmt.Fprintf(c.out, " Issue: %s\n", r.Message) if r.Suggestion != "" { - _, _ = fmt.Fprintf(c.Out, " Fix: %s\n", r.Suggestion) + _, _ = fmt.Fprintf(c.out, " Fix: %s\n", r.Suggestion) } } } diff --git a/reporter/console_test.go b/reporter/console_test.go index e4adca2..a1d0f56 100644 --- a/reporter/console_test.go +++ b/reporter/console_test.go @@ -10,7 +10,7 @@ import ( func TestConsoleReporter_Report(t *testing.T) { var buf bytes.Buffer - rep := &ConsoleReporter{Out: &buf} + rep := NewConsoleReporterTo(&buf) results := []analyzer.Result{ { @@ -41,7 +41,7 @@ func TestConsoleReporter_Report(t *testing.T) { func TestConsoleReporter_CriticalSeverity(t *testing.T) { var buf bytes.Buffer - rep := &ConsoleReporter{Out: &buf} + rep := NewConsoleReporterTo(&buf) rep.Report([]analyzer.Result{{ RuleName: "delete-without-where", @@ -57,7 +57,7 @@ func TestConsoleReporter_CriticalSeverity(t *testing.T) { func TestConsoleReporter_WithFileInfo(t *testing.T) { var buf bytes.Buffer - rep := &ConsoleReporter{Out: &buf} + rep := NewConsoleReporterTo(&buf) rep.Report([]analyzer.Result{{ RuleName: "select-star", @@ -76,7 +76,7 @@ func TestConsoleReporter_WithFileInfo(t *testing.T) { func TestConsoleReporter_EmptyResults(t *testing.T) { var buf bytes.Buffer - rep := &ConsoleReporter{Out: &buf} + rep := NewConsoleReporterTo(&buf) rep.Report(nil) diff --git a/reporter/json.go b/reporter/json.go index 6f89099..2bdd04f 100644 --- a/reporter/json.go +++ b/reporter/json.go @@ -11,14 +11,21 @@ import ( ) // JSONReporter outputs analysis results as JSON. +// The output writer is fixed at construction so Report is safe for concurrent +// use (the writer cannot be swapped out from under an in-flight Report). type JSONReporter struct { - Out io.Writer + out io.Writer mu sync.Mutex } // NewJSONReporter creates a JSONReporter that writes to stderr. func NewJSONReporter() *JSONReporter { - return &JSONReporter{Out: os.Stderr} + return NewJSONReporterTo(os.Stderr) +} + +// NewJSONReporterTo creates a JSONReporter that writes to w. +func NewJSONReporterTo(w io.Writer) *JSONReporter { + return &JSONReporter{out: w} } type jsonResult struct { @@ -51,7 +58,7 @@ func (j *JSONReporter) Report(results []analyzer.Result) { } } - enc := json.NewEncoder(j.Out) + enc := json.NewEncoder(j.out) enc.SetIndent("", " ") if err := enc.Encode(out); err != nil { // Fallback: log encoding failure since Reporter interface can't return error diff --git a/reporter/json_test.go b/reporter/json_test.go index 05e31dd..b7c30a0 100644 --- a/reporter/json_test.go +++ b/reporter/json_test.go @@ -10,7 +10,7 @@ import ( func TestJSONReporter_Report(t *testing.T) { var buf bytes.Buffer - rep := &JSONReporter{Out: &buf} + rep := NewJSONReporterTo(&buf) results := []analyzer.Result{ { @@ -62,7 +62,7 @@ func TestJSONReporter_Report(t *testing.T) { func TestJSONReporter_EmptyResults(t *testing.T) { var buf bytes.Buffer - rep := &JSONReporter{Out: &buf} + rep := NewJSONReporterTo(&buf) rep.Report([]analyzer.Result{})