Abstracting database/sql boilerplate with Go generics

James Kirk
Eureka Engineering
Published in
11 min readDec 10, 2022
An abstract vision of abstracting database/sql boilerplate

Introduction

This post is part of the Eureka Advent Calendar 2022.

database/sql provides everything needed to work with SQL in Go, yet there’s an ever-growing list of query builders, ORMs and database utilities being developed by the community. Is database/sql really so bad?

In this post we’ll explore what can be achieved using Go generics with no added dependencies. If you’re not familiar with database/sql basics, then check out Tutorial: Accessing a relational database.

The go-generic-dao repository

The code for this post is in the below repository. The vanilla branch contains the starting code, and setup instructions are in README.md.

The project is a minimal example of using Go to connect to a MySQL database. The interesting packages are shown below.

  • db manages an sql.DB instance
  • user and like each have a model and DAO, which depend on the db package to execute SQL

The code is standard database/sql, and depends on only go-sql-driver/mysql. We’ll take a closer look at how the user package is implemented, and explore how the code can be abstracted with generics.

The user package

Let’s take a look at user/user.go first. It has constants for building raw SQL, and a struct representing a row in the user table.

const (
Table = "`user`"
Columns = "`id`, `nickname`, `bio`, `created_at`"
)

type User struct {
ID int64
Nickname string
Bio sql.Null[string]
CreatedAt time.Time
}

// PtrFields is a convenience method for use with sql#Row.Scan.
func (u *User) PtrFields() []any {
return []any{&u.ID, &u.Nickname, &u.Bio, &u.CreatedAt}
}

user/user_dao.go contains a struct for executing raw SQL strings via the db package. It defines the following methods.

GetByID(ctx context.Context, id int64) (User, error) 
Count(ctx context.Context) (int64, error)
FindByIDs(ctx context.Context, ids []int64) ([]User, error)
FindIDsWithBio(ctx context.Context) ([]int64, error)

The implementation of GetByID shows database/sql in action. It retrieves and parses a single row from the user table.

const getByIDQuery = "SELECT " + Columns + 
" FROM " + Table +
" WHERE `id` = ?" +
" LIMIT 1;"

func (DAO) GetByID(ctx context.Context, id int64) (User, error) {
row := db.DB().QueryRowContext(ctx, getByIDQuery, id)
var u User
if err := row.Scan(u.PtrFields()...); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return u, db.ErrNotFound
}
return u, fmt.Errorf("user.DAO#GetByID row.Scan error: %w", err)
}
return u, nil
}

ℹ️ ️️The receiver of PtrFields must be a pointer — (u *User). A value receiver — (u User)—will create a copy of u, resulting in GetByID returning an empty User struct.

=== RUN   TestUserDAO_GetByID/found
user_dao_test.go:65: user.User{
- ID: 1,
+ ID: 0,
- Nickname: "Socks",
+ Nickname: "",
Bio: {},
- CreatedAt: s"2022-12-10 13:17:23 +0000 UTC",
+ CreatedAt: s"0001-01-01 00:00:00 +0000 UTC",
}

--- FAIL: TestUserDAO_GetByID/found (0.05s)

Where the DAO gets painful

GetByID is 11 lines, and doesn’t make such a strong case for abstraction. Implementing a WHERE IN query, on the other hand, takes a lot more effort.

const findByIDsQuery = "SELECT " + Columns + 
" FROM " + Table +
" WHERE `id` IN (%s)" +
" ORDER BY `id`;"

func (DAO) FindByIDs(ctx context.Context, ids []int64) ([]User, error) {
args := make([]any, len(ids))
for i, t := range ids {
args[i] = t
}
placeholders := strings.Repeat("?,", len(args)-1) + "?"

q := fmt.Sprintf(findByIDsQuery, placeholders)
rows, err := db.DB().QueryContext(ctx, q, args...)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("user.DAO#FindByIDs failed\n%s: %w", q, err)
}
defer func() { _ = rows.Close() }()

