Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support set custom unmarshal row handle #4312

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions core/stores/sqlx/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ type rowsScanner interface {
Scan(v ...any) error
}

// RowsScanner alias of rowsScanner
type RowsScanner = rowsScanner

// UnmarshalRowHandler defines the method to unmarshal row.
type UnmarshalRowHandler func(v any, rows RowsScanner, strict bool) error

// UnmarshalRowsHandler defines the method to unmarshal rows.
// alias of UnmarshalRowHandler
type UnmarshalRowsHandler = UnmarshalRowHandler

func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) {
rt := mapping.Deref(v.Type())
size := rt.NumField()
Expand Down
50 changes: 41 additions & 9 deletions core/stores/sqlx/sqlconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ type (
// Because CORBA doesn't support PREPARE, so we need to combine the
// query arguments into one string and do underlying query without arguments
commonSqlConn struct {
connProv connProvider
onError func(context.Context, error)
beginTx beginnable
brk breaker.Breaker
accept breaker.Acceptable
connProv connProvider
onError func(context.Context, error)
beginTx beginnable
brk breaker.Breaker
accept breaker.Acceptable
unmarshalRowHandler UnmarshalRowHandler
unmarshalRowsHandler UnmarshalRowsHandler
}

connProvider func() (*sql.DB, error)
Expand Down Expand Up @@ -163,10 +165,12 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
}

stmt = statement{
query: query,
stmt: st,
brk: db.brk,
accept: db.acceptable,
query: query,
stmt: st,
brk: db.brk,
accept: db.acceptable,
unmarshalRowHandler: db.unmarshalRowHandler,
unmarshalRowsHandler: db.unmarshalRowsHandler,
}
return nil
}, db.acceptable)
Expand All @@ -189,6 +193,9 @@ func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v any, q string,
}()

return db.queryRows(ctx, func(rows *sql.Rows) error {
if db.unmarshalRowHandler != nil {
return db.unmarshalRowHandler(v, rows, true)
}
return unmarshalRow(v, rows, true)
}, q, args...)
}
Expand All @@ -205,6 +212,9 @@ func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v any,
}()

return db.queryRows(ctx, func(rows *sql.Rows) error {
if db.unmarshalRowHandler != nil {
return db.unmarshalRowHandler(v, rows, false)
}
return unmarshalRow(v, rows, false)
}, q, args...)
}
Expand All @@ -221,6 +231,9 @@ func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v any, q string,
}()

return db.queryRows(ctx, func(rows *sql.Rows) error {
if db.unmarshalRowsHandler != nil {
return db.unmarshalRowsHandler(v, rows, true)
}
return unmarshalRows(v, rows, true)
}, q, args...)
}
Expand All @@ -237,6 +250,9 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
}()

return db.queryRows(ctx, func(rows *sql.Rows) error {
if db.unmarshalRowsHandler != nil {
return db.unmarshalRowsHandler(v, rows, false)
}
return unmarshalRows(v, rows, false)
}, q, args...)
}
Expand Down Expand Up @@ -325,3 +341,19 @@ func WithAcceptable(acceptable func(err error) bool) SqlOption {
}
}
}

// WithMysqlUnmarshalRowHandler returns a SqlOption that setting the UnmarshalRowHandler.
// handler is the func to unmarshal a row.
func WithMysqlUnmarshalRowHandler(handler UnmarshalRowHandler) SqlOption {
return func(conn *commonSqlConn) {
conn.unmarshalRowHandler = handler
}
}

// WithMysqlUnmarshalRowsHandler returns a SqlOption that setting the UnmarshalRowsHandler.
// handler is the func to unmarshal rows.
func WithMysqlUnmarshalRowsHandler(handler UnmarshalRowsHandler) SqlOption {
return func(conn *commonSqlConn) {
conn.unmarshalRowsHandler = handler
}
}
16 changes: 16 additions & 0 deletions core/stores/sqlx/sqlconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,19 @@ func buildConn() (mock sqlmock.Sqlmock, err error) {
})
return
}

