Skip to content

Commit

Permalink
Consolidate fetching of MySQL server info
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Vaillancourt <[email protected]>
  • Loading branch information
timvaillancourt committed Jan 15, 2024
1 parent 59fd18d commit 06079a9
Show file tree
Hide file tree
Showing 14 changed files with 261 additions and 130 deletions.
14 changes: 3 additions & 11 deletions go/base/context.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2022 GitHub Inc.
Copyright 2023 GitHub Inc.
See https://github.com/github/gh-ost/blob/master/LICENSE
*/

Expand Down Expand Up @@ -163,18 +163,15 @@ type MigrationContext struct {

Hostname string
AssumeMasterHostname string
ApplierTimeZone string
TableEngine string
RowsEstimate int64
RowsDeltaEstimate int64
UsedRowsEstimateMethod RowsEstimateMethod
HasSuperPrivilege bool
OriginalBinlogFormat string
OriginalBinlogRowImage string
InspectorConnectionConfig *mysql.ConnectionConfig
InspectorMySQLVersion string
InspectorServerInfo *mysql.ServerInfo
ApplierConnectionConfig *mysql.ConnectionConfig
ApplierMySQLVersion string
ApplierServerInfo *mysql.ServerInfo
StartTime time.Time
RowCopyStartTime time.Time
RowCopyEndTime time.Time
Expand Down Expand Up @@ -359,11 +356,6 @@ func (this *MigrationContext) GetVoluntaryLockName() string {
return fmt.Sprintf("%s.%s.lock", this.DatabaseName, this.OriginalTableName)
}

// RequiresBinlogFormatChange is `true` when the original binlog format isn't `ROW`
func (this *MigrationContext) RequiresBinlogFormatChange() bool {
return this.OriginalBinlogFormat != "ROW"
}

// GetApplierHostname is a safe access method to the applier hostname
func (this *MigrationContext) GetApplierHostname() string {
if this.ApplierConnectionConfig == nil {
Expand Down
41 changes: 13 additions & 28 deletions go/base/utils.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2022 GitHub Inc.
Copyright 2023 GitHub Inc.
See https://github.com/github/gh-ost/blob/master/LICENSE
*/

Expand All @@ -12,8 +12,6 @@ import (
"strings"
"time"

gosql "database/sql"

"github.com/github/gh-ost/go/mysql"
)

Expand Down Expand Up @@ -61,35 +59,22 @@ func StringContainsAll(s string, substrings ...string) bool {
return nonEmptyStringsFound
}

func ValidateConnection(db *gosql.DB, connectionConfig *mysql.ConnectionConfig, migrationContext *MigrationContext, name string) (string, error) {
versionQuery := `select @@global.version`
var port, extraPort int
var version string
if err := db.QueryRow(versionQuery).Scan(&version); err != nil {
return "", err
}
extraPortQuery := `select @@global.extra_port`
if err := db.QueryRow(extraPortQuery).Scan(&extraPort); err != nil { //nolint:staticcheck
// swallow this error. not all servers support extra_port
}
// ValidateConnection confirms the database server info matches the provided connection config.
func ValidateConnection(serverInfo *mysql.ServerInfo, connectionConfig *mysql.ConnectionConfig, migrationContext *MigrationContext, name string) error {
// AliyunRDS set users port to "NULL", replace it by gh-ost param
// GCP set users port to "NULL", replace it by gh-ost param
// Azure MySQL set users port to a different value by design, replace it by gh-ost para
// Azure MySQL set users port to a different value by design, replace it by gh-ost param
if migrationContext.AliyunRDS || migrationContext.GoogleCloudPlatform || migrationContext.AzureMySQL {
port = connectionConfig.Key.Port
} else {
portQuery := `select @@global.port`
if err := db.QueryRow(portQuery).Scan(&port); err != nil {
return "", err
}
serverInfo.Port.Int64 = connectionConfig.Key.Port
serverInfo.Port.Valid = connectionConfig.Key.Port > 0
}

if connectionConfig.Key.Port == port || (extraPort > 0 && connectionConfig.Key.Port == extraPort) {
migrationContext.Log.Infof("%s connection validated on %+v", name, connectionConfig.Key)
return version, nil
} else if extraPort == 0 {
return "", fmt.Errorf("Unexpected database port reported: %+v", port)
} else {
return "", fmt.Errorf("Unexpected database port reported: %+v / extra_port: %+v", port, extraPort)
if !serverInfo.Port.Valid && !serverInfo.ExtraPort.Valid {
return fmt.Errorf("Unexpected database port reported: %+v", serverInfo.Port.Int64)
} else if connectionConfig.Key.Port != serverInfo.Port.Int64 && connectionConfig.Key.Port != serverInfo.ExtraPort.Int64 {
return fmt.Errorf("Unexpected database port reported: %+v / extra_port: %+v", serverInfo.Port.Int64, serverInfo.ExtraPort.Int64)
}

migrationContext.Log.Infof("%s connection validated on %+v", name, connectionConfig.Key)
return nil
}
85 changes: 84 additions & 1 deletion go/base/utils_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
/*
Copyright 2016 GitHub Inc.
Copyright 2023 GitHub Inc.
See https://github.com/github/gh-ost/blob/master/LICENSE
*/

package base

import (
gosql "database/sql"
"testing"

"github.com/github/gh-ost/go/mysql"
"github.com/openark/golib/log"
test "github.com/openark/golib/tests"
)
Expand All @@ -16,6 +18,10 @@ func init() {
log.SetLevel(log.ERROR)
}

func newMysqlPort(port int64) gosql.NullInt64 {
return gosql.NullInt64{Int64: port, Valid: port > 0}
}

func TestStringContainsAll(t *testing.T) {
s := `insert,delete,update`

Expand All @@ -27,3 +33,80 @@ func TestStringContainsAll(t *testing.T) {
test.S(t).ExpectTrue(StringContainsAll(s, "insert", ""))
test.S(t).ExpectTrue(StringContainsAll(s, "insert", "update", "delete"))
}

func TestValidateConnection(t *testing.T) {
connectionConfig := &mysql.ConnectionConfig{
Key: mysql.InstanceKey{
Hostname: t.Name(),
Port: mysql.DefaultInstancePort,
},
}

// check valid port matching connectionConfig validates
{
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
serverInfo := &mysql.ServerInfo{
Port: newMysqlPort(mysql.DefaultInstancePort),
ExtraPort: newMysqlPort(mysql.DefaultInstancePort + 1),
}
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check NULL port validates when AliyunRDS=true
{
migrationContext := &MigrationContext{
Log: NewDefaultLogger(),
AliyunRDS: true,
}
serverInfo := &mysql.ServerInfo{}
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check NULL port validates when AzureMySQL=true
{
migrationContext := &MigrationContext{
Log: NewDefaultLogger(),
AzureMySQL: true,
}
serverInfo := &mysql.ServerInfo{}
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check NULL port validates when GoogleCloudPlatform=true
{
migrationContext := &MigrationContext{
Log: NewDefaultLogger(),
GoogleCloudPlatform: true,
}
serverInfo := &mysql.ServerInfo{}
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check extra_port validates when port=NULL
{
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
serverInfo := &mysql.ServerInfo{
ExtraPort: newMysqlPort(mysql.DefaultInstancePort),
}
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check extra_port validates when port does not match but extra_port does
{
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
serverInfo := &mysql.ServerInfo{
Port: newMysqlPort(12345),
ExtraPort: newMysqlPort(mysql.DefaultInstancePort),
}
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check validation fails when valid port does not match connectionConfig
{
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
serverInfo := &mysql.ServerInfo{
Port: newMysqlPort(9999),
}
test.S(t).ExpectNotNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check validation fails when port and extra_port are invalid
{
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
serverInfo := &mysql.ServerInfo{}
test.S(t).ExpectNotNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
}
4 changes: 2 additions & 2 deletions go/cmd/gh-ost/main.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2022 GitHub Inc.
Copyright 2023 GitHub Inc.
See https://github.com/github/gh-ost/blob/master/LICENSE
*/

Expand Down Expand Up @@ -49,7 +49,7 @@ func main() {
migrationContext := base.NewMigrationContext()
flag.StringVar(&migrationContext.InspectorConnectionConfig.Key.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)")
flag.StringVar(&migrationContext.AssumeMasterHostname, "assume-master-host", "", "(optional) explicitly tell gh-ost the identity of the master. Format: some.host.com[:port] This is useful in master-master setups where you wish to pick an explicit master, or in a tungsten-replicator where gh-ost is unable to determine the master")
flag.IntVar(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)")
flag.Int64Var(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)")
flag.Float64Var(&migrationContext.InspectorConnectionConfig.Timeout, "mysql-timeout", 0.0, "Connect, read and write timeout for MySQL")
flag.StringVar(&migrationContext.CliUser, "user", "", "MySQL user")
flag.StringVar(&migrationContext.CliPassword, "password", "", "MySQL password")
Expand Down
39 changes: 14 additions & 25 deletions go/logic/applier.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2022 GitHub Inc.
Copyright 2023 GitHub Inc.
See https://github.com/github/gh-ost/blob/master/LICENSE
*/

Expand Down Expand Up @@ -71,25 +71,24 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier {
}
}

func (this *Applier) ServerInfo() *mysql.ServerInfo {
return this.migrationContext.ApplierServerInfo
}

func (this *Applier) InitDBConnections() (err error) {
applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName)
if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil {
return err
}
if this.migrationContext.ApplierServerInfo, err = mysql.GetServerInfo(this.db); err != nil {
return err
}
singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri)
if this.singletonDB, _, err = mysql.GetDB(this.migrationContext.Uuid, singletonApplierUri); err != nil {
return err
}
this.singletonDB.SetMaxOpenConns(1)
version, err := base.ValidateConnection(this.db, this.connectionConfig, this.migrationContext, this.name)
if err != nil {
return err
}
if _, err := base.ValidateConnection(this.singletonDB, this.connectionConfig, this.migrationContext, this.name); err != nil {
return err
}
this.migrationContext.ApplierMySQLVersion = version
if err := this.validateAndReadTimeZone(); err != nil {
if err = base.ValidateConnection(this.ServerInfo(), this.connectionConfig, this.migrationContext, this.name); err != nil {
return err
}
if !this.migrationContext.AliyunRDS && !this.migrationContext.GoogleCloudPlatform && !this.migrationContext.AzureMySQL {
Expand All @@ -102,18 +101,8 @@ func (this *Applier) InitDBConnections() (err error) {
if err := this.readTableColumns(); err != nil {
return err
}
this.migrationContext.Log.Infof("Applier initiated on %+v, version %+v", this.connectionConfig.ImpliedKey, this.migrationContext.ApplierMySQLVersion)
return nil
}

// validateAndReadTimeZone potentially reads server time-zone
func (this *Applier) validateAndReadTimeZone() error {
query := `select /* gh-ost */ @@global.time_zone`
if err := this.db.QueryRow(query).Scan(&this.migrationContext.ApplierTimeZone); err != nil {
return err
}

this.migrationContext.Log.Infof("will use time_zone='%s' on applier", this.migrationContext.ApplierTimeZone)
this.migrationContext.Log.Infof("Applier initiated on %+v, version %+v (%+v)", this.connectionConfig.ImpliedKey,
this.ServerInfo().Version, this.ServerInfo().VersionComment)
return nil
}

Expand Down Expand Up @@ -238,7 +227,7 @@ func (this *Applier) CreateGhostTable() error {
}
defer tx.Rollback()

sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.migrationContext.ApplierTimeZone)
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.ServerInfo().TimeZone)
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())

if _, err := tx.Exec(sessionQuery); err != nil {
Expand Down Expand Up @@ -279,7 +268,7 @@ func (this *Applier) AlterGhost() error {
}
defer tx.Rollback()

sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.migrationContext.ApplierTimeZone)
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.ServerInfo().TimeZone)
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())

if _, err := tx.Exec(sessionQuery); err != nil {
Expand Down Expand Up @@ -640,7 +629,7 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected
}
defer tx.Rollback()

sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.migrationContext.ApplierTimeZone)
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.ServerInfo().TimeZone)
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())

if _, err := tx.Exec(sessionQuery); err != nil {
Expand Down
Loading

0 comments on commit 06079a9

Please sign in to comment.