💡

context.Context の Value のみを引き継いでコピーする方法

2021/12/24に公開

よろしくどーぞ。@knwoop です
この記事はGo Advent Calendar 2021の22日目の記事です。
https://qiita.com/advent-calendar/2021/go

この記事は Context をコピーするため Tips を紹介します。
ちょっと特殊なのは、Context にすでにセットされている timeout を無視して value のみを引き継ぐ方法を紹介したいと思います。

主なユースケース

複数の API に対して跨ぐ原子性を持つリクエストするケースがあるとします。以下の様なコードAPI2 がリクエスト失敗したときに API1 に Rollback します。

if err := api1.Request(ctx); err != nil {
  return fmt.Errorf("error api1 Request: %w", err)
}

if err := api2.Request(ctx); err != nil {
  if err := api1.Rollback(ctx); err != nil {
    return fmt.Errorf("error api1 Rollback: %w", err)
  }
  return fmt.Errorf("error api2 Request: %w", err)
}

api2 が失敗したときに api1 にロールバックを呼ぶことで原子性を担保しています。

しかし、この実装では問題があります。
api2 のリクエスト中に context が cancel になったとき Rollback をコールしても失敗してしまいます。
以下、サンプルコードを書いてみました。

package main

import (
	"context"
	"errors"
	"fmt"
	"time"
)

type Api1 struct {
}

func (a *Api1) Request(ctx context.Context) error {
	time.Sleep(3 * time.Second)
	return nil
}

func (a *Api1) Rollback(ctx context.Context) error {
	select {
	case <-ctx.Done():
		return errors.New("error timeout")
	default:
	}

	fmt.Println("Rollback done")
	return nil
}

type Api2 struct {
}

func (a *Api2) Request(ctx context.Context) error {
	return errors.New("failed api request")
}

func main() {
	api1 := &Api1{}
	api2 := &Api2{}

	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
	defer cancel()
	if err := api1.Request(ctx); err != nil {
		fmt.Printf("error api1 Request: %s\n", err)
		return
	}

	if err := api2.Request(ctx); err != nil {
		if err := api1.Rollback(ctx); err != nil {
			fmt.Printf("error api1 Rollback: %s\n", err)
			return
		}
		fmt.Printf("error api2 Request: %s\n", err)
		return
	}
}

https://go.dev/play/p/FFD1KpaDc9q

ちょっと解説すると

  • context が 1 秒で設定されている
  • api1 のリクエストは 3 秒かかります
  • api2 のリクエストは必ず失敗するようになっています
  • そして必ず context が cancel になっている状態で Rollback が呼ばれる
  • context が cancel になっているため Rollback が失敗する

これを解決するには、Rollback を呼ぶ前に新しく context を生成するば解決します。

if err := api2.Request(ctx); err != nil {
	// 今回追加したコード
	childCtx, childCancel := context.WithTimeout(context.Background(), time.Second)
	defer childCancel()
	if err := api1.Rollback(ctx); err != nil {
		fmt.Printf("error api1 Rollback: %s\n", err)
		return
	}
	fmt.Printf("error api2 Request: %s\n", err)
	return
}

https://go.dev/play/p/Zl4WJ6SDaHl
無事 Rollback が成功しました。

Rollback done
error api1 Request: failed api request

しかしここで問題があります。実際のプロダクションのコードでは、context の中には以下のようなさまざまな metadata を格納することがほとんどです。

  • trace id
  • 認証トークン
  • トランザクションキー
  • etc...

このままでは、これらの値をうまく渡せなくなってしまいます。確かに context パッケージの func WithValue で1つずつ値をセットすることができますが、少し冗長です。それに metadata が増えたときに、都度追加する必要があります。

やっと本題に入りましたが、この問題を解決するための方法を紹介したいと思います。

実装してみる

まず、context.Context をフィールドをもつ struct を作成します。それから Value(key interface{}) 以外は、ゼロ値を返すような context.Context インタフェースの実装を実装したカスタム Context を作成します。 Value を返すメソッドだけは、context.Context のフィールドから返すようにします。
以下の様になります。

type xcontext struct {
	ctx context.Context
}

