SODA Engineering Blog
📖

AI使わずにgomockでGoテスト生成

2024/12/16に公開

この記事を質問から始めましょう:ユニットテストのテストケースを作成するのが好きですか?
問題ないと感じる人もいるかもしれませんが、大量のモックを必要とする場合、テストケースの作成を好まない人もいるでしょう。エンジニアなので、このプロセスをコンピューターに自動化させてみてはいかがでしょうか?この記事では、AIを使用せずに、常に決定論的な結果を得られる方法でユニットテストケースを生成するアプローチを探ります。

テスト生成とは

テストジェネレーターは、SEがテストケースを手動で作成する手間を大幅に軽減する便利な手段を提供します。これらのツールには、簡単な実装のものから非常に高度なソリューションまで、さまざまな種類があります。AIはテスト生成において非常に有用な役割を果たすことがありますが、AIに過度に依存することには問題もあります。AIが生成するテストは、結果が毎回異なることがあるため、一貫性や信頼性に欠ける場合があるのです。

背景

次の実装に対してテストケースを作成すると仮定してみましょう:

hoge.go
package hoge

import "context"

type HogeImpl struct {
    service Service
}

func NewHogeImpl(service Service) *HogeImpl {
    return &HogeImpl{
        service: service,
    }
}

func (u HogeImpl) Do(ctx context.Context) (int, error) {
    if err := u.service.A(ctx, 50); err != nil {
        return 0, err
    }

    b, err := u.service.B(ctx, 50)
    if err != nil {
        return 0, err
    }

    c := u.service.C()

    return b.Y + c, nil
}
service.go
package hoge

import "context"

type Out struct {
    X string
    Y int
}

//go:generate mockgen -typed -source=service.go -destination=service_mock.go -package=hoge
type Service interface {
    A(ctx context.Context, in int) error
    B(ctx context.Context, in int) (*Out, error)
    C() int
}

この例では、Serviceという3つの関数を持つインターフェースを受け取るHogeImplがあります。

その関数に対するユニットテストは次のようになります:

hoge_test.go
package hoge_test

import (
    "context"
    "errors"
    "strconv"
    "testing"

    "github.com/johndoe/hoge"
    "go.uber.org/mock/gomock"
)

func TestHogeImpl_Do(t *testing.T) {
    mockErr := errors.New("mock error")

    cases := []struct {
        setup   func(ctrl *gomock.Controller) *hoge.MockService
        want    int
        wantErr error
    }{
        {
            setup: func(ctrl *gomock.Controller) *hoge.MockService {
                mockMockService := hoge.NewMockService(ctrl)

                mockMockService.EXPECT().A(gomock.Any(), gomock.Any()).Return(mockErr)

                return mockMockService
            },
            wantErr: mockErr,
        },
        {
            setup: func(ctrl *gomock.Controller) *hoge.MockService {
                mockMockService := hoge.NewMockService(ctrl)

                mockMockService.EXPECT().A(gomock.Any(), gomock.Any()).Return(mockErr)

                return mockMockService
            },
            wantErr: mockErr,
        },
        {
            setup: func(ctrl *gomock.Controller) *hoge.MockService {
                mockMockService := hoge.NewMockService(ctrl)

                mockMockService.EXPECT().A(gomock.Any(), gomock.Any()).Return(nil)
                mockMockService.EXPECT().B(gomock.Any(), gomock.Any()).Return(nil, mockErr)

                return mockMockService
            },
            wantErr: mockErr,
        },
        {
            setup: func(ctrl *gomock.Controller) *hoge.MockService {
                mockMockService := hoge.NewMockService(ctrl)

                mockMockService.EXPECT().A(gomock.Any(), gomock.Any()).Return(nil)
                mockMockService.EXPECT().B(gomock.Any(), gomock.Any()).Return(&hoge.Out{X: "This is X", Y: 50}, nil)
                mockMockService.EXPECT().C().Return(35)

                return mockMockService
            },
            want: 85,
        },
    }

    for i, c := range cases {
        c := c
        t.Run("Case #"+strconv.Itoa(i), func(t *testing.T) {
            ctrl := gomock.NewController(t)
            mockMockService := c.setup(ctrl)

            impl := hoge.NewHogeImpl(mockMockService)
            got, err := impl.Do(context.Background())
            if !errors.Is(err, c.wantErr) {
                t.Errorf("expected %v but got %v for error", c.wantErr, err)
            }

            if got != c.want {
                t.Errorf("expected %v but got %v", c.want, got)
            }
        })
    }
}

将来的に、HogeImpl.Doが行う呼び出しの数を増やす必要が出てくるかもしれないです。その場合、関数は次のようになる可能性があります:

hoge.go
package hoge

// ..

func (u HogeImpl) Do(ctx context.Context) (int, error) {
    if err := u.service.A(ctx, 50); err != nil {
        return 0, err
    }

    b, err := u.service.B(ctx, 50)
    if err != nil {
        return 0, err
    }

    c := u.service.C()

    d, err := u.service.D(ctx, b.Y + c)
    if err != nil {
        return 0, err
    }

    e, err := u.service.E(ctx, d)
    if err != nil {
        return 0, err
    }

    f, err := u.service.F(ctx, e)
    if err != nil {
        return 0, err
    }

    // So on...

    return f, nil
}

