TIL Creating Transient Databases For Testing
This post explores an approach to unit testing database-dependent code by implementing a single backing Postgres instance and a disposable database for each test case.
I often encounter a challenge when unit testing code that depends on a database. In such cases, I’ve typically resorted to using an ORM that can be set up to utilize an in-memory database. While this approach is beneficial for swift iterations, it falls short of providing an accurate representation of how your code will perform in a production environment, especially if you depend on specific features of a database like Postgres.
In light of this, I’d like to share the solution I devised. It involves using a single backing Postgres database and creating a transient database for each test case. This method enables you to execute your tests in parallel and reset the database for each test case as needed.
Here’s a code snippet of how I implemented this approach:
package testdb
import (
"database/sql"
"fmt"
"os"
"strings"
"testing"
_ "github.com/lib/pq"
)
type OptionsFunc func(*options)
// options is a struct that holds the options for the test database you can use this
// to override the default options within a test case. This may be useful if you
// want to test a specific database name or provide some other configuration options.
//
// The functional options pattern will let you grow your options without breaking
// existing code.
type options struct {
database string
}
// WithRandomDatabase is a functional option that will generate a random database name
// for each test case. This is useful if you want to run your tests in parallel or isolate
// your tests from each other.
func WithRandomDatabase() func(*options) {
// rand.Seed is only needed before go 1.20
// rand.Seed(time.Now().UnixNano())
return func(o *options) {
const letters = "abcdefghijklmnopqrstuvwxyz"
const length = 10
result := make([]byte, length)
for i := 0; i < length; i++ {
result[i] = letters[rand.Intn(len(letters))]
}
o.database = string(result)
}
}
// Helper function for getting environment variables with a default value
func envOrDefault(key, defaultValue string) string {
if value, ok := os.LookupEnv(key); ok {
return value
}
return defaultValue
}
func New(t *testing.T, fns ...OptionsFunc) *sql.DB {
t.Helper()
options := &options{
database: envOrDefault("API_POSTGRES_DATABASE", "postgres"),
}
for _, fn := range fns {
fn(options)
}
// Create a DatabaseString _WITHOUT_ the database name
pgConnectionStr := fmt.Sprintf("user=%s password=%s sslmode=%s host=%s port=%s",
envOrDefault("API_POSTGRES_USER", "postgres"),
envOrDefault("API_POSTGRES_PASSWORD", "postgres"),
envOrDefault("API_POSTGRES_SSL_MODE", "disable"),
envOrDefault("API_POSTGRES_HOST", "postgres"),
envOrDefault("API_POSTGRES_PORT", "5432"),
)
db, err := sql.Open("postgres", pgConnectionStr)
if err != nil {
t.Fatal(err)
}
// Here we make a query to check if the database exists
var exists bool
err = db.QueryRow("SELECT EXISTS (SELECT 1 FROM pg_database WHERE datname = $1)", options.database).Scan(&exists)
if err != nil {
// Unfortunately, the error returned by the database is a string, so we have to do a string comparison
// to check for the specific error we're looking for.
errStr := err.Error()
if strings.Contains(errStr, "does not exist") && strings.Contains(errStr, "database") {
exists = false
} else {
t.Fatal(err)
}
}
// If the database doesn't exist, we create it
if !exists {
_, err = db.Exec("CREATE DATABASE " + options.database)
if err != nil {
t.Fatal(err)
}
}
err = db.Close() // Close the root connection
if err != nil {
t.Fatal(err)
}
// Reopen the database with the database name
db, err = sql.Open("postgres", pgConnectionStr+" dbname="+options.database)
if err != nil {
t.Fatal(err)
}
// Optionally, you can register a cleanup function to drop the database after the test case is complete
// though if you're using docker to run your tests, you can just use the --rm flag to remove the container
// after the test is complete.
t.Cleanup(func() {
con, err := sql.Open("postgres", pgConnectionStr)
_, err = con.Exec("DROP DATABASE " + options.database)
if err != nil {
t.Fatal(err)
}
})
return db
}