Open21

goでmiddlewareの後に実行される関数を注入したい

podhmopodhmo

以下のようなことがしたい

  • middlewareで特定のapi用の基本的なfixtureを設定したい
  • httptest.Recorder越しにhandlerを直接呼ぶタイミングでrequestにちょっとした装飾を加えたい

ここで通常はmiddlewareが優先される。handlerの中でmiddlewareが呼ばれるので順序的にはそれはそうなのだけど、後者の方を優先するような仕組みを作りたい。

----------------------------------------
HTTP/1.1 200 OK
Connection: close
Content-Type: text/plain; charset=utf-8

{"value": "", "ok": false} <nil>
----------------------------------------
PASS
ok      github.com/podhmo/individual-sandbox/daily/20230629/example_go/00test   0.008s

repository: https://github.com/podhmo/individual-sandbox/tree/master/daily/20230629/example_go

code

package main

import (
	"fmt"
	"net/http"
	"net/http/httptest"
	"net/http/httputil"
	"testing"
)

type ctxKey string

const (
	ctxValueKey ctxKey = "value"
)

func Handler(w http.ResponseWriter, req *http.Request) {
	v, ok := req.Context().Value(ctxValueKey).(string)
	fmt.Println("get ", v, ok)
	fmt.Fprintf(w, `{"value": %q, "ok": %t}`, v, ok)
}

func TestIt(t *testing.T) {
	middleware := func(inner http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
			inner.ServeHTTP(w, req)
		})
	}
	h := middleware(http.HandlerFunc(Handler))

	rec := httptest.NewRecorder()
	req := httptest.NewRequest("GET", "/", nil)
	h.ServeHTTP(rec, req)
	res := rec.Result()
	if want, got := http.StatusOK, res.StatusCode; want != got {
		t.Fatalf("unexpected status code: want=%d, but got=%d", want, got)
	}
	b, err := httputil.DumpResponse(rec.Result(), true)
	fmt.Println("----------------------------------------")
	fmt.Println(string(b), err)
	fmt.Println("----------------------------------------")
}
podhmopodhmo

とりあえず、contextへのinjectで設定を済ませることにする。

podhmopodhmo

以下のようにやるとmiddlewareが優先される(それはそう)

diff --git a/daily/20230629/example_go/00test/main_test.go b/daily/20230629/example_go/00test/main_test.go
index 0f4a1dc1..7b406177 100644
--- a/daily/20230629/example_go/00test/main_test.go
+++ b/daily/20230629/example_go/00test/main_test.go
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"context"
 	"fmt"
 	"net/http"
 	"net/http/httptest"
@@ -23,6 +24,7 @@ func Handler(w http.ResponseWriter, req *http.Request) {
 func TestIt(t *testing.T) {
 	middleware := func(inner http.Handler) http.Handler {
 		return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+			req = req.WithContext(context.WithValue(req.Context(), ctxValueKey, "MIDDLEWARE"))
 			inner.ServeHTTP(w, req)
 		})
 	}
