Goのジェネリクスを使って、テーブル駆動テスト(TDT)に統一性を持たせる
私は普段Goのテストを書く際に、テーブル駆動テスト(TDT)を用いることが多いです。
しかし、チームで開発をしていると、メンバーによってテストの書き方が異なることがしばしばあり、それにより、テストを実装する際に書き方に迷ったり、レビューがしづらくなったりと、テストの実装において問題が生じていました。
そこで、ジェネリクスを活用することで、汎用的に使えるテストケースの構造体を定義し、それをプロジェクト全体で共通で使用することで、テーブル駆動テストの書き方に統一性を持たせることができたので、今回はその方法を紹介したいと思います!
テーブル駆動テストとは
テーブル駆動テストとは、複数のテストケースをスライスで定義し、ループすることで一つのテスト関数で複数のテストケースを実行する方法です。
package main
func Sum(numbers []int) int {
total := 0
for _, n := range numbers {
total += n
}
return total
}
func TestSum(t *testing.T) {
cases := []struct {
name string
numbers []int
want int
}{
{"空のスライス", []int{}, 0},
{"1つの要素", []int{42}, 42},
{"複数の要素", []int{1, 2, 3, 4}, 10},
{"負の数を含む", []int{-1, -2, -3, -4}, -10},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := Sum(tc.numbers)
if got != tc.want {
t.Errorf("Sum(%v) = %d, want %d", tc.numbers, got, tc.want)
}
})
}
}
テーブル駆動テストには、下記のようなメリットがあります。
- テストケースの追加・編集・削除がしやすい
- テストの期待値が明確になる
- 繰り返し実行されるコードをRunに集約できる
テーブル駆動テストの記法がバラバラになる問題
上記の例のように、シンプルなロジックのテストであれば、テーブル駆動テストを用いて、テストコードをシンプルに保つことができます。
ただ、Webアプリケーションのロジックのテストの際には、返り値の検証に加えて、DBやモックなどの準備や検証も行いたい場面があります。
それらを取り入れる際に、テストケースのフィールドに関数を追加してロジックを書いたり、Runの中で実行したりなど、実現方法が複数存在し、それによって記法がバラバラになる問題が発生していました。
そこで、何のテストを書くかや誰が書くかによらず、記法を統一するため、プロジェクト内で共通で使用するテーブル駆動テスト用の構造体を定義することにしました。
テーブル駆動テスト用の構造体
構造体の定義は下記のようになります。
フィールドの型定義にジェネリクスを活用することで、引数や返り値の型をテストに応じて変えられるので、汎用的に使うことができます。
type TestCase[TArgs, TResult, TMockFields any] struct {
// テスト名
Name string
// テストする関数の引数
Args TArgs
// DBなどを操作する準備
Prepare func()
// モックの準備
PrepareMock func(m *TMockFields)
// DBなどのアサーション
Assert func()
// 返り値のアサーション
AssertResult func(r TResult)
// モックのアサーション
AssertMock func(m *TMockFields)
}
使用例
上記の構造体を使ったテスト実装の例をいくつか紹介します。
構造体の全てのフィールドを使用する必要はないので、テストによっては使わないフィールドは省略しても大丈夫です。
シンプルなロジックのテスト
package main
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSumWithTestCase(t *testing.T) {
type args struct {
numbers []int
}
type result struct {
sum int
}
testcases := []TestCase[args, result, any]{
{
Name: "空のスライス",
Args: args{
numbers: []int{},
},
AssertResult: func(r result) {
assert.Equal(t, 0, r.sum)
},
},
{
Name: "1つの要素",
Args: args{
numbers: []int{42},
},
AssertResult: func(r result) {
assert.Equal(t, 42, r.sum)
},
},
}
for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
got := Sum(tc.Args.numbers)
tc.AssertResult(result{sum: got})
})
}
}
モックを使用するテスト
テスト対象のコード
package main
import "errors"
type Interface1 interface {
Run() string
}
type Interface2 interface {
Run() string
}
type SomeService struct {
i1 Interface1
i2 Interface2
}
func (s *SomeService) Run(str string) (string, error) {
str1 := s.i1.Run()
str2 := s.i2.Run()
if str == "" && str1 == "" && str2 == "" {
return "", errors.New("error")
}
return str + str1 + str2, nil
}
func NewSomeService(i1 Interface1, i2 Interface2) *SomeService {
return &SomeService{i1: i1, i2: i2}
}
モック定義
package main
import (
"github.com/stretchr/testify/mock"
)
type MockInterface1 struct {
mock.Mock
}
func (m *MockInterface1) Run() string {
return m.Called().String(0)
}
type MockInterface2 struct {
mock.Mock
}
func (m *MockInterface2) Run() string {
return m.Called().String(0)
}
package main
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSomeService_Run(t *testing.T) {
type args struct {
str string
}
type result struct {
str string
err error
}
type mockFields struct {
i1 *MockInterface1
i2 *MockInterface2
}
testcases := []TestCase[args, result, mockFields]{
{
Name: "正常系",
Args: args{
str: "test",
},
PrepareMock: func(m *mockFields) {
m.i1.On("Run").Return("a")
m.i2.On("Run").Return("b")
},
AssertResult: func(r result) {
assert.NoError(t, r.err)
assert.Equal(t, "testab", r.str)
},
AssertMock: func(m *mockFields) {
m.i1.AssertCalled(t, "Run")
m.i2.AssertCalled(t, "Run")
},
},
{
Name: "異常系",
Args: args{
str: "",
},
PrepareMock: func(m *mockFields) {
m.i1.On("Run").Return("")
m.i2.On("Run").Return("")
},
AssertResult: func(r result) {
assert.EqualError(t, r.err, "error")
},
AssertMock: func(m *mockFields) {
m.i1.AssertCalled(t, "Run")
m.i2.AssertCalled(t, "Run")
},
},
}
for _, tc := range testcases {
t.Run(tc.Name, func(t *testing.T) {
t.Log(tc.Name)
mockFields := &mockFields{
i1: new(MockInterface1),
i2: new(MockInterface2),
}
tc.PrepareMock(mockFields)
svc := NewSomeService(mockFields.i1, mockFields.i2)
got, err := svc.Run(tc.Args.str)
tc.AssertResult(result{str: got, err: err})
tc.AssertMock(mockFields)
})
}
}
DB操作を行うテスト
テスト対象のコード
package main
import (
"database/sql"
"errors"
)
type User struct {
ID string
Name string
}
func UpdateUser(db *sql.DB, user User) error {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE id = $1", user.ID).Scan(&count)
if err != nil {
return err
}
if count == 0 {
return errors.New("user not found")
}
_, err = db.Exec("UPDATE users SET name = $1 WHERE id = $2", user.Name, user.ID)
return err
}
package main
import (
"database/sql"
"testing"
_ "github.com/lib/pq"
"github.com/stretchr/testify/assert"
)
func TestUpdateUser(t *testing.T) {
var userID = "1"
db, err := sql.Open("postgres", "...")
if err != nil {
assert.Fail(t, err.Error())
}
defer db.Close()
type args struct {
user User
}
type result struct {
err error
}
testcases := []TestCase[args, result, any]{
{
Name: "正常系",
Args: args{
user: User{
ID: userID,
Name: "test-update",
},
},
Prepare: func() {
_, err := db.Exec("INSERT INTO users (id, name) VALUES ($1, 'test')", userID)
if err != nil {
assert.Fail(t, err.Error())
}
},
AssertResult: func(r result) {
assert.NoError(t, r.err)
},
Assert: func() {
user, err := selectUser(db, userID)
if err != nil {
assert.Fail(t, err.Error())
}
assert.Equal(t, "test-update", user.Name)
},
},
{
Name: "異常系",
Args: args{
user: User{
ID: "not-exist-id",
Name: "test-update",
},
},
Prepare: func() {},
AssertResult: func(r result) {
assert.EqualError(t, r.err, "user not found")
},
Assert: func() {
user, err := selectUser(db, "not-exist-id")
assert.Error(t, err)
assert.Empty(t, user)
},
},
}
for _, tc := range testcases {
t.Run(tc.Name, func(t *testing.T) {
tc.Prepare()
err := UpdateUser(db, tc.Args.user)
tc.AssertResult(result{err: err})
tc.Assert()
})
}
}
func selectUser(db *sql.DB, id string) (User, error) {
var user User
err := db.QueryRow("SELECT id, name FROM users WHERE id = $1", id).Scan(&user.ID, &user.Name)
return user, err
}
まとめ
今回はGoのジェネリクスを使って、テーブル駆動テストに統一性を持たせる方法を紹介しました。
これにより、私のチームではテストの実装方法に迷うことが減り、実装スピード向上にも繋がっています。
また、構造体の型定義はチームの方針に合わせて定義することで、より効果的に使えると思います。
一方で、今回紹介した構造体のフィールドを全て使う必要がある場合などでは、テーブル駆動テストを採用することで、逆にテストケースが肥大化し、可読性が低下する場合もあります。
そのため、場合によってはテーブル駆動テストそのものを使わない方が良い場合もあると思うので、状況に応じて使い分けることも大事だと思います。
Discussion