func TestWithMysqlUnmarshalRowHandler(t *testing.T) {
var handler UnmarshalRowHandler = func(v any, scanner rowsScanner, _ bool) error {
return nil
}
conn := NewMysql(mockedDatasource, WithMysqlUnmarshalRowHandler(handler))
assert.NotNil(t, conn.(*commonSqlConn).unmarshalRowHandler)
}

func TestWithMysqlUnmarshalRowsHandler(t *testing.T) {
var handler UnmarshalRowsHandler = func(v any, scanner rowsScanner, _ bool) error {
return nil
}
conn := NewMysql(mockedDatasource, WithMysqlUnmarshalRowsHandler(handler))
assert.NotNil(t, conn.(*commonSqlConn).unmarshalRowsHandler)
}
22 changes: 15 additions & 7 deletions core/stores/sqlx/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ type (
}

statement struct {
query string
stmt *sql.Stmt
brk breaker.Breaker
accept breaker.Acceptable
query string
stmt *sql.Stmt
brk breaker.Breaker
accept breaker.Acceptable
unmarshalRowHandler UnmarshalRowHandler
unmarshalRowsHandler UnmarshalRowsHandler
}

stmtConn interface {
Expand Down Expand Up @@ -89,6 +91,9 @@ func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err err
}()

return s.queryRows(ctx, func(v any, scanner rowsScanner) error {
if s.unmarshalRowHandler != nil {
return s.unmarshalRowHandler(v, scanner, true)
}
return unmarshalRow(v, scanner, true)
}, v, args...)
}
Expand All @@ -104,7 +109,10 @@ func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (
}()

return s.queryRows(ctx, func(v any, scanner rowsScanner) error {
return unmarshalRow(v, scanner, false)
if s.unmarshalRowsHandler != nil {
return s.unmarshalRowHandler(v, scanner, false)
}
return unmarshalRows(v, scanner, false)
}, v, args...)
}

Expand All @@ -119,7 +127,7 @@ func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err er
}()

return s.queryRows(ctx, func(v any, scanner rowsScanner) error {
return unmarshalRows(v, scanner, true)
return s.unmarshalRowsHandler(v, scanner, true)
}, v, args...)
}

Expand All @@ -134,7 +142,7 @@ func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any)
}()

return s.queryRows(ctx, func(v any, scanner rowsScanner) error {
return unmarshalRows(v, scanner, false)
return s.unmarshalRowsHandler(v, scanner, false)
}, v, args...)
}

Expand Down
69 changes: 59 additions & 10 deletions core/stores/sqlx/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