@@ -30,6 +32,7 @@ func TestIt(t *testing.T) {
 
 	rec := httptest.NewRecorder()
 	req := httptest.NewRequest("GET", "/", nil)
+	req = req.WithContext(context.WithValue(req.Context(), ctxValueKey, "REQUEST"))
 	h.ServeHTTP(rec, req)
 	res := rec.Result()
 	if want, got := http.StatusOK, res.StatusCode; want != got {
podhmopodhmo

ここで本当はMIDDLEWAREではなくREQUESTが帰ってくるようになって欲しい。

get  MIDDLEWARE true
----------------------------------------
HTTP/1.1 200 OK
Connection: close
Content-Type: text/plain; charset=utf-8

{"value": "MIDDLEWARE", "ok": true} <nil>
----------------------------------------
PASS
ok  	github.com/podhmo/individual-sandbox/daily/20230629/example_go/00test	0.006s
podhmopodhmo

どのような方法があるか?

podhmopodhmo

必ずmiddlewareなりhookを呼ばなくてはいけない部分が面倒。

podhmopodhmo

気持ちとしてget側では以下のような記述を壊したくない。

v, ok := ctx.Value(ctxValueKey).(string)
podhmopodhmo

requestに直接注入するのではなくmiddlewareにする

ちょっとめんどくさいのは、ginやechoなどの場合はleft to rightにmiddlewareが登録されるが、net/httpの場合には単にネストした関数呼び出しになるだけなのでright to leftに実行されてく。というわけでframeworkによって挙動が不安定になるかもしれない。

podhmopodhmo

例えば、以下のように変更して呼び出し関係をわかりやすくする。

diff --git a/daily/20230629/example_go/00test/main_test.go b/daily/20230629/example_go/00test/main_test.go
index 0f4a1dc1..aef74b65 100644
--- a/daily/20230629/example_go/00test/main_test.go
+++ b/daily/20230629/example_go/00test/main_test.go
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"context"
 	"fmt"
 	"net/http"
 	"net/http/httptest"
@@ -15,21 +16,32 @@ const (
 )
 
 func Handler(w http.ResponseWriter, req *http.Request) {
-	v, ok := req.Context().Value(ctxValueKey).(string)
-	fmt.Println("get ", v, ok)
+	v, ok := req.Context().Value(ctxValueKey).([]string)
 	fmt.Fprintf(w, `{"value": %q, "ok": %t}`, v, ok)
 }
 
 func TestIt(t *testing.T) {
 	middleware := func(inner http.Handler) http.Handler {
 		return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+			v, _ := req.Context().Value(ctxValueKey).([]string)
+			req = req.WithContext(context.WithValue(req.Context(), ctxValueKey, append(v, "MIDDLEWARE")))
 			inner.ServeHTTP(w, req)
 		})
 	}
 	h := middleware(http.HandlerFunc(Handler))
+	middleware2 := func(inner http.Handler) http.Handler {
+		return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+			v, _ := req.Context().Value(ctxValueKey).([]string)
+			req = req.WithContext(context.WithValue(req.Context(), ctxValueKey, append(v, "MIDDLEWARE2")))
+			inner.ServeHTTP(w, req)
+		})
+	}
+	h = middleware2(h)
 
 	rec := httptest.NewRecorder()
 	req := httptest.NewRequest("GET", "/", nil)
+	v, _ := req.Context().Value(ctxValueKey).([]string)
+	req = req.WithContext(context.WithValue(req.Context(), ctxValueKey, append(v, "REQUEST")))
 	h.ServeHTTP(rec, req)
 	res := rec.Result()
 	if want, got := http.StatusOK, res.StatusCode; want != got {

帰ってくるのは

----------------------------------------
HTTP/1.1 200 OK
Connection: close
Content-Type: text/plain; charset=utf-8

{"value": ["REQUEST" "MIDDLEWARE2" "MIDDLEWARE"], "ok": true} <nil>
----------------------------------------
PASS
ok  	github.com/podhmo/individual-sandbox/daily/20230629/example_go/00test	0.005s

素直にset/getしたときにはこのsliceの末尾が取得される。今回の例ではMIDDLEWAREが返ってくる。

podhmopodhmo

値の取得時に優先順位をつけることにする

これとの組み合わせで、先頭を取り出すというルールにしてあげれば、見かけ上LIFOのような形で扱えなくもない(実際のところは真逆)。

しかし、これはこれで、実行時の処理を書き換えているところが気持ち悪い。

podhmopodhmo

middlewareの中でcontextに注入するように変えてみる。これは他のmiddlewareの影響を壊す。

--- 00test/main_test.go	2023-06-29 13:52:49.137498322 +0900
+++ 01with-middleware/main_test.go	2023-06-29 13:53:12.049987533 +0900
@@ -21,10 +21,16 @@
 }
 
 func TestIt(t *testing.T) {
+	hooks := []func(context.Context) context.Context{}
 	middleware := func(inner http.Handler) http.Handler {
 		return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
-			v, _ := req.Context().Value(ctxValueKey).([]string)
-			req = req.WithContext(context.WithValue(req.Context(), ctxValueKey, append(v, "MIDDLEWARE")))
+			ctx := req.Context()
+			v, _ := ctx.Value(ctxValueKey).([]string)
+			ctx = context.WithValue(ctx, ctxValueKey, append(v, "MIDDLEWARE"))
+			for _, m := range hooks {
+				ctx = m(ctx)
+			}
+			req = req.WithContext(ctx)
 			inner.ServeHTTP(w, req)
 		})
 	}
