context.Context の Value のみを引き継いでコピーする方法
よろしくどーぞ。@knwoop です
この記事はGo Advent Calendar 2021の22日目の記事です。
この記事は Context をコピーするため Tips を紹介します。
ちょっと特殊なのは、Context にすでにセットされている timeout を無視して value のみを引き継ぐ方法を紹介したいと思います。
主なユースケース
複数の API に対して跨ぐ原子性を持つリクエストするケースがあるとします。以下の様なコードAPI2 がリクエスト失敗したときに API1 に Rollback します。
if err := api1.Request(ctx); err != nil {
return fmt.Errorf("error api1 Request: %w", err)
}
if err := api2.Request(ctx); err != nil {
if err := api1.Rollback(ctx); err != nil {
return fmt.Errorf("error api1 Rollback: %w", err)
}
return fmt.Errorf("error api2 Request: %w", err)
}
api2 が失敗したときに api1 にロールバックを呼ぶことで原子性を担保しています。
しかし、この実装では問題があります。
api2 のリクエスト中に context が cancel になったとき Rollback をコールしても失敗してしまいます。
以下、サンプルコードを書いてみました。
package main
import (
"context"
"errors"
"fmt"
"time"
)
type Api1 struct {
}
func (a *Api1) Request(ctx context.Context) error {
time.Sleep(3 * time.Second)
return nil
}
func (a *Api1) Rollback(ctx context.Context) error {
select {
case <-ctx.Done():
return errors.New("error timeout")
default:
}
fmt.Println("Rollback done")
return nil
}
type Api2 struct {
}
func (a *Api2) Request(ctx context.Context) error {
return errors.New("failed api request")
}
func main() {
api1 := &Api1{}
api2 := &Api2{}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
if err := api1.Request(ctx); err != nil {
fmt.Printf("error api1 Request: %s\n", err)
return
}
if err := api2.Request(ctx); err != nil {
if err := api1.Rollback(ctx); err != nil {
fmt.Printf("error api1 Rollback: %s\n", err)
return
}
fmt.Printf("error api2 Request: %s\n", err)
return
}
}
ちょっと解説すると
- context が 1 秒で設定されている
- api1 のリクエストは 3 秒かかります
- api2 のリクエストは必ず失敗するようになっています
- そして必ず context が cancel になっている状態で Rollback が呼ばれる
- context が cancel になっているため Rollback が失敗する
これを解決するには、Rollback を呼ぶ前に新しく context を生成するば解決します。
if err := api2.Request(ctx); err != nil {
// 今回追加したコード
childCtx, childCancel := context.WithTimeout(context.Background(), time.Second)
defer childCancel()
if err := api1.Rollback(ctx); err != nil {
fmt.Printf("error api1 Rollback: %s\n", err)
return
}
fmt.Printf("error api2 Request: %s\n", err)
return
}
無事 Rollback が成功しました。
Rollback done
error api1 Request: failed api request
しかしここで問題があります。実際のプロダクションのコードでは、context の中には以下のようなさまざまな metadata を格納することがほとんどです。
- trace id
- 認証トークン
- トランザクションキー
- etc...
このままでは、これらの値をうまく渡せなくなってしまいます。確かに context パッケージの func WithValue で1つずつ値をセットすることができますが、少し冗長です。それに metadata が増えたときに、都度追加する必要があります。
やっと本題に入りましたが、この問題を解決するための方法を紹介したいと思います。
実装してみる
まず、context.Context をフィールドをもつ struct を作成します。それから Value(key interface{})
以外は、ゼロ値を返すような context.Context インタフェースの実装を実装したカスタム Context を作成します。 Value を返すメソッドだけは、context.Context のフィールドから返すようにします。
以下の様になります。
type xcontext struct {
ctx context.Context
}
func (xcontext) Deadline() (time.Time, bool) { return time.Time{}, false }
func (xcontext) Done() <-chan struct{} { return nil }
func (xcontext) Err() error { return nil }
func (x xcontext) Value(key interface{}) interface{} {
return x.ctx.Value(key)
}
次に context.Context を引数に カスタム Context を返す Detach 関数を作成します。
func Detach(ctx context.Context) context.Context {
return xcontext{ctx}
}
実際に試してみると、いい感じに timeout は切り離されて Value を引き継つけていることが確認できます。
func main() {
parentCtx := context.Background()
parentCtx = context.WithValue(parentCtx, "trace-id", "req123")
parentCtx = context.WithValue(parentCtx, "authenticate", "xxxxxxxx.yyyyyyyyy.zzzzzzzz")
ctx, cancel := context.WithCancel(parentCtx)
dtx := Detach(ctx)
cancel()
fmt.Println("ctx.Err():", ctx.Err())
fmt.Println("dtx.Err():", dtx.Err())
fmt.Println(`dtx.Value("trace-id"):`, dtx.Value("trace-id"))
ttx, cancel := context.WithTimeout(dtx, time.Millisecond)
cancel()
deadline, ok := ttx.Deadline()
fmt.Println("ttx.Deadline():", deadline, ok)
fmt.Println("ttx.Err():", ttx.Err())
fmt.Println("dtx.Err():", dtx.Err())
}
このように出力されます
ctx.Err(): context canceled
dtx.Err(): <nil>
dtx.Value("trace-id"): req123
ttx.Deadline(): 2009-11-10 23:00:00.001 +0000 UTC m=+0.001000001 true
ttx.Err(): context canceled
dtx.Err(): <nil>
Program exited.
最初のコードに追加してみます。
package main
import (
"context"
"errors"
"fmt"
"time"
)
type Api1 struct {
}
func (a *Api1) Request(ctx context.Context) error {
time.Sleep(3 * time.Second)
return nil
}
func (a *Api1) Rollback(ctx context.Context) error {
select {
case <-ctx.Done():
return errors.New("error timeout")
default:
}
fmt.Println(`ctx.Value("trace-id"):`, ctx.Value("trace-id"))
fmt.Println(`ctx.Value("authenticate"):`, ctx.Value("authenticate"))
fmt.Println("Rollback done")
return nil
}
type Api2 struct {
}
func (a *Api2) Request(ctx context.Context) error {
return errors.New("failed api request")
}
func main() {
api1 := &Api1{}
api2 := &Api2{}
parentCtx := context.Background()
parentCtx = context.WithValue(parentCtx, "trace-id", "req123")
parentCtx = context.WithValue(parentCtx, "authenticate", "xxxxxxxx.yyyyyyyyy.zzzzzzzz")
ctx, cancel := context.WithTimeout(parentCtx, 1*time.Second)
defer cancel()
if err := api1.Request(ctx); err != nil {
fmt.Printf("error api1 Request: %s\n", err)
return
}
if err := api2.Request(ctx); err != nil {
childCtx := Detach(ctx)
if err := api1.Rollback(childCtx); err != nil {
fmt.Printf("error api1 Rollback: %s\n", err)
return
}
fmt.Printf("error api2 Request: %s\n", err)
return
}
}
func Detach(ctx context.Context) context.Context {
return xcontext{ctx}
}
type xcontext struct {
ctx context.Context
}
func (xcontext) Deadline() (time.Time, bool) { return time.Time{}, false }
func (xcontext) Done() <-chan struct{} { return nil }
func (xcontext) Err() error { return nil }
func (x xcontext) Value(key interface{}) interface{} {
return x.ctx.Value(key)
}
無事、ロールバックが成功して、trace id と トークンを取得できたことを確認できたと思います。
ctx.Value("trace-id"): req123
ctx.Value("authenticate"): xxxxxxxx.yyyyyyyyy.zzzzzzzz
Rollback done
error api1 Request: failed api request
Program exited.
以上、context の value のみをコピーする方法でした!
Discussion