すべてのif err != nil条件に対してテストケースを作成するのは、非常に繰り返しが多く、時間のかかる作業になりがちですね。これらのテストを手動で作成する代わりに、そのプロセスを自動化できたら便利だと思わないですか?例えば、モックの呼び出し結果のすべての組み合わせを体系的に網羅するテストファイルを生成することができます。このアプローチにより、エラーハンドリングシナリオのテスト作成が大幅に効率化されるでしょう。

まず、この記事を「検知・分析」と「生成」の2つのパートに分けて進めていきます。

検知・分析

テストケースを、どの関数が呼び出されたかに基づいて生成するためには、実装をプログラム的に分析する必要があります。実装を分析することで、関数のリターン値のリストに基づいて、どの関数が呼び出されたのかを検出できるようになります。

方法

AST (Abstract Syntax Tree) + タイプ情報

この方法では、GoのASTライブラリとコンパイル前のタイプ情報(リフレクションの型情報とは異なる)を活用します。ソースファイルを直接分析するため、コンパイルプロセスを介さずにコードの解釈に集中することができます。でも、特に関数呼び出しの流れを分析する際の設定が複雑なため、本記事ではこの方法について詳しくは取り上げません。

ランタイムリフレクション

この方法では、モックを使用して実装をコンパイルし、実行時に関数呼び出しをキャプチャします。ただし、Goではランタイムでメソッドを定義することがサポートされていないため、特定のインターフェースを実装する型を動的に作成することはできません。この課題に対応するため、本記事ではモック生成のためにgomockライブラリを活用します。

関数を呼び出し分析の解説

生成のためのランタイムコード分析を理解するために、gomockを使用して呼び出しをキャプチャするシンプルな実装を作成してみましょう。この実装は、以下のリポジトリを基に再実装したものです:
https://github.com/ezraisw/test-gen

テスト生成のために関数をどのように分析するかに進む前に、まずリフレクションを使用してgomockのモックを操作する方法を理解することが重要です。

まずは、gomockが生成する型の構造を確認するところから始めましょう。次のインターフェースを例に考えてみます:

service.go
type Service interface {
    A(ctx context.Context, in int) error
    B(ctx context.Context, in int) (*Out, error)
    C() int
}

この構造を理解することは、gomockを使用する際にリフレクションを効果的に活用するために不可欠です。

MockService
    - EXPECT() *MockServiceMockRecorder
    - A(ctx context.Context, in int) error
    - B(ctx context.Context, in int) (*Out, error)
    - C() int

MockServiceMockRecorder
    - A(ctx, in any) *MockServiceACall
    - B(ctx, in any) *MockServiceBCall
    - C() *MockServiceCCall

インターフェース内の各関数には、対応するrecorderメソッドがあり、これを使用して関数が呼び出された際にモックがどのように動作するかを設定できます。重要な設定項目の一つは、リターン値(Return)を定義することや、関数の動作をカスタマイズすること(Do/DoAndReturn)です。

特に注目すべきはDoAndReturnメソッドで、これによりモックが呼び出されたときにカスタムの動作を指定できます。このメソッドを利用して、関数が呼ばれた際に記録し、特定の時点で呼び出しをキャプチャすることができます。

では、analyzerを作成していきます。まず、EXPECT関数のリフレクションを取得します。NumIn()は関数が受け取るパラメータの数で、NumOut()は関数が返すリターン値の数です。

Source (analyzer/analyzer.go)
analyzer/analyzer.go
package analyzer

import (
    "fmt"
    "reflect"
)

type Analyzer struct {
}

func NewAnalyzer() *Analyzer {
    return &Analyzer{}
}

func (a *Analyzer) AttachTrap(r any) {
    rv := reflect.ValueOf(r)

    // (*)
    expectMethodRv := a.getExpectMethod(rv)

    fmt.Println(expectMethodRv.Kind())
}

func (a *Analyzer) getExpectMethod(mockRv reflect.Value) reflect.Value {
    if mockRv.Kind() != reflect.Ptr {
        panic("given value is not a pointer type to a possible gomock mock")
    }

    expectMethodRv := mockRv.MethodByName("EXPECT")
    if !expectMethodRv.IsValid() {
        panic("mock does not have EXPECT method")
    }
    // EXPECT is guaranteed a function kind.

    expectMethodRt := expectMethodRv.Type()
    if expectMethodRt.NumIn() != 0 {
        panic("EXPECT method does not have exactly 0 parameters")
    }

    if expectMethodRt.NumOut() != 1 {
        panic("EXPECT method does not have exactly 1 return value")
    }

    return expectMethodRv
}

analyzerの出力を確認するために、main関数も作成しましょう。

Source (cmd/main/main.go)
cmd/main/main.go
package main

import (
    "github.com/johndoe/hoge/analyzer"
    "github.com/johndoe/hoge"
    "go.uber.org/mock/gomock"
)

func main() {
    ctrl := gomock.NewController(nil)
    mock := hoge.NewMockService(ctrl)

    azr := analyzer.NewAnalyzer()
    azr.AttachTrap(mock)
}
> go run ./cmd/main
func

EXPECTのリフレクションを正常に取得しました!次に、リフレクションを通じて呼び出し、リターン値を取得する必要があります。リターン値は、MockServiceMockRecorderのインスタンスへのポインタになります。

Source (analyzer/analyzer.go)
analyzer/analyzer.go
package analyzer

import (
    "fmt"
    "reflect"
)