var result []User
for rows.Next() {
var u User
if err := rows.Scan(u.PtrFields()...); err != nil {
return nil, fmt.Errorf("user.DAO#FindByIDs scan error: %w", err)
}
result = append(result, u)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("user.DAO#FindByIDs rows.Err(): %w", err)
}
return result, nil
}

This 30-line method highlights a few potential issues going forward:

  1. Code duplication — in many cases, only the SQL string and types will differ for any similar Get or Find method
  2. Lack of type safety — care needs to be taken using row.Scan with regard to types and argument ordering (PtrFields helps here)
  3. Inconsistent error handling — in particular, sql.ErrNoRows could cause unexpected errors if not handled consistently

Generics: a brief introduction

A common solution to code duplication until recently has been the use of reflection. However, reflection does little to help with lack of type safety, and tends to come at the cost of readability and performance.

Generics were introduced in Go 1.18 and offer new solutions to the these problems. See An Introduction to Generics for a well-rounded introduction.

Enter the type parameter

To get started with generics in user_dao.go, let’s take another look at the way FindByIDs generates placeholders for the WHERE IN clause.

ids := []int64{1,2,3}

args := make([]any, len(ids))
for i, t := range ids {
args[i] = t
}
placeholders := strings.Repeat("?,", len(args)-1) + "?"

It converts ids to type []any, and creates a string like ?,?,?. The only variable is the type of ids. If the dependency on int64 is removed, we can reuse the code for any type. This is exactly what type parameters are for.

func InArgs[T any](tt []T) (string, []any) {
args := make([]any, len(tt))
for i, t := range tt {
args[i] = t
}
return strings.Repeat("?,", len(args)-1) + "?", args
}

Type parameter T allows us to abstract types in a similar way that function parameters allow us to abstract values. InArgs now works with any slice.

ids := []int64{1,2,3}
placeholders, args := InArgs(ids)
// "?, ?, ?", []any{1,2,3}

nicknames := []string{"Socks", "Darius"}
placeholders, args := InArgs(nicknames)
// "?, ?", []any{"Socks", "Darius"}

ℹ️ InArgs has to be implemented as a function, because you can’t declare new type parameters on a method. A signature like below will not compile.

func (d *DAO) InArgs[T any](tt []T) (string, []any)

ℹ️ Generic type inference depends on the presence of arguments. In cases with no arguments, we must pass types explicitly.

ids := []int64{1,2,3}
InArgs[int64](ids) // explicitly pass type (optional)
InArgs(ids) // int64 inferred from ids argument

// Argless takes a type argument
// but no function arguments.
func Argless[T any]() T {
var t T
return t
}

var i int
i = Argless() // compile error! cannot infer T
i = Argless[int]() // OK

Identifying type parameters

The methods of user.DAO have two common patterns in terms of return types, i.e. the types that row.Scan will populate.

  1. Basic (scalar) types like int64 that are passed to Scan directly, e.g. row.Scan(&value)
  2. Struct types like User that are passed to Scan field-by-field, e.g. row.Scan(&u.ID, &u.Nickname, …)

With the help of reflection, we could inspect types at runtime, and cover both patterns in a single method. With generics, however, we need to verify types explicitly at compile time, which requires some new techniques.

Abstracting a function for basic types

Here’s the code for the Count method, which returns a single int64.

func (DAO) Count(ctx context.Context) (int64, error) {
row := db.DB().QueryRowContext(ctx, countQuery)
var count int64
if err := row.Scan(&count); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
return 0, fmt.Errorf("user.DAO#Count row.Scan error: %w", err)
}
return count, nil
}

row.Scan accepts any type(s), so if we remove the dependency on int64, we can write a generic method for any scalar value. The process is almost the same as with InArgs earlier.

func GetColumn[T any](ctx context.Context, q string, args ...any) (T, error) {
row := database.QueryRowContext(ctx, q, args...)
var t T
if err := row.Scan(&t); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return t, ErrNotFound
}
return t, fmt.Errorf("DAO#GetColumn row.Scan error: %w", err)
}
return t, nil
}

