diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go index b3f9d9c26c..b9bf19c04d 100644 --- a/src/database/sql/driver/driver.go +++ b/src/database/sql/driver/driver.go @@ -55,6 +55,17 @@ type Driver interface { Open(name string) (Conn, error) } +// DriverContext enhances the Driver interface by returning a Connector +// rather then a single Conn. +// It separates out the name parsing step from actually connecting to the +// database. It also gives dialers access to the context by using the +// Connector. +type DriverContext interface { + // OpenConnector must parse the name in the same format that Driver.Open + // parses the name parameter. + OpenConnector(name string) (Connector, error) +} + // Connector is an optional interface that drivers can implement. // It allows drivers to provide more flexible methods to open // database connections without requiring the use of a DSN string. diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index 31e22a7a74..e795412de0 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -71,6 +71,16 @@ func (c *fakeConnector) Driver() driver.Driver { return fdriver } +type fakeDriverCtx struct { + fakeDriver +} + +var _ driver.DriverContext = &fakeDriverCtx{} + +func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { + return &fakeConnector{name: name}, nil +} + type fakeDB struct { name string diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 30b4ad3609..1192eaae26 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -662,6 +662,14 @@ func Open(driverName, dataSourceName string) (*DB, error) { return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) } + if driverCtx, ok := driveri.(driver.DriverContext); ok { + connector, err := driverCtx.OpenConnector(dataSourceName) + if err != nil { + return nil, err + } + return OpenDB(connector), nil + } + return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil } diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index f7b7d988e1..8137eff82b 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -3523,6 +3523,19 @@ func TestNamedValueCheckerSkip(t *testing.T) { } } +func TestOpenConnector(t *testing.T) { + Register("testctx", &fakeDriverCtx{}) + db, err := Open("testctx", "people") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + if _, is := db.connector.(*fakeConnector); !is { + t.Fatal("not using *fakeConnector") + } +} + type ctxOnlyDriver struct { fakeDriver }