type (
beginnable func(*sql.DB) (trans, error)
beginnable func(*sql.DB, ...TxOption) (trans, error)

trans interface {
Session
Expand All @@ -23,7 +23,11 @@ type (

txSession struct {
*sql.Tx
unmarshalRowHandler UnmarshalRowHandler
unmarshalRowsHandler UnmarshalRowsHandler
}

TxOption func(*txSession)
)

func (s txConn) RawDB() (*sql.DB, error) {
Expand All @@ -38,10 +42,30 @@ func (s txConn) TransactCtx(_ context.Context, _ func(context.Context, Session)
return errCantNestTx
}

// WithTxUnmarshalRowHandler sets the UnmarshalRowHandler for the txSession.
// It's used to customize the unmarshal behavior for QueryRow and QueryRowPartial.
func WithTxUnmarshalRowHandler(handler UnmarshalRowHandler) TxOption {
return func(ts *txSession) {
ts.unmarshalRowHandler = handler
}
}

// WithTxUnmarshalRowsHandler sets the UnmarshalRowsHandler for the txSession.
// It's used to customize the unmarshal behavior for QueryRows and QueryRowsPartial.
func WithTxUnmarshalRowsHandler(handler UnmarshalRowsHandler) TxOption {
return func(ts *txSession) {
ts.unmarshalRowsHandler = handler
}
}

// NewSessionFromTx returns a Session with the given sql.Tx.
// Use it with caution, it's provided for other ORM to interact with.
func NewSessionFromTx(tx *sql.Tx) Session {
return txSession{Tx: tx}
func NewSessionFromTx(tx *sql.Tx, opts ...TxOption) Session {
ts := txSession{Tx: tx}
for _, opt := range opts {
opt(&ts)
}
return ts
}

func (t txSession) Exec(q string, args ...any) (sql.Result, error) {
Expand Down Expand Up @@ -75,9 +99,11 @@ func (t txSession) PrepareCtx(ctx context.Context, q string) (stmtSession StmtSe
}

return statement{
query: q,
stmt: stmt,
brk: breaker.NopBreaker(),
query: q,
stmt: stmt,
brk: breaker.NopBreaker(),
unmarshalRowHandler: t.unmarshalRowHandler,
unmarshalRowsHandler: t.unmarshalRowsHandler,
}, nil
}

Expand All @@ -92,6 +118,9 @@ func (t txSession) QueryRowCtx(ctx context.Context, v any, q string, args ...any
}()

return query(ctx, t.Tx, func(rows *sql.Rows) error {
if t.unmarshalRowHandler != nil {
return t.unmarshalRowHandler(v, rows, true)
}
return unmarshalRow(v, rows, true)
}, q, args...)
}
Expand All @@ -108,6 +137,9 @@ func (t txSession) QueryRowPartialCtx(ctx context.Context, v any, q string,
}()

return query(ctx, t.Tx, func(rows *sql.Rows) error {
if t.unmarshalRowHandler != nil {
return t.unmarshalRowHandler(v, rows, false)
}
return unmarshalRow(v, rows, false)
}, q, args...)
}
Expand All @@ -123,6 +155,9 @@ func (t txSession) QueryRowsCtx(ctx context.Context, v any, q string, args ...an
}()

return query(ctx, t.Tx, func(rows *sql.Rows) error {
if t.unmarshalRowsHandler != nil {
return t.unmarshalRowsHandler(v, rows, true)
}
return unmarshalRows(v, rows, true)
}, q, args...)
}
Expand All @@ -139,19 +174,28 @@ func (t txSession) QueryRowsPartialCtx(ctx context.Context, v any, q string,
}()

return query(ctx, t.Tx, func(rows *sql.Rows) error {
if t.unmarshalRowsHandler != nil {
return t.unmarshalRowsHandler(v, rows, false)
}
return unmarshalRows(v, rows, false)
}, q, args...)
}

func begin(db *sql.DB) (trans, error) {
func begin(db *sql.DB, opts ...TxOption) (trans, error) {
tx, err := db.Begin()
if err != nil {
return nil, err
}

return txSession{
ts := txSession{
Tx: tx,
}, nil
}

for _, opt := range opts {
opt(&ts)
}

return ts, nil
}

func transact(ctx context.Context, db *commonSqlConn, b beginnable,
Expand All @@ -162,7 +206,12 @@ func transact(ctx context.Context, db *commonSqlConn, b beginnable,
return err
}

return transactOnConn(ctx, conn, b, fn)
return transactOnConn(ctx, conn, func(d *sql.DB, _ ...TxOption) (trans, error) {
return b(d,
WithTxUnmarshalRowHandler(db.unmarshalRowHandler),
WithTxUnmarshalRowsHandler(db.unmarshalRowsHandler),
)
}, fn)
}

func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable,
Expand Down
18 changes: 17 additions & 1 deletion core/stores/sqlx/tx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (mt *mockTx) Rollback() error {
}

func beginMock(mock *mockTx) beginnable {
return func(*sql.DB) (trans, error) {
return func(*sql.DB, ...TxOption) (trans, error) {
return mock, nil
}
}
Expand Down Expand Up @@ -310,3 +310,19 @@ func runTxTest(t *testing.T, f func(conn SqlConn, mock sqlmock.Sqlmock)) {
f(conn, mock)
})
}

func TestWithTxUnmarshalRowHandler(t *testing.T) {
var handler UnmarshalRowHandler = func(v any, scanner rowsScanner, _ bool) error {
return nil
}
tx := NewSessionFromTx(nil, WithTxUnmarshalRowHandler(handler))
assert.NotNil(t, tx.(txSession).unmarshalRowHandler)
}

func TestWithTxUnmarshalsRowHandler(t *testing.T) {
var handler UnmarshalRowsHandler = func(v any, scanner rowsScanner, _ bool) error {
return nil
}
tx := NewSessionFromTx(nil, WithTxUnmarshalRowsHandler(handler))
assert.NotNil(t, tx.(txSession).unmarshalRowsHandler)
}