Again we define a type parameter T that replaces the use of int64. We can now call GetColumn from user.DAO#Count.

func (DAO) Count(ctx context.Context) (int64, error) {
return db.GetColumn[int64](ctx, countQuery)
}

The result is concise and readable, even though we’re forced to pass the type argument [int64] explicitly. However, it will fail at runtime if we pass [User] or any other type not supported directly by Scan.

cannot convert 1 (untyped int constant) to 
struct{ID int64; Nickname string; Bio sql.Null[string]; CreatedAt time.Time}

Constraining for type safety

Luckily, Go provides constraints for restricting type arguments. Understanding constraints requires rethinking interfaces in Go.

  • Before generics: an interface is a set of methods that is implemented by one or more types
  • After generics: an interface is a set of types that implement common methods (if any)

These two definitions can be functionally equivalent. Many types implement the set of methods defined by sort.Interface, which makes those types sortable. We can also say that sort.Interface defines the set of types that are sortable.

Why do we care about these semantics? Go 1.18 added a new syntax for interfaces to allow explicit enumeration of a set of types. It can be used to define exactly which types represent a column in the context of row.Scan.

// Column is a constraint that defines the 
// set of types supported by row.Scan.
type Column interface {
~byte | ~int16 | ~int32 | ~int64 | ~float64 |
~string | ~bool | time.Time
}

This constraint can be added to the signature of GetColumn as below. (Note: ~ allows treating type definitions like type ID int64 as Column types too.)

GetColumn[T Column](ctx context.Context, q string, args ...any) (T, error)

This ensures that passing User to GetColumn will fail at compile time instead of runtime.

User does not implement db.Column 
(User missing in ~byte | ~int16 | ~int32 | ~int64 | ~float64 |
~string | ~bool | time.Time)

We can now implement queries for basic types in one line of code, ensure compile-time type safety, and standardise error handling. That’s 3/3 for the problems mentioned above.

As a bonus, the Column constraint can prevent invalid types being passed to InArgs, further improving type safety.

func InArgs[T Column](tt []T) (string, []any)

InArgs([]MyStruct{{ID: 1}}) // compile error

Dealing with structs

Now that scalar types have been handled, let’s look at the second pattern: struct types. Here’s the GetByID method of user.DAO again.

func (DAO) GetByID(ctx context.Context, id int64) (User, error) {
row := db.DB().QueryRowContext(ctx, getByIDQuery, id)
var u User
if err := row.Scan(u.PtrFields()...); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return u, db.ErrNotFound
}
return u, fmt.Errorf("user.DAO#GetByID row.Scan error: %w", err)
}
return u, nil
}

It’s essentially the same as GetColumn, except the use of PtrFields. We wrote the Column constraint to explicitly deny structs, so we need to write a new function, GetRow, with a constraint for valid structs.

ℹ️ Spoiler alert: writing a constraint for structs themselves isn’t possible.

Though we can’t constrain types to structs specifically, we can constrain to a set of types that implement PtrFields, which is just as good for our purposes. In other words, just write an old-fashioned interface.

type Row interface {
PtrFields() []any
}

Now we can copy the GetColumn function, replace the Column constraint with the new Row constraint, and call PtrFields on t.

func GetRow[T Row](ctx context.Context, q string, args ...any) (T, error) {
row := database.QueryRowContext(ctx, q, args...)
var t T
if err := row.Scan(t.PtrFields()...); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return t, ErrNotFound
}
return t, fmt.Errorf("DAO#GetRow row.Scan error\n%s: %w", q, err)
}
return t, nil
}

This allows for a concise implementation of user.DAO#GetByID.

func (DAO) GetByID(ctx context.Context, id int64) (User, error) {
return db.GetRow[User](ctx, getByIDQuery, id)
}

Unfortunately, it leads to a compilation error.

User does not implement db.Row (PtrFields method has pointer receiver)

This is because User doesn’t implement the Row interface. User and *User are considered distinctly different types, and only *User has a PtrFields method.

