Skip to content

Commit

Permalink
refact pkg/database: extract function rollbackOnError(); dry error me…
Browse files Browse the repository at this point in the history
…ssages
  • Loading branch information
mmetc committed Oct 10, 2024
1 parent 1a2cc12 commit 520a518
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ jobs:
- name: Run "make generate" and check for changes
run: |
set -e
make generate 2>/dev/null
make generate
if [[ $(git status --porcelain) ]]; then
echo "Error: Uncommitted changes found after running 'make generate'. Please commit all generated code."
git diff
Expand Down
4 changes: 2 additions & 2 deletions pkg/apiserver/alerts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func TestAlertListFilters(t *testing.T) {

w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String())
assert.Equal(t, `{"message":"invalid ip address 'gruueq'"}`, w.Body.String())

// test range (ok)

Expand All @@ -258,7 +258,7 @@ func TestAlertListFilters(t *testing.T) {

w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=ratata", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String())
assert.Equal(t, `{"message":"invalid ip address 'ratata'"}`, w.Body.String())

// test since (ok)

Expand Down
12 changes: 0 additions & 12 deletions pkg/apiserver/controllers/v1/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,6 @@ func (c *Controller) HandleDBErrors(gctx *gin.Context, err error) {
case errors.Is(err, database.HashError):
gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
case errors.Is(err, database.InsertFail):
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
return
case errors.Is(err, database.QueryFail):
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
return
case errors.Is(err, database.ParseTimeFail):
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
return
case errors.Is(err, database.ParseDurationFail):
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
return
default:
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
return
Expand Down
2 changes: 1 addition & 1 deletion pkg/database/alertfilter.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e
case "ip", "range":
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(value[0])
if err != nil {
return nil, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", value[0], err)
return nil, err
}
case "since", "created_before", "until":
if err := handleTimeFilters(param, value[0], &predicates); err != nil {
Expand Down
43 changes: 13 additions & 30 deletions pkg/database/alerts.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ const (
maxLockRetries = 10 // how many times to retry a bulk operation when sqlite3.ErrBusy is encountered
)

func rollbackOnError(tx *ent.Tx, err error, msg string) error {
if rbErr := tx.Rollback(); rbErr != nil {
log.Errorf("rollback error: %v", rbErr)
}

Check warning on line 37 in pkg/database/alerts.go

View check run for this annotation

Codecov / codecov/patch

pkg/database/alerts.go#L34-L37

Added lines #L34 - L37 were not covered by tests

return fmt.Errorf("%s: %w", msg, err)

Check warning on line 39 in pkg/database/alerts.go

View check run for this annotation

Codecov / codecov/patch

pkg/database/alerts.go#L39

Added line #L39 was not covered by tests
}

// CreateOrUpdateAlert is specific to PAPI : It checks if alert already exists, otherwise inserts it
// if alert already exists, it checks it associated decisions already exists
// if some associated decisions are missing (ie. previous insert ended up in error) it inserts them
Expand Down Expand Up @@ -284,12 +292,7 @@ func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models

duration, err := time.ParseDuration(*decisionItem.Duration)
if err != nil {
rollbackErr := txClient.Rollback()
if rollbackErr != nil {
log.Errorf("rollback error: %s", rollbackErr)
}

return 0, 0, 0, errors.Wrapf(ParseDurationFail, "decision duration '%+v' : %s", *decisionItem.Duration, err)
return 0,0,0, rollbackOnError(txClient, err, "parsing decision duration")

Check warning on line 295 in pkg/database/alerts.go

View check run for this annotation

Codecov / codecov/patch

pkg/database/alerts.go#L295

Added line #L295 was not covered by tests
}

if decisionItem.Scope == nil {
Expand All @@ -301,12 +304,7 @@ func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models
if strings.ToLower(*decisionItem.Scope) == "ip" || strings.ToLower(*decisionItem.Scope) == "range" {
sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(*decisionItem.Value)
if err != nil {
rollbackErr := txClient.Rollback()
if rollbackErr != nil {
log.Errorf("rollback error: %s", rollbackErr)
}

return 0, 0, 0, errors.Wrapf(InvalidIPOrRange, "invalid addr/range %s : %s", *decisionItem.Value, err)
return 0, 0, 0, rollbackOnError(txClient, err, "invalid ip addr/range")

Check warning on line 307 in pkg/database/alerts.go

View check run for this annotation

Codecov / codecov/patch

pkg/database/alerts.go#L307

Added line #L307 was not covered by tests
}
}

Expand Down Expand Up @@ -348,12 +346,7 @@ func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models
decision.ValueIn(deleteChunk...),
)).Exec(ctx)
if err != nil {
rollbackErr := txClient.Rollback()
if rollbackErr != nil {
log.Errorf("rollback error: %s", rollbackErr)
}

return 0, 0, 0, fmt.Errorf("while deleting older community blocklist decisions: %w", err)
return 0, 0, 0, rollbackOnError(txClient, err, "deleting older community blocklist decisions")

Check warning on line 349 in pkg/database/alerts.go

View check run for this annotation

Codecov / codecov/patch

pkg/database/alerts.go#L349

Added line #L349 was not covered by tests
}

deleted += deletedDecisions
Expand All @@ -364,12 +357,7 @@ func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models
for _, builderChunk := range builderChunks {
insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(ctx)
if err != nil {
rollbackErr := txClient.Rollback()
if rollbackErr != nil {
log.Errorf("rollback error: %s", rollbackErr)
}

return 0, 0, 0, fmt.Errorf("while bulk creating decisions: %w", err)
return 0, 0, 0, rollbackOnError(txClient, err, "bulk creating decisions")

Check warning on line 360 in pkg/database/alerts.go

View check run for this annotation

Codecov / codecov/patch

pkg/database/alerts.go#L360

Added line #L360 was not covered by tests
}

inserted += len(insertedDecisions)
Expand All @@ -379,12 +367,7 @@ func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models

err = txClient.Commit()
if err != nil {
rollbackErr := txClient.Rollback()
if rollbackErr != nil {
log.Errorf("rollback error: %s", rollbackErr)
}

return 0, 0, 0, fmt.Errorf("error committing transaction: %w", err)
return 0, 0, 0, rollbackOnError(txClient, err, "error committing transaction")

Check warning on line 370 in pkg/database/alerts.go

View check run for this annotation

Codecov / codecov/patch

pkg/database/alerts.go#L370

Added line #L370 was not covered by tests
}

return alertRef.ID, inserted, deleted, nil
Expand Down
1 change: 0 additions & 1 deletion pkg/database/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ var (
ParseTimeFail = errors.New("unable to parse time")
ParseDurationFail = errors.New("unable to parse duration")
MarshalFail = errors.New("unable to serialize")
UnmarshalFail = errors.New("unable to parse")
BulkError = errors.New("unable to insert bulk")
ParseType = errors.New("unable to parse type")
InvalidIPOrRange = errors.New("invalid ip address / range")
Expand Down
7 changes: 3 additions & 4 deletions pkg/types/ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package types

import (
"encoding/binary"
"errors"
"fmt"
"math"
"net"
Expand Down Expand Up @@ -38,20 +37,20 @@ func Addr2Ints(anyIP string) (int, int64, int64, int64, int64, error) {
if strings.Contains(anyIP, "/") {
_, net, err := net.ParseCIDR(anyIP)
if err != nil {
return -1, 0, 0, 0, 0, fmt.Errorf("while parsing range %s: %w", anyIP, err)
return -1, 0, 0, 0, 0, fmt.Errorf("invalid ip range '%s': %w", anyIP, err)
}

return Range2Ints(*net)
}

ip := net.ParseIP(anyIP)
if ip == nil {
return -1, 0, 0, 0, 0, errors.New("invalid address")
return -1, 0, 0, 0, 0, fmt.Errorf("invalid ip address '%s'", anyIP)
}

sz, start, end, err := IP2Ints(ip)
if err != nil {
return -1, 0, 0, 0, 0, fmt.Errorf("while parsing ip %s: %w", anyIP, err)
return -1, 0, 0, 0, 0, fmt.Errorf("invalid ip address '%s': %w", anyIP, err)

Check warning on line 53 in pkg/types/ip.go

View check run for this annotation

Codecov / codecov/patch

pkg/types/ip.go#L53

Added line #L53 was not covered by tests
}

return sz, start, end, start, end, nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/types/ip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func TestAdd2Int(t *testing.T) {
},
{
in_addr: "xxx2",
exp_error: "invalid address",
exp_error: "invalid ip address 'xxx2'",
},
}

Expand Down
4 changes: 2 additions & 2 deletions test/bats/90_decisions.bats
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ teardown() {
EOT
assert_stderr --partial 'Parsing values'
assert_stderr --partial 'Imported 1 decisions'
assert_file_contains "$LOGFILE" "invalid addr/range 'whatever': invalid address"
assert_file_contains "$LOGFILE" "invalid addr/range 'whatever': invalid ip address 'whatever'"

rune -0 cscli decisions list -a -o json
assert_json '[]'
Expand All @@ -189,7 +189,7 @@ teardown() {
EOT
assert_stderr --partial 'Parsing values'
assert_stderr --partial 'Imported 3 decisions'
assert_file_contains "$LOGFILE" "invalid addr/range 'bad-apple': invalid address"
assert_file_contains "$LOGFILE" "invalid addr/range 'bad-apple': invalid ip address 'bad-apple'"

rune -0 cscli decisions list -a -o json
rune -0 jq -r '.[0].decisions | length' <(output)
Expand Down

0 comments on commit 520a518

Please sign in to comment.