@@ -40,8 +46,11 @@
 
 	rec := httptest.NewRecorder()
 	req := httptest.NewRequest("GET", "/", nil)
-	v, _ := req.Context().Value(ctxValueKey).([]string)
-	req = req.WithContext(context.WithValue(req.Context(), ctxValueKey, append(v, "REQUEST")))
+
+	hooks = append(hooks, func(ctx context.Context) context.Context {
+		v, _ := ctx.Value(ctxValueKey).([]string)
+		return context.WithValue(ctx, ctxValueKey, append(v, "REQUEST"))
+	})
 	h.ServeHTTP(rec, req)
 	res := rec.Result()
 	if want, got := http.StatusOK, res.StatusCode; want != got {
podhmopodhmo
----------------------------------------
HTTP/1.1 200 OK
Connection: close
Content-Type: text/plain; charset=utf-8

{"value": ["MIDDLEWARE2" "MIDDLEWARE" "REQUEST"], "ok": true} <nil>
----------------------------------------
PASS
ok  	github.com/podhmo/individual-sandbox/daily/20230629/example_go/01with-middleware	0.006s
podhmopodhmo

何らかのhandlerを別途用意してそこに注入できるようにする

どうせ、これが必要になるのはcontextの注入と取り出しの時だけなのだから、そういう境界を作ってしまう。

  • InjectComponentsMiddleware
  • InjectComponentsFunction (for request)
podhmopodhmo
--- 00test/main_test.go	2023-06-29 14:09:25.738699941 +0900
+++ 02custom-func/main_test.go	2023-06-29 14:09:32.650363344 +0900
@@ -12,7 +12,8 @@
 type ctxKey string
 
 const (
-	ctxValueKey ctxKey = "value"
+	ctxValueKey  ctxKey = "value"
+	ctxInjectKey ctxKey = "inject"
 )
 
 func Handler(w http.ResponseWriter, req *http.Request) {
@@ -20,15 +21,38 @@
 	fmt.Fprintf(w, `{"value": %q, "ok": %t}`, v, ok)
 }
 
-func TestIt(t *testing.T) {
-	middleware := func(inner http.Handler) http.Handler {
+func InjectComponentsMiddlware(hooks ...func(context.Context) context.Context) func(http.Handler) http.Handler {
+	return func(inner http.Handler) http.Handler {
 		return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
-			v, _ := req.Context().Value(ctxValueKey).([]string)
-			req = req.WithContext(context.WithValue(req.Context(), ctxValueKey, append(v, "MIDDLEWARE")))
+			ctx := req.Context()
+			for _, inject := range hooks {
+				ctx = inject(ctx)
+			}
+			innerHooks, ok := ctx.Value(ctxInjectKey).([]func(context.Context) context.Context)
+			if ok {
+				for _, inject := range innerHooks {
+					ctx = inject(ctx)
+				}
+			}
+			req = req.WithContext(ctx)
 			inner.ServeHTTP(w, req)
 		})
 	}
+}
+
+func InjectComponentsFunction(req *http.Request, hooks ...func(context.Context) context.Context) *http.Request {
+	ctx := req.Context()
+	innerHooks, _ := ctx.Value(ctxInjectKey).([]func(context.Context) context.Context)
+	return req.WithContext(context.WithValue(ctx, ctxInjectKey, append(innerHooks, hooks...)))
+}
+
+func TestIt(t *testing.T) {
+	middleware := InjectComponentsMiddlware(func(ctx context.Context) context.Context {
+		v, _ := ctx.Value(ctxValueKey).([]string)
+		return context.WithValue(ctx, ctxValueKey, append(v, "MIDDLEWARE"))
+	})
 	h := middleware(http.HandlerFunc(Handler))
+
 	// middleware2 := func(inner http.Handler) http.Handler {
 	// 	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 	// 		v, _ := req.Context().Value(ctxValueKey).([]string)
@@ -40,9 +64,12 @@
 
 	rec := httptest.NewRecorder()
 	req := httptest.NewRequest("GET", "/", nil)
-	v, _ := req.Context().Value(ctxValueKey).([]string)
-	req = req.WithContext(context.WithValue(req.Context(), ctxValueKey, append(v, "REQUEST")))
+	req = InjectComponentsFunction(req, func(ctx context.Context) context.Context {
+		v, _ := ctx.Value(ctxValueKey).([]string)
+		return context.WithValue(ctx, ctxValueKey, append(v, "REQUEST"))
+	})
 	h.ServeHTTP(rec, req)
+
 	res := rec.Result()
 	if want, got := http.StatusOK, res.StatusCode; want != got {
 		t.Fatalf("unexpected status code: want=%d, but got=%d", want, got)
podhmopodhmo
----------------------------------------
HTTP/1.1 200 OK
Connection: close
Content-Type: text/plain; charset=utf-8

{"value": ["MIDDLEWARE" "REQUEST"], "ok": true} <nil>
----------------------------------------
PASS
ok  	github.com/podhmo/individual-sandbox/daily/20230629/example_go/02custom-func	0.005s
podhmopodhmo

:warning: ここでInjectComponentsMiddlewareが呼ばれていない場合の挙動はどうなるか?(テストでハマりそう)

podhmopodhmo

値の取得時に優先順位をつけることにする

sliceで持つ以外にsetの意味をget or setのような形で定義するようにしてみる。これは個別のものに定義していけばそれぞれに対してのみうまくいく。めんどくさいし。main.goでの上書きをしようとして無視されるみたいな挙動になる危険性があり、あまりうれしくない気もする。

podhmopodhmo
--- ../00test/main_test.go	2023-06-29 14:09:25.738699941 +0900
+++ main_test.go	2023-06-29 14:29:48.431461970 +0900
@@ -16,15 +16,23 @@
 )
 
 func Handler(w http.ResponseWriter, req *http.Request) {
-	v, ok := req.Context().Value(ctxValueKey).([]string)
+	v, ok := req.Context().Value(ctxValueKey).(string)
 	fmt.Fprintf(w, `{"value": %q, "ok": %t}`, v, ok)
 }
 
+func GetOrSetValue(ctx context.Context, v string) context.Context {
+	_, ok := ctx.Value(ctxValueKey).(string)
+	if ok {
+		return ctx
+	}
+	return context.WithValue(ctx, ctxValueKey, v)
+}
+
 func TestIt(t *testing.T) {
 	middleware := func(inner http.Handler) http.Handler {
 		return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
-			v, _ := req.Context().Value(ctxValueKey).([]string)
-			req = req.WithContext(context.WithValue(req.Context(), ctxValueKey, append(v, "MIDDLEWARE")))
+			ctx := GetOrSetValue(req.Context(), "MIDDLEWARE")
+			req = req.WithContext(ctx)
 			inner.ServeHTTP(w, req)
 		})
 	}
@@ -40,8 +48,7 @@
 
 	rec := httptest.NewRecorder()
 	req := httptest.NewRequest("GET", "/", nil)
-	v, _ := req.Context().Value(ctxValueKey).([]string)
-	req = req.WithContext(context.WithValue(req.Context(), ctxValueKey, append(v, "REQUEST")))
+	req = req.WithContext(GetOrSetValue(req.Context(), "REQUEST"))
 	h.ServeHTTP(rec, req)
 	res := rec.Result()
 	if want, got := http.StatusOK, res.StatusCode; want != got {
podhmopodhmo
----------------------------------------
HTTP/1.1 200 OK
Connection: close
Content-Type: text/plain; charset=utf-8

{"value": "REQUEST", "ok": true} <nil>
----------------------------------------
PASS
ok  	github.com/podhmo/individual-sandbox/daily/20230629/example_go/03get-or-set	0.007s
podhmopodhmo

期待通りの挙動を示すが何も解決してない気もする。requestへの装飾が優先されるような動きをして欲しい感覚があるがこれはデフォルトの挙動を記憶して利用してるだけ。