If we change the type argument and return type to [*User], the code compiles, but then panics at runtime.

var t T
row.Scan(t.PtrFields()...) // <--- panic! t is nil

Obviously, PtrFields can’t be called when t is nil. We want to be able to receive and return type User, but invoke PtrFields on type *User.

Constraining a pointer type

Luckily, we don’t need to work out a way around this problem, because a solution is already detailed in the Type Parameters Proposal document.

The idea is actually simple — pass both User and *User as type arguments. Multiple type parameters like [T any, PT Row] won’t work, because we need to constrain PT to be a pointer of T specifically. It can be done by combining the ideas used for the Column and Row constraints.

type Row[T any] interface {
PtrFields() []any
*T
}

The new Row constraint ensures that a type:

  • Implements a PtrFields method
  • Is in the set of types *T, a single type that is a pointer of type T

In our case, if T is User, then we know *User meets the requirements of Row[User]. The new constraint can be used as below.

GetRow[T any, PT Row[T]](ctx context.Context, q string, args ...any) (T, error)

The updated body of GetRow looks like this.

var t T
ptr := PT(&t)
row.Scan(ptr.PtrFields()...)

It involves wrapping a pointer of T in PT to ensure we have access to PtrFields. It’s similar to using the original Row interface as below.

var u User
ptr := Row(&u)
ptr.PtrFields()

We can now take another look at the concise GetByID method knowing that the tests are passing successfully.

func (DAO) GetByID(ctx context.Context, id int64) (User, error) {
return db.GetRow[User](ctx, getByIDQuery, id)
}

Note that PT can be inferred from T thanks to constraint type inference, so it doesn’t need to be passed explicitly, though we could optionally do so as GetRow[User, *User].

Curing the pain

Now we’re ready to take on the 30-line FindByIDs method. All that’s needed is to replace any reference to User with a type parameter T, and constrain T in the same way as GetRow above.

func FindRows[T any, PT Row[T]](ctx context.Context, q string, args ...any) ([]T, error) {
rows, err := database.QueryContext(ctx, q, args...)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("DAO#FindRows QueryContext failed: %w", err)
}
defer func() { _ = rows.Close() }()

var result []T
for rows.Next() {
var t T
ptr := PT(&t)
if err := rows.Scan(ptr.PtrFields()...); err != nil {
return nil, fmt.Errorf("DAO#FindRows row.Scan error: %w", err)
}
result = append(result, t)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("DAO#FindRows rows.Err(): %w", err)
}
return result, nil
}

The code is still long and a little ugly, but calling it together with InArgs allows executing a WHERE IN query in a few lines of code.

func (DAO) FindByIDs(ctx context.Context, ids []int64) ([]User, error) {
placeholders, args := db.InArgs(ids)
q := fmt.Sprintf(findByIDsQuery, placeholders)
return db.FindRows[User](ctx, q, args...)
}

But is FindRows really generic? Let’s implement like.DAO#FindByUser.

const findByUserQuery = "SELECT " + Columns + 
" FROM " + Table +
" WHERE `user_id` = ?" +
" ORDER BY `id`;"

func (DAO) FindByUser(ctx context.Context, userID int64) ([]Like, error) {
return db.FindRows[Like](ctx, findByUserQuery, userID)
}

FindRows can be called in the same way for User, Like, and any other type that implements a PtrFields method.

For a complete diff of everything implemented in this post versus the standard database/sql implementation, check out this pull request.

Summary

Here’s a review of what’s been achieved in terms of the original three problems.

  1. Code duplication —basic queries can now be written in a single line by calling GetRow and FindRows (or the Column equivalents)
  2. Lack of type safety —types are constrained by the Row and Columns constraints, making runtime errors far less common
  3. Inconsistent error handling —handling of sql.ErrNoRows is fixed within GetRow and FindRows to a pre-defined pattern

Of course, the raw SQL queries aren’t type-safe, and the production-readiness of the generic functions remains to be tested, but hopefully this post gave you some insight and ideas on how to go about abstracting code in your own projects.

References

--

--