func (xcontext) Deadline() (time.Time, bool) { return time.Time{}, false }
func (xcontext) Done() <-chan struct{}       { return nil }
func (xcontext) Err() error                  { return nil }
func (x xcontext) Value(key interface{}) interface{} {
	return x.ctx.Value(key)
}

次に context.Context を引数に カスタム Context を返す Detach 関数を作成します。

func Detach(ctx context.Context) context.Context {
	return xcontext{ctx}
}

実際に試してみると、いい感じに timeout は切り離されて Value を引き継つけていることが確認できます。

func main() {
	parentCtx := context.Background()
	parentCtx = context.WithValue(parentCtx, "trace-id", "req123")
	parentCtx = context.WithValue(parentCtx, "authenticate", "xxxxxxxx.yyyyyyyyy.zzzzzzzz")

	ctx, cancel := context.WithCancel(parentCtx)

	dtx := Detach(ctx)

	cancel()
	fmt.Println("ctx.Err():", ctx.Err())

	fmt.Println("dtx.Err():", dtx.Err())
	fmt.Println(`dtx.Value("trace-id"):`, dtx.Value("trace-id"))

	ttx, cancel := context.WithTimeout(dtx, time.Millisecond)
	cancel()
	deadline, ok := ttx.Deadline()
	fmt.Println("ttx.Deadline():", deadline, ok)
	fmt.Println("ttx.Err():", ttx.Err())
	fmt.Println("dtx.Err():", dtx.Err())
}

https://go.dev/play/p/jflykwHjvC4

このように出力されます

ctx.Err(): context canceled
dtx.Err(): <nil>
dtx.Value("trace-id"): req123
ttx.Deadline(): 2009-11-10 23:00:00.001 +0000 UTC m=+0.001000001 true
ttx.Err(): context canceled
dtx.Err(): <nil>

Program exited.

最初のコードに追加してみます。

package main

import (
	"context"
	"errors"
	"fmt"
	"time"
)

type Api1 struct {
}

func (a *Api1) Request(ctx context.Context) error {
	time.Sleep(3 * time.Second)
	return nil
}

func (a *Api1) Rollback(ctx context.Context) error {
	select {
	case <-ctx.Done():
		return errors.New("error timeout")
	default:
	}

	fmt.Println(`ctx.Value("trace-id"):`, ctx.Value("trace-id"))
	fmt.Println(`ctx.Value("authenticate"):`, ctx.Value("authenticate"))
	fmt.Println("Rollback done")
	return nil
}

type Api2 struct {
}

func (a *Api2) Request(ctx context.Context) error {
	return errors.New("failed api request")
}

func main() {
	api1 := &Api1{}
	api2 := &Api2{}

	parentCtx := context.Background()
	parentCtx = context.WithValue(parentCtx, "trace-id", "req123")
	parentCtx = context.WithValue(parentCtx, "authenticate", "xxxxxxxx.yyyyyyyyy.zzzzzzzz")

	ctx, cancel := context.WithTimeout(parentCtx, 1*time.Second)
	defer cancel()
	if err := api1.Request(ctx); err != nil {
		fmt.Printf("error api1 Request: %s\n", err)
		return
	}

	if err := api2.Request(ctx); err != nil {
		childCtx := Detach(ctx)
		if err := api1.Rollback(childCtx); err != nil {
			fmt.Printf("error api1 Rollback: %s\n", err)
			return
		}
		fmt.Printf("error api2 Request: %s\n", err)
		return
	}
}

func Detach(ctx context.Context) context.Context {
	return xcontext{ctx}
}

type xcontext struct {
	ctx context.Context
}

func (xcontext) Deadline() (time.Time, bool) { return time.Time{}, false }
func (xcontext) Done() <-chan struct{}       { return nil }
func (xcontext) Err() error                  { return nil }
func (x xcontext) Value(key interface{}) interface{} {
	return x.ctx.Value(key)
}

https://go.dev/play/p/Z2trW_ciYZO

無事、ロールバックが成功して、trace id と トークンを取得できたことを確認できたと思います。

ctx.Value("trace-id"): req123
ctx.Value("authenticate"): xxxxxxxx.yyyyyyyyy.zzzzzzzz
Rollback done
error api1 Request: failed api request

Program exited.

以上、context の value のみをコピーする方法でした!

Discussion