// ..

func (a *Analyzer) AttachTrap(r any) {
    rv := reflect.ValueOf(r)

    expectMethodRv := a.getExpectMethod(rv)

    // (*)
    recorderRv := expectMethodRv.Call([]reflect.Value{})[0]
    if recorderRv.Kind() != reflect.Pointer {
        panic("return value from EXPECT is not a pointer to a possible recorder")
    }

    fmt.Println(recorderRv.Type())
}

// ..
> go run ./cmd/main
*hoge.MockServiceMockRecorder

次に、各メソッドにフックして、呼び出されたタイミングを追跡する必要があります。recorder型でエクスポートされている関数は、インターフェースメソッドのレコーダーのみです。そのため、IsExported()チェックを前提に、すべてのメソッドをループで処理するのは安全です。

各メソッドにフックするためには、現時点ではすべての引数にgomock.Any()を使ってレコーダ関数を呼び出す必要があります。

Source (analyzer/analyzer.go)
analyzer/analyzer.go
package analyzer

import "reflect"

// ..

var condAnyRv = reflect.ValueOf(gomock.Any())

func (a *Analyzer) AttachTrap(r any) {
    rv := reflect.ValueOf(r)

    expectMethodRv := a.getExpectMethod(rv)

    recorderRv := expectMethodRv.Call([]reflect.Value{})[0]
    if recorderRv.Kind() != reflect.Pointer {
        panic("return value from EXPECT is not a pointer to a possible recorder")
    }
    recorderRt := recorderRv.Type()

    // (*)
    for i := 0; i < recorderRt.NumMethod(); i++ {
        method := recorderRt.Method(i)
        if !method.IsExported() {
            continue
        }

        methodRv := recorderRv.MethodByName(method.Name)
        methodRt := methodRv.Type()

        if methodRt.NumOut() != 1 {
            panic("recorder method does not have exactly 1 return value")
        }

        // Pass gomock.Any() to every argument of recorder function.
        argRvs := make([]reflect.Value, 0, methodRt.NumIn())
        for i := 0; i < methodRt.NumIn(); i++ {
            argRvs = append(argRvs, condAnyRv)
        }

        retRvs := methodRv.Call(argRvs)
        a.attachCall(retRvs[0])
    }
}

func (a *Analyzer) attachCall(callRv reflect.Value) {
    // To do...
}

// ..

呼び出しをアタッチするには、EXPECTを取得した方法と同じようにリフレクションメソッドを操作する必要があります。

その後、MakeFuncを使用して、reflect.Funcを型として持つreflect.Valueを作成する必要があります。現時点では、フックした関数に対してはゼロ値を返します。

リフレクションを通じて関数を作成したら、その関数のリフレクションをDoAndReturnの引数として渡します。

Source (analyzer/analyzer.go)
analyzer/analyzer.go
package analyzer

import (
    "fmt"
    "reflect"
)

// ..

func (a *Analyzer) attachCall(callRv reflect.Value) {
    if callRv.Kind() != reflect.Pointer {
        panic("return value from recorder method is not a pointer type to a possible gomock call")
    }

    doAndReturnMethodRv := callRv.MethodByName("DoAndReturn")
    doAndReturnMethodRt := doAndReturnMethodRv.Type()

    if doAndReturnMethodRt.NumIn() != 1 {
        panic("DoAndReturn does not have exactly 1 parameter")
    }

    if doAndReturnMethodRt.NumOut() != 1 {
        panic("DoAndReturn does not have exactly 1 return value")
    }

    realMethodRt := doAndReturnMethodRt.In(0)
    if realMethodRt.Kind() == reflect.Interface {
        panic("mock is generated without type information; cannot determine function type")
    }

    realMethodRv := reflect.MakeFunc(realMethodRt, func([]reflect.Value) []reflect.Value {
        retRvs := make([]reflect.Value, 0, realMethodRt.NumOut())
        // Return zero values for hooked functions.
        for i := 0; i < realMethodRt.NumOut(); i++ {
            // Create a zero value with reflection.
            retRvs = append(retRvs, reflect.New(realMethodRt.Out(i)).Elem())
        }

        // Print to see if this has been called.
        fmt.Println("Called")

        return retRvs
    })

    retRvs := doAndReturnMethodRv.Call([]reflect.Value{realMethodRv})
    callRv = retRvs[0]

    anyTimesMethodRv := callRv.MethodByName("AnyTimes")
    anyTimesMethodRt := anyTimesMethodRv.Type()

    if anyTimesMethodRt.NumIn() != 0 {
        panic("AnyTimes does not have exactly 0 parameters")
    }

    if anyTimesMethodRt.NumOut() != 1 {
        panic("AnyTimes does not have exactly 1 return value")
    }

    anyTimesMethodRv.Call([]reflect.Value{})
}

// ..

main関数も忘れずに調整しましょう。これで、フックしたモックを使って実装を呼び出すことができるようになります。

Source (cmd/main/main.go)
cmd/main/main.go
package main

import (
    "context"

    "github.com/johndoe/hoge"
    "github.com/johndoe/hoge/analyzer"
    "go.uber.org/mock/gomock"
)

func main() {
    ctrl := gomock.NewController(nil)
    mock := hoge.NewMockService(ctrl)

    azr := analyzer.NewAnalyzer()
    azr.AttachTrap(mock)

    // (*)
    h := hoge.NewHogeImpl(mock)
    h.Do(context.Background())
}
> go run ./cmd/main
go run ./cmd/main
Called
Called
Called
panic: runtime error: invalid memory address or nil pointer dereference
[signal SIGSEGV: segmentation violation code=0x2 addr=0x10 pc=0x1024bd2a4]

goroutine 1 [running]:
github.com/johndoe/hoge.HogeImpl.Do({{0x102529ee8?, 0x1400000c018?}}, {0x102529f78, 0x1026168c0})
        /Users/johndoe/hoge/hoge.go:27 +0x94
main.main()
        /Users/johndoe/hoge/cmd/main/main.go:19 +0xf4
exit status 2

これを実行すると、"nil pointer dereference"エラーが発生します。これは、Bの最初のリターン値としてnilを返したためです!ポインタのゼロ値はnilです。この問題を修正して、各関数のカスタムリターン値を定義できるようにしましょう。

AttachTrapmap[string][]any型のパラメータを追加し、各メソッドのリターン値を設定できるようにします。メソッドの名前に基づいてリターン値として[]anyの値を渡せます。配列の長さが関数の実際のリターン値の数と一致するかを確認するのを忘れないようにしましょう。

Source (analyzer/analyzer.go)
analyzer/analyzer.go
package analyzer

import (
    "fmt"
    "reflect"
)

// ..

// (*)
func (a *Analyzer) AttachTrap(r any, retsByMethodName map[string][]any) {
    rv := reflect.ValueOf(r)

    expectMethodRv := a.getExpectMethod(rv)

    recorderRv := expectMethodRv.Call([]reflect.Value{})[0]
    if recorderRv.Kind() != reflect.Pointer {
        panic("return value from EXPECT is not a pointer to a possible recorder")
    }
    recorderRt := recorderRv.Type()

    // (*) Validation to make sure all method names exist.
    for methodName := range retsByMethodName {
        _, ok := recorderRt.MethodByName(methodName)
        if !ok {
            panic("method name not found: " + methodName)
        }
    }

    for i := 0; i < recorderRt.NumMethod(); i++ {
        method := recorderRt.Method(i)
        if !method.IsExported() {
            continue
        }

        methodRv := recorderRv.MethodByName(method.Name)
        methodRt := methodRv.Type()

        if methodRt.NumOut() != 1 {
            panic("recorder method does not have exactly 1 return value")
        }

        // Pass gomock.Any() to every argument of recorder function.
        argRvs := make([]reflect.Value, 0, methodRt.NumIn())
        for i := 0; i < methodRt.NumIn(); i++ {
            argRvs = append(argRvs, condAnyRv)
        }

        // (*)
        rets := retsByMethodName[method.Name]

        retRvs := methodRv.Call(argRvs)
        a.attachCall(retRvs[0], rets)
    }
}

// (*)
func (a *Analyzer) attachCall(callRv reflect.Value, rets []any) {
    if callRv.Kind() != reflect.Pointer {
        panic("return value from recorder method is not a pointer type to a possible gomock call")
    }

    doAndReturnMethodRv := callRv.MethodByName("DoAndReturn")
    doAndReturnMethodRt := doAndReturnMethodRv.Type()

    if doAndReturnMethodRt.NumIn() != 1 {
        panic("DoAndReturn does not have exactly 1 parameter")
    }

    if doAndReturnMethodRt.NumOut() != 1 {
        panic("DoAndReturn does not have exactly 1 return value")
    }

    realMethodRt := doAndReturnMethodRt.In(0)
    if realMethodRt.Kind() == reflect.Interface {
        panic("mock is generated without type information; cannot determine function type")
    }

    // (*) Validate return value count.
    if rets != nil && realMethodRt.NumOut() != len(rets) {
        panic("number of return values does not match")
    }

    realMethodRv := reflect.MakeFunc(realMethodRt, func([]reflect.Value) []reflect.Value {
        retRvs := make([]reflect.Value, 0, realMethodRt.NumOut())
        // (*) Return zero values for hooked functions if return values are not registered.
        if rets == nil {
            for i := 0; i < realMethodRt.NumOut(); i++ {
                retRvs = append(retRvs, reflect.New(realMethodRt.Out(i)).Elem())
            }
        } else { // (*) Return the values as specified.
            for i, ret := range rets {
                retRv := reflect.ValueOf(ret)
                if retRv.Kind() == reflect.Invalid {
                    retRv = reflect.New(realMethodRt.Out(i)).Elem()
                }
                retRvs = append(retRvs, retRv)
            }
        }

        // Print to see if this has been called.
        fmt.Println("Called")

        return retRvs
    })

    retRvs := doAndReturnMethodRv.Call([]reflect.Value{realMethodRv})
    callRv = retRvs[0]

    anyTimesMethodRv := callRv.MethodByName("AnyTimes")
    anyTimesMethodRt := anyTimesMethodRv.Type()

    if anyTimesMethodRt.NumIn() != 0 {
        panic("AnyTimes does not have exactly 0 parameters")
    }

    if anyTimesMethodRt.NumOut() != 1 {
        panic("AnyTimes does not have exactly 1 return value")
    }

    anyTimesMethodRv.Call([]reflect.Value{})
}

// ..

AttachTrapの呼び出しを調整する必要があります。これにより、関数のリターン値を指定できるようになります。

Source (cmd/main/main.go)
cmd/main/main.go
package main

import (
    "context"
    "fmt"

    "github.com/johndoe/hoge"
    "github.com/johndoe/hoge/analyzer"
    "go.uber.org/mock/gomock"
)

func main() {
    ctrl := gomock.NewController(nil)
    mock := hoge.NewMockService(ctrl)

    azr := analyzer.NewAnalyzer()
    // (*) Pass the specified return values for each function.
    azr.AttachTrap(mock, map[string][]any{
        "A": {nil},
        "B": {&hoge.Out{
            X: "foobar",
            Y: 16,
        }, nil},
        "C": {25},
    })

    h := hoge.NewHogeImpl(mock)
    res, err := h.Do(context.Background())

    fmt.Println(res, err)
}
> go run ./cmd/main
Called
Called
Called
41 <nil>

関数は設定したリターン値を受け取っています!

これで、必要に応じて自由にリターン値を定義できるようになりました。もしAをエラーを返すように調整すると、Aからエラーが発生した時点で即座に返るため、「Called」は1回だけ表示されます。

Source (cmd/main/main.go)
cmd/main/main.go
package main

import (
    "context"
    "errors"
    "fmt"

    "github.com/johndoe/hoge"
    "github.com/johndoe/hoge/analyzer"
    "go.uber.org/mock/gomock"
)

func main() {
    ctrl := gomock.NewController(nil)
    mock := hoge.NewMockService(ctrl)

    azr := analyzer.NewAnalyzer()
    azr.AttachTrap(mock, map[string][]any{
        "A": {errors.New("mock error")}, // (*)
        "B": {&hoge.Out{
            X: "foobar",
            Y: 16,
        }, nil},
        "C": {25},
    })

    h := hoge.NewHogeImpl(mock)
    res, err := h.Do(context.Background())

    fmt.Println(res, err)
}
> go run ./cmd/main/main.go
Called
0 mock error

呼び出しをプログラム的に記録する必要があります。呼び出しを保存するために、シンプルな配列を追加しましょう。また、競合状態を防ぐためにsync.Mutexを追加します。

DoAndReturnに渡された関数内では、ロックをかけて呼び出しを配列に追加するだけです。

Source (analyzer/cs.go)
analyzer/cs.go
package analyzer

import "reflect"

type CallSignature struct {
    Type       reflect.Type
    MethodName string
    MethodType reflect.Type
    Returns    []any
}

func (c CallSignature) String() string {
    return c.Type.PkgPath() + "." + c.Type.Name() + "." + c.MethodName
}
Source (analyzer/analyzer.go)
analyzer/analyzer.go
package hoge

import (
    // ..
    "reflect"
    "sync"
)

type Analyzer struct {
    mu            sync.Mutex
    capturedCalls []*CallSignature
}

// ..

func (a *Analyzer) AttachTrap(r any, retsByMethodName map[string][]any) {
    rv := reflect.ValueOf(r)

    expectMethodRv := a.getExpectMethod(rv)

    rt := rv.Type()

    recorderRv := expectMethodRv.Call([]reflect.Value{})[0]
    if recorderRv.Kind() != reflect.Pointer {
        panic("return value from EXPECT is not a pointer to a possible recorder")
    }
    recorderRt := recorderRv.Type()

    for methodName := range retsByMethodName {
        _, ok := recorderRt.MethodByName(methodName)
        if !ok {
            panic("method name not found: " + methodName)
        }
    }

    for i := 0; i < recorderRt.NumMethod(); i++ {
        method := recorderRt.Method(i)
        if !method.IsExported() {
            continue
        }

        methodRv := recorderRv.MethodByName(method.Name)
        methodRt := methodRv.Type()

        if methodRt.NumOut() != 1 {
            panic("recorder method does not have exactly 1 return value")
        }

        // Pass gomock.Any() to every argument of recorder function.
        argRvs := make([]reflect.Value, 0, methodRt.NumIn())
        for i := 0; i < methodRt.NumIn(); i++ {
            argRvs = append(argRvs, condAnyRv)
        }

        rets := retsByMethodName[method.Name]

        // (*)
        cs := &CallSignature{
            Type:       rt.Elem(),
            MethodName: method.Name,
            MethodType: methodRt,
            Returns:    rets,
        }

        retRvs := methodRv.Call(argRvs)

        // (*)
        a.attachCall(cs, retRvs[0], rets)
    }
}

