GoのAPIサーバーでミドルウェアを使用した共通処理の実装を理解する
はじめに
APIサーバーでミドルウェアを使用した共通処理の実装がいまいち理解できていなかったので、
1.22のバージョンアップでルーティング周りに改善が入ったGoで理解を深めていきたいと思います。
対象読者
- APIサーバーにおけるミドルウェアの実装を理解したい方
GoのContextとは?
以下の神本をお読みください。
無料なので。
GoでのContextは以下の機能を持ちます。
- 値を伝搬可能
- キャンセルを伝搬可能
- エラーを伝搬可能
- goroutine間での伝搬可能
使い方
- 初期Contextを生成
ctx := context.Background()
- 値を付加
newCtx := context.WithValue(ctx,"key","id")
- タイムアウトを設定
newctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
httpハンドラーにミドルウェアを設定する
サンプルコード
サンプルコード
package main
import "github.com/o-ga09/go122rcsample/api/internal/presenter"
func main() {
s := presenter.NewServer("8080")
s.Run()
}
package presenter
import (
"fmt"
"log/slog"
"net"
"net/http"
"os"
"os/signal"
"github.com/o-ga09/go122rcsample/api/internal/controller"
"github.com/o-ga09/go122rcsample/api/internal/middleware"
)
type Server struct {
port string
}
func NewServer(port string) *Server {
return &Server{port: port}
}
func (s *Server) Run() {
listen, err := net.Listen("tcp", fmt.Sprintf(":%s", s.port))
if err != nil {
panic(err)
}
mux := http.NewServeMux()
HealthCheckhandler := http.HandlerFunc(controller.Health)
HealthCheckhandler = middleware.UseMiddleware(HealthCheckhandler)
GetUserhandler := http.HandlerFunc(controller.GetUsers)
GetUserhandler = middleware.UseMiddleware(GetUserhandler)
CreateUserhandler := http.HandlerFunc(controller.CreateUser)
CreateUserhandler = middleware.UseMiddleware(CreateUserhandler)
mux.HandleFunc("GET /", HealthCheckhandler)
mux.HandleFunc("GET /users/{id}", GetUserhandler)
mux.HandleFunc("POST /users", CreateUserhandler)
slog.Info("starting server")
go func() {
err = http.Serve(listen, mux)
if err != nil {
panic(err)
}
}()
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt)
<-quit
slog.Info("stopping srever")
}
package controller
import (
"database/sql"
"encoding/json"
"log/slog"
"net/http"
_ "github.com/go-sql-driver/mysql"
"github.com/o-ga09/go122rcsample/api/internal/config"
"github.com/o-ga09/go122rcsample/api/internal/middleware"
)
func GetUsers(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
cfg, _ := config.New()
db, err := sql.Open("mysql", cfg.Database_url)
if err != nil {
slog.Log(r.Context(), middleware.SeverityError, "con not get environment value")
return
}
defer func() {
slog.Log(r.Context(), middleware.SeverityInfo, "db disconnect ....", "requestId", middleware.GetRequestID(r.Context()))
db.Close()
}()
query := "SELECT * FROM users WHERE id = ?"
var uid int
var name string
var email string
err = db.QueryRow(query, id).Scan(&uid, &email, &name)
if err != nil {
if err == sql.ErrNoRows {
slog.Log(r.Context(), middleware.SeverityInfo, "no rows", "requestId", middleware.GetRequestID(r.Context()))
return
}
slog.Log(r.Context(), middleware.SeverityError, "panic error", "error message", err, "requestId", middleware.GetRequestID(r.Context()))
}
result := struct {
ID int `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
}{
ID: uid,
Name: name,
Email: email,
}
slog.Log(r.Context(), middleware.SeverityInfo, "result", "data", result, "requestId", middleware.GetRequestID(r.Context()))
middleware.Response(&w, http.StatusOK, result)
}
func CreateUser(w http.ResponseWriter, r *http.Request) {
reqBody := struct {
Name string `json:"name"`
Email string `json:"email"`
}{}
err := json.NewDecoder(r.Body).Decode(&reqBody)
if err != nil {
slog.Log(r.Context(), middleware.SeverityInfo, "can not get request body", "requestId", middleware.GetRequestID(r.Context()))
return
}
cfg, _ := config.New()
db, err := sql.Open("mysql", cfg.Database_url)
if err != nil {
slog.Log(r.Context(), middleware.SeverityError, "db connect error...")
panic(err)
}
defer func() {
slog.Log(r.Context(), middleware.SeverityInfo, "db disconnect ....", "requestId", middleware.GetRequestID(r.Context()))
db.Close()
}()
name := reqBody.Name
email := reqBody.Email
sql := "INSERT INTO users (name, email) VALUES (?, ?)"
_, err = db.Exec(sql, name, email)
if err != nil {
slog.Log(r.Context(), middleware.SeverityError, "can not insert", "error message", err, "requestId", middleware.GetRequestID(r.Context()))
return
}
middleware.Response(&w, http.StatusCreated, reqBody)
}
package controller
import (
"fmt"
"net/http"
)
func Health(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Hellow World go 1.22 ! from GET\n")
}
package middleware
import (
"context"
"net/http"
"time"
"github.com/o-ga09/go122rcsample/api/pkg"
)
type RequestId string
// AddIDはリクエスト毎にIDを付与するmiddlewareです。
func AddID(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// IDを生成してcontextに保存
id := pkg.GenerateID()
ctx := context.WithValue(r.Context(), RequestId("requestId"), id)
// 次のハンドラーに渡す
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// WithTimeoutはIDを追加するmiddlewareです。
func WithTimeout(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Context() == nil {
r = r.WithContext(context.Background())
}
// タイムアウトを設定
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel() // 処理が終了したらキャンセルする
// 次のハンドラーを実行し、タイムアウトが発生した場合はエラーメッセージを出力
done := make(chan struct{})
go func() {
defer close(done)
next.ServeHTTP(w, r)
}()
select {
case <-done:
// ハンドラーが正常に終了した場合は何もしない
return
case <-ctx.Done():
http.Error(w, "Timeout", http.StatusRequestTimeout)
}
})
}
func GetRequestID(ctx context.Context) string {
return ctx.Value(RequestId("requestId")).(string)
}
func UseMiddleware(handler http.HandlerFunc) http.HandlerFunc {
handler = WithTimeout(handler)
handler = RequestLogger(handler)
handler = AddID(handler)
handler = Logger(handler)
return handler
}
package middleware
import (
"context"
"fmt"
"log/slog"
"net/http"
"os"
"cloud.google.com/go/logging"
"github.com/o-ga09/go122rcsample/api/internal/config"
"go.opentelemetry.io/otel/trace"
)
// cloud logging の Log level 定義
var (
Severitydefault = slog.Level(logging.Default)
SeverityInfo = slog.Level(logging.Info)
SeverityWarn = slog.Level(logging.Warning)
SeverityError = slog.Level(logging.Error)
SeverityNotice = slog.Level(logging.Notice)
)
// traceId , spanId 追加
type traceHandler struct {
slog.Handler
projectID string
}
// traceHandler 実装
func (h *traceHandler) Enabled(ctx context.Context, l slog.Level) bool {
return h.Handler.Enabled(ctx, l)
}
func (h *traceHandler) Handle(ctx context.Context, r slog.Record) error {
if sc := trace.SpanContextFromContext(ctx); sc.IsValid() {
trace := fmt.Sprintf("projects/%s/traces/%s", h.projectID, sc.TraceID().String())
r.AddAttrs(slog.String("logging.googleapis.com/trace", trace),
slog.String("logging.googleapis.com/spanId", sc.SpanID().String()))
}
return h.Handler.Handle(ctx, r)
}
func (h *traceHandler) WithAttr(attrs []slog.Attr) slog.Handler {
return &traceHandler{h.Handler.WithAttrs(attrs), h.projectID}
}
func (h *traceHandler) WithGroup(g string) slog.Handler {
return h.Handler.WithGroup(g)
}
// logger 生成関数
func Logger(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
replacer := func(groups []string, a slog.Attr) slog.Attr {
if a.Key == slog.MessageKey {
a.Key = "message"
}
if a.Key == slog.LevelKey {
a.Key = "severity"
a.Value = slog.StringValue(logging.Severity(a.Value.Any().(slog.Level)).String())
}
if a.Key == slog.SourceKey {
a.Key = "logging.googleapis.com/sourceLocation"
}
return a
}
cfg, _ := config.New()
projectID := cfg.ProjectID
h := traceHandler{slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{AddSource: true, ReplaceAttr: replacer}), projectID}
newh := h.WithAttr([]slog.Attr{
slog.Group("logging.googleapis.com/labels",
slog.String("app", "MH-API"),
slog.String("env", cfg.Env),
),
})
logger := slog.New(newh)
slog.SetDefault(logger)
next.ServeHTTP(w, r)
})
}
package middleware
import (
"log/slog"
"net/http"
"time"
)
type RequestInfo struct {
ContentsLength int64
Path string
SourceIP string
Query string
UserAgent string
Errors string
Elapsed time.Duration
}
func (r *RequestInfo) LogValue() interface{} { // Assuming slog expects an interface{}
return map[string]interface{}{
"contents_length": r.ContentsLength,
"path": r.Path,
"sourceIP": r.SourceIP,
"query": r.Query,
"user_agent": r.UserAgent,
"errors": r.Errors,
"elapsed": r.Elapsed.String(),
}
}
func RequestLogger(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slog.Log(r.Context(), SeverityInfo, "処理開始", "requestId", GetRequestID(r.Context()))
start := time.Now()
next.ServeHTTP(w, r)
req := RequestInfo{
ContentsLength: r.ContentLength,
Path: r.RequestURI,
SourceIP: r.RemoteAddr,
Query: r.URL.RawQuery,
UserAgent: r.UserAgent(),
Errors: "errors",
Elapsed: time.Since(start),
}
slog.Log(r.Context(), SeverityInfo, "処理終了", "Request", req.LogValue(), "requestId", GetRequestID(r.Context())) // Adjust logging context as needed
})
}
package middleware
import (
"encoding/json"
"net/http"
)
func Response(w *http.ResponseWriter, status int, message interface{}) {
json, err := json.Marshal(message)
if err != nil {
(*w).Header().Set("Content-Type", "application/json")
(*w).WriteHeader(http.StatusInternalServerError)
(*w).Write([]byte(`{"message": "error marshalling json"}`))
return
}
(*w).Header().Set("Content-Type", "application/json")
(*w).WriteHeader(status)
(*w).Write(json)
}
package config
import "github.com/caarlos0/env"
type Config struct {
Env string `env:"ENV" envDefault:"dev"`
Port string `env:"PORT" envDefault:"80"`
Database_url string `env:"DATABASE_URL" envDefult:""`
ProjectID string `env:"PROJECTID" envDefault:""`
}
func New() (*Config, error) {
cfg := &Config{}
if err := env.Parse(cfg); err != nil {
return nil, err
}
return cfg, nil
}
解説
例として、requestIdを付与するミドルウェアを見てみましょう。
func AddID(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// IDを生成してcontextに保存
id := pkg.GenerateID()
ctx := context.WithValue(r.Context(), RequestId("requestId"), id)
// 次のハンドラーに渡す
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func AddID(next http.HandlerFunc) http.HandlerFunc
は、引数に次のハンドラーに渡すためのhttp.handler
型の変数を、
戻り値にhttp.HandlerFunc
型を持ちます。
処理の最後に、next.ServeHTTP
で次のハンドラに処理を渡します。
next.ServeHTTP
を渡してミドルウェアを繋いで最後にアプリケーションハンドラーを呼び出します。
また、以下のようにすることでnext.ServeHTTP
で処理を次のハンドラーに渡すけど、アプリケーションハンドラーまで行ったところで、
また戻って処理するようにできます。
func RequestLogger(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slog.Log(r.Context(), SeverityInfo, "処理開始", "requestId", GetRequestID(r.Context()))
start := time.Now()
next.ServeHTTP(w, r)
req := RequestInfo{
ContentsLength: r.ContentLength,
Path: r.RequestURI,
SourceIP: r.RemoteAddr,
Query: r.URL.RawQuery,
UserAgent: r.UserAgent(),
Errors: "errors",
Elapsed: time.Since(start),
}
slog.Log(r.Context(), SeverityInfo, "処理終了", "Request", req.LogValue(), "requestId", GetRequestID(r.Context()))
})
}
ハンドラーにミドルウェアを登録するには以下のようにします。
UseMiddleware
は、ただ、一つにまとめただけです。
本当は、ginのように、r.Use
のようにしたかったのですが、
これからやってみたいところではあります。
なので、ハンドラー毎に、UseMiddleware
を呼び出さないといけないです。
呼び出し順は、Logger
→AddId
→RequestLogger
→WithTimeout
→HealthCheck
→RequestLogger
(nextの後ろの処理)となります。
HealthCheckhandler := http.HandlerFunc(controller.Health)
HealthCheckhandler = middleware.UseMiddleware(HealthCheckhandler)
func UseMiddleware(handler http.HandlerFunc) http.HandlerFunc {
handler = WithTimeout(handler)
handler = RequestLogger(handler)
handler = AddID(handler)
handler = Logger(handler)
return handler
}
まとめ
APIのリクエストを処理するミドルウェアの部分は理解が曖昧でしたが、自分で実装することで、
かなり理解が進みました。
言語はなんでも良いので、一度は、フレームワークを使わずに実装してみてはいかがでしょうか。
よろしければいいねをよろしくお願いいたします。
Discussion