// (*)
func (a *Analyzer) attachCall(cs *CallSignature, callRv reflect.Value, rets []any) {
    if callRv.Kind() != reflect.Pointer {
        panic("return value from recorder method is not a pointer type to a possible gomock call")
    }

    doAndReturnMethodRv := callRv.MethodByName("DoAndReturn")
    doAndReturnMethodRt := doAndReturnMethodRv.Type()

    if doAndReturnMethodRt.NumIn() != 1 {
        panic("DoAndReturn does not have exactly 1 parameter")
    }

    if doAndReturnMethodRt.NumOut() != 1 {
        panic("DoAndReturn does not have exactly 1 return value")
    }

    realMethodRt := doAndReturnMethodRt.In(0)
    if realMethodRt.Kind() == reflect.Interface {
        panic("mock is generated without type information; cannot determine function type")
    }

    // Validate return value count.
    if rets != nil && realMethodRt.NumOut() != len(rets) {
        panic("number of return values does not match")
    }

    realMethodRv := reflect.MakeFunc(realMethodRt, func([]reflect.Value) []reflect.Value {
        // (*) Record as a captured call.
        a.mu.Lock()
        a.capturedCalls = append(a.capturedCalls, cs)
        a.mu.Unlock()

        retRvs := make([]reflect.Value, 0, realMethodRt.NumOut())
        // Return zero values for hooked functions if return values are not registered.
        if rets == nil {
            for i := 0; i < realMethodRt.NumOut(); i++ {
                retRvs = append(retRvs, reflect.New(realMethodRt.Out(i)).Elem())
            }
        } else {
            for i, ret := range rets {
                retRv := reflect.ValueOf(ret)
                if retRv.Kind() == reflect.Invalid {
                    retRv = reflect.New(realMethodRt.Out(i)).Elem()
                }
                retRvs = append(retRvs, retRv)
            }
        }

        return retRvs
    })

    retRvs := doAndReturnMethodRv.Call([]reflect.Value{realMethodRv})
    callRv = retRvs[0]

    anyTimesMethodRv := callRv.MethodByName("AnyTimes")
    anyTimesMethodRt := anyTimesMethodRv.Type()

    if anyTimesMethodRt.NumIn() != 0 {
        panic("AnyTimes does not have exactly 0 parameters")
    }

    if anyTimesMethodRt.NumOut() != 1 {
        panic("AnyTimes does not have exactly 1 return value")
    }

    anyTimesMethodRv.Call([]reflect.Value{})
}

// ..

capturedCallsはエクスポートされていないので、これを取得するためのゲッタ関数も追加しましょう。

Source (analyzer/analyzer.go)
analyzer/analyzer.go
package analyzer

// ..

func (a *Analyzer) GetCapturedCalls() []*CallSignature {
    return a.capturedCalls
}

// ..

main関数を調整して結果が確認できるようにするのを忘れないでください。GetCapturedCalls()は、Doが呼び出された後のすべての呼び出し情報を持っているべきです。

Source (cmd/main/main.go)
cmd/main/main.go
package main

import (
    "context"
    "fmt"

    "github.com/johndoe/hoge"
    "github.com/johndoe/hoge/analyzer"
    "go.uber.org/mock/gomock"
)

func main() {
    ctrl := gomock.NewController(nil)
    mock := hoge.NewMockService(ctrl)

    azr := analyzer.NewAnalyzer()
    azr.AttachTrap(mock, map[string][]any{
        "A": {nil},
        "B": {&hoge.Out{
            X: "foobar",
            Y: 16,
        }, nil},
        "C": {25},
    })

    h := hoge.NewHogeImpl(mock)
    res, err := h.Do(context.Background())

    for _, cs := range azr.GetCapturedCalls() {
        fmt.Println(cs)
    }
    fmt.Println(res, err)
}
> go run ./cmd/main
github.com/johndoe/hoge.MockService.A
github.com/johndoe/hoge.MockService.B
github.com/johndoe/hoge.MockService.C
41 <nil>

リターン値組合せから決定木を生成する

先ほど説明したコード分析実装をもとに、いくつかの設定に基づいてリターン値の決定木を作成できるようになりました。

この生成における決定木がどのように機能するかを説明するために、HogeImpl.Doの可能なテストケースを見てみましょう:

  1. Aがerror A1を返す。
  2. Aがerror A2を返す。
  3. Aがerrorを返さず、Bnilとerror Bを返す。
  4. Aがerrorを返さず、B&Out{X: "This is X", Y: 50}とerrorなしを返し、C35を返す。

グラフとして描くと:

これはリターン値の決定木であり、関数呼び出しの決定木ではないことに注意してください!

コード的には、これは定義されたすべてのパスを探索する再帰のケースです。test-genでは、これがanalyzer.Multiply関数を通じて行われており、これはモック設定を受け取り、すべての可能なリターン値を定義し、すべての可能なリターン値(Varyとして表現)に対して再帰的に探索を行います。この記事では、analyzer.Multiplyの実装については詳細に触れません。

MockMethodReturnsには、PassStop、またはVaryを渡すことができます。

  • Passは、決定木が探索を続けることを許可します。
MockMethod{
    Name: "A",
    Returns: Pass{nil},
}
  • Stopは、決定木がさらに探索を続けるのを停止します。必ずしも「リターン値を返す」ことを意味するわけではなく、未探索の関数のリターン値をゼロ値として残します。
MockMethod{
    Name: "A",
    Returns: Stop{nil},
}
  • Varyは、可能なリターン値を定義することを許可します。これにより、決定木に分岐が作成されます。Stopがない場合は、通常通り探索を続けます。
MockMethod{
    Name: "A",
    Returns: Vary{Pass{errA1}, Pass{errA2}, Pass{nil}},
}
MockMethod{
    Name: "A",
    Returns: Vary{Stop{errA1}, Stop{errA2}, Pass{nil}},
}
Source (cmd/main/main.go)
cmd/main/main.go
package main

import (
    "context"
    "errors"
    "fmt"

    "github.com/ezraisw/test-gen/analyzer"
    "github.com/johndoe/hoge"
    "go.uber.org/mock/gomock"
)

func main() {
    errA1 := errors.New("mock error A 1")
    errA2 := errors.New("mock error A 2")
    mockErrB := errors.New("mock error B")

    res := analyzer.Multiply([]*analyzer.MockConfig{
        {
            New: func(ctrl *gomock.Controller) any { return hoge.NewMockService(ctrl) },
            Methods: []*analyzer.MockMethod{
                {Name: "A", Returns: analyzer.Vary{
                    analyzer.Stop{errA1},
                    analyzer.Stop{errA2},
                    analyzer.Pass{nil},
                }},
                {Name: "B", Returns: analyzer.Vary{
                    analyzer.Stop{nil, mockErrB},
                    analyzer.Pass{&hoge.Out{X: "This is X", Y: 50}, nil},
                }},
                {Name: "C", Returns: analyzer.Pass{35}},
            },
        },
    }, func(mocks []any) {
        mock := mocks[0].(*hoge.MockService)
        u := hoge.NewHogeImpl(mock)
        out, err := u.Do(context.Background())

        fmt.Println(out, err)
    })

    for i, calls := range res {
        fmt.Println(i, calls)
    }
}
> go run ./cmd/main
0 mock error A 1
0 mock error A 2
0 mock error B
85 <nil>
0 [github.com/johndoe/hoge.MockService.A]
1 [github.com/johndoe/hoge.MockService.A]
2 [github.com/johndoe/hoge.MockService.A github.com/johndoe/hoge.MockService.B]
3 [github.com/johndoe/hoge.MockService.A github.com/johndoe/hoge.MockService.B github.com/johndoe/hoge.MockService.C]

決定木の結果からコードを生成する

リターン値決定木からすべての可能な呼び出しを生成した後は、あとは手元にあるデータを基にコードを生成するだけです。

生成ロジックはシンプルで、テンプレートを定義し、それにデータを埋め込み、テンプレートを組み合わせ、最後に生成されたコードをgo/formatパッケージで整形するだけです。

テンプレートを定義する

テンプレートはこのように定義できます:

var fileTmpl = template.Must(template.New("file").Parse(
    `// Code generated by test-gen. DO NOT EDIT.
package {{ .packageName }}

import (
    {{ range .imports }}"{{ . }}"
    {{ end }}
)

{{ range .testFuncs }}{{ . }}
{{ end }}
`))

template.Newの最初の引数はテンプレート名です。なんでもに設定できます。その後、Parseを呼び出してテンプレートを解析します。この時点でテンプレートはコンパイルされ、使用できる状態になります。template.Mustは、テンプレートエラーが発生した場合にパニックを引き起こすためのグローバル変数定義のための糖衣構文です。

テンプレートを使用するには、Execute関数を呼び出します。

fileTmpl.Execute(&buf, map[string]any{
    "packageName": pkgName,
    "imports":     imports,
    "testFuncs":   testFuncs,
})

リターン値を表現する

CallSignatureには、テンプレートに貼り付ける必要があるリターン値が含まれています。でも、現在はセマンティックな形式であり、コード形式に戻す必要があります。そのためには、ライブラリ"github.com/sanity-io/litter"を使用して、セマンティックな値をコードに戻すことができます。このライブラリは元のGoコードの形式をほぼ完全に模倣しているため、コンパイルエラーは発生しません。

Source (cmd/littertest/main.go)
cmd/littertest/main.go
package main

import (
    "fmt"

    "github.com/sanity-io/litter"
)

type Hoge struct {
    Foo          *Foo
    IntArray     []int
    StringArray  []string
    StringIntMap map[string]int
    StringBazMap map[string]*Baz
}

type Foo struct {
    Bar    *Bar
    Int    int
    Long   int64
    Float  float32
    Double float64
    Str    string
}

type Bar struct {
    Str string
}

type Baz struct {
    Short int16
    Byte  int8
    Uint  uint
}

var sq = litter.Options{
    HidePrivateFields:         true,
    DisablePointerReplacement: true,
    HomePackage:               "main",
}

func main() {
    value := Hoge{
        Foo: &Foo{
            Bar: &Bar{},
        },
        StringArray: []string{"A", "B", "C", "D"},
        IntArray:    []int{1, 2, 3, 4},
        StringIntMap: map[string]int{
            "A": 1,
            "B": 2,
        },
        StringBazMap: map[string]*Baz{
            "X": {},
            "Y": {},
            "Z": {},
        },
    }

    str := sq.Sdump(value)

    fmt.Println(str)
}
> go run ./cmd/littertest
Hoge{
  Foo: &Foo{
    Bar: &Bar{
      Str: "",
    },
    Int: 0,
    Long: 0,
    Float: 0.0,
    Double: 0.0,
    Str: "",
  },
  IntArray: []int{
    1,
    2,
    3,
    4,
  },
  StringArray: []string{
    "A",
    "B",
    "C",
    "D",
  },
  StringIntMap: map[string]int{
    "A": 1,
    "B": 2,
  },
  StringBazMap: map[string]*Baz{
    "X": &Baz{
      Short: 0,
      Byte: 0,
      Uint: 0,
    },
    "Y": &Baz{
      Short: 0,
      Byte: 0,
      Uint: 0,
    },
    "Z": &Baz{
      Short: 0,
      Byte: 0,
      Uint: 0,
    },
  },
}

フォーマットする

ソースコードをフォーマットするには、コードをバイトの配列として保存する必要があります。その後、そのバイトの配列に対してformat.Formatを呼び出すだけです。出力もバイトの配列として返され、ファイルに出力することができます。

var buf bytes.Buffer
// ..

formatted, err := format.Source(buf.Bytes())
if err != nil {
    panic(err)
}

最終の設定

単一の引数[]anyを受け取るrunner関数を作成する必要があります。

hoge_test.go
package hoge_test

import (
    "context"

    "github.com/johndoe/hoge"
)

func runHogeImplDo(mocks []any) {
    mockService := mocks[0].(*hoge.MockService)
    u := hoge.NewHogeImpl(mockService)
    _, _ = u.Do(context.Background())
}

その後、testgenのbuild tagを持つ"testgen.go"ファイルを作成する必要があります。これにより、実装と一緒にコンパイルされることはありません。

Source (testgen.go)
testgen.go
//go:build testgen
// +build testgen

package hoge

import (
    "context"
    "errors"

    testgen "github.com/ezraisw/test-gen"
    "github.com/ezraisw/test-gen/analyzer"
    "go.uber.org/mock/gomock"
)

func Generate() {
    testgen.Generate("hoge_test", "hoge_cov_test.go", []*testgen.Test{
        {
            Name: "TestCoverage_HogeImpl_Do",
            MockConfigs: []*analyzer.MockConfig{
                {
                    New: func(ctrl *gomock.Controller) any { return NewMockService(ctrl) },
                    Methods: []*analyzer.MockMethod{
                        {Name: "A", Returns: analyzer.Vary{
                            analyzer.Stop{errors.New("mock error A 1")},
                            analyzer.Stop{errors.New("mock error A 2")},
                            analyzer.Pass{nil},
                        }},
                        {Name: "B", Returns: analyzer.Vary{
                            analyzer.Stop{nil, errors.New("mock error B")},
                            analyzer.Pass{&Out{X: "This is X", Y: 50}, nil},
                        }},
                        {Name: "C", Returns: analyzer.Pass{35}},
                    },
                },
            },
            Run: func(mocks []any) {
                mock := mocks[0].(*MockService)
                u := NewHogeImpl(mock)
                _, _ = u.Do(context.Background())
            },
            TestRun: "runHogeImplDo",
        },
    })
}

この生成を実行するためにmainパッケージコードを作成せずに実行する良い方法はないです。なので、それを実行するために同じbuild tagを持つファイルも作成しましょう。

Source (testgenrun/main.go)
testgenrun/main.go
//go:build testgen
// +build testgen

package main

import "github.com/johndoe/hoge"

func main() {
	hoge.Generate()
}

そのmainパッケージでgo runを実行することで、ジェネレーターを実行できます。

> go run ./testgenrun
Generated test (hoge_cov_test.go)
hoge_cov_test.go
// Code generated by test-gen. DO NOT EDIT.
package hoge_test

import (
    "errors"
    "github.com/johndoe/hoge"
    "go.uber.org/mock/gomock"
    "strconv"
    "testing"
)

func TestCoverage_HogeImpl_Do(t *testing.T) {
    cases := []struct {
        setup func(ctrl *gomock.Controller) []any
    }{
        {
            setup: func(ctrl *gomock.Controller) []any {
                mockMockService := hoge.NewMockService(ctrl)

                mockMockService.EXPECT().A(gomock.Any(), gomock.Any()).Return(errors.New("mock error A 1"))

                return []any{
                    mockMockService,
                }
            },
        },
        {
            setup: func(ctrl *gomock.Controller) []any {
                mockMockService := hoge.NewMockService(ctrl)

                mockMockService.EXPECT().A(gomock.Any(), gomock.Any()).Return(errors.New("mock error A 2"))

                return []any{
                    mockMockService,
                }
            },
        },
        {
            setup: func(ctrl *gomock.Controller) []any {
                mockMockService := hoge.NewMockService(ctrl)

                mockMockService.EXPECT().A(gomock.Any(), gomock.Any()).Return(nil)
                mockMockService.EXPECT().B(gomock.Any(), gomock.Any()).Return(nil, errors.New("mock error B"))

                return []any{
                    mockMockService,
                }
            },
        },
        {
            setup: func(ctrl *gomock.Controller) []any {
                mockMockService := hoge.NewMockService(ctrl)

                mockMockService.EXPECT().A(gomock.Any(), gomock.Any()).Return(nil)
                mockMockService.EXPECT().B(gomock.Any(), gomock.Any()).Return(&hoge.Out{X: "This is X", Y: 50}, nil)
                mockMockService.EXPECT().C().Return(35)

                return []any{
                    mockMockService,
                }
            },
        },
    }

    for i, c := range cases {
        c := c
        t.Run("Case #"+strconv.Itoa(i), func(t *testing.T) {
            ctrl := gomock.NewController(t)
            mocks := c.setup(ctrl)

            runHogeImplDo(mocks)
        })
    }
}

実用的な用途は何?

上記のように、これの実用的な使い方の一つは、大量のモックを使用した大規模なテストケースを書く手間を避けることです。初期の記述はジェネレーターに任せます。もしコードカバレッジが気になる場合、特にif err != nilのケースに対しては、これを使用してコードカバレッジを増やすことができます。

生成されたテストだけを使いたくなるかもしれませんが、機能のすべての入力と出力をテストする少なくとも1つの手動テストを書くことは依然として重要です!自分自身では、生成されたテストと手動で書かれたテストの両方を持つのが最良です。

https://github.com/ezraisw/test-gen

SODA Engineering Blog
SODA Engineering Blog

Discussion