🏕️

GoのAPIサーバーでミドルウェアを使用した共通処理の実装を理解する

2024/02/19に公開

はじめに

APIサーバーでミドルウェアを使用した共通処理の実装がいまいち理解できていなかったので、
1.22のバージョンアップでルーティング周りに改善が入ったGoで理解を深めていきたいと思います。

対象読者

  • APIサーバーにおけるミドルウェアの実装を理解したい方

GoのContextとは?

以下の神本をお読みください。
無料なので。

https://zenn.dev/hsaki/books/golang-context

GoでのContextは以下の機能を持ちます。

  • 値を伝搬可能
  • キャンセルを伝搬可能
  • エラーを伝搬可能
  • goroutine間での伝搬可能

使い方

  1. 初期Contextを生成
context
ctx := context.Background()
  1. 値を付加
context
newCtx := context.WithValue(ctx,"key","id")
  1. タイムアウトを設定
context
newctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()

httpハンドラーにミドルウェアを設定する

サンプルコード

サンプルコード
main.go
package main

import "github.com/o-ga09/go122rcsample/api/internal/presenter"

func main() {
	s := presenter.NewServer("8080")
	s.Run()
}
api/presenter/server.go
package presenter

import (
	"fmt"
	"log/slog"
	"net"
	"net/http"
	"os"
	"os/signal"

	"github.com/o-ga09/go122rcsample/api/internal/controller"
	"github.com/o-ga09/go122rcsample/api/internal/middleware"
)

type Server struct {
	port string
}

func NewServer(port string) *Server {
	return &Server{port: port}
}

func (s *Server) Run() {
	listen, err := net.Listen("tcp", fmt.Sprintf(":%s", s.port))
	if err != nil {
		panic(err)
	}

	mux := http.NewServeMux()

	HealthCheckhandler := http.HandlerFunc(controller.Health)
	HealthCheckhandler = middleware.UseMiddleware(HealthCheckhandler)

	GetUserhandler := http.HandlerFunc(controller.GetUsers)
	GetUserhandler = middleware.UseMiddleware(GetUserhandler)

	CreateUserhandler := http.HandlerFunc(controller.CreateUser)
	CreateUserhandler = middleware.UseMiddleware(CreateUserhandler)

	mux.HandleFunc("GET /", HealthCheckhandler)
	mux.HandleFunc("GET /users/{id}", GetUserhandler)
	mux.HandleFunc("POST /users", CreateUserhandler)

	slog.Info("starting server")
	go func() {
		err = http.Serve(listen, mux)
		if err != nil {
			panic(err)
		}
	}()

	quit := make(chan os.Signal, 1)
	signal.Notify(quit, os.Interrupt)
	<-quit
	slog.Info("stopping srever")
}
api/controller/user.go
package controller

import (
	"database/sql"
	"encoding/json"
	"log/slog"
	"net/http"

	_ "github.com/go-sql-driver/mysql"
	"github.com/o-ga09/go122rcsample/api/internal/config"
	"github.com/o-ga09/go122rcsample/api/internal/middleware"
)

func GetUsers(w http.ResponseWriter, r *http.Request) {
	id := r.PathValue("id")
	cfg, _ := config.New()
	db, err := sql.Open("mysql", cfg.Database_url)
	if err != nil {
		slog.Log(r.Context(), middleware.SeverityError, "con not get environment value")
		return
	}
	defer func() {
		slog.Log(r.Context(), middleware.SeverityInfo, "db disconnect ....", "requestId", middleware.GetRequestID(r.Context()))
		db.Close()
	}()

	query := "SELECT * FROM users WHERE id = ?"
	var uid int
	var name string
	var email string
	err = db.QueryRow(query, id).Scan(&uid, &email, &name)

	if err != nil {
		if err == sql.ErrNoRows {
			slog.Log(r.Context(), middleware.SeverityInfo, "no rows", "requestId", middleware.GetRequestID(r.Context()))
			return
		}
		slog.Log(r.Context(), middleware.SeverityError, "panic error", "error message", err, "requestId", middleware.GetRequestID(r.Context()))
	}
	result := struct {
		ID    int    `json:"id"`
		Name  string `json:"name"`
		Email string `json:"email"`
	}{
		ID:    uid,
		Name:  name,
		Email: email,
	}
	slog.Log(r.Context(), middleware.SeverityInfo, "result", "data", result, "requestId", middleware.GetRequestID(r.Context()))
	middleware.Response(&w, http.StatusOK, result)
}

func CreateUser(w http.ResponseWriter, r *http.Request) {
	reqBody := struct {
		Name  string `json:"name"`
		Email string `json:"email"`
	}{}

	err := json.NewDecoder(r.Body).Decode(&reqBody)
	if err != nil {
		slog.Log(r.Context(), middleware.SeverityInfo, "can not get request body", "requestId", middleware.GetRequestID(r.Context()))
		return
	}

	cfg, _ := config.New()
	db, err := sql.Open("mysql", cfg.Database_url)
	if err != nil {
		slog.Log(r.Context(), middleware.SeverityError, "db connect error...")
		panic(err)
	}
	defer func() {
		slog.Log(r.Context(), middleware.SeverityInfo, "db disconnect ....", "requestId", middleware.GetRequestID(r.Context()))
		db.Close()
	}()

	name := reqBody.Name
	email := reqBody.Email
	sql := "INSERT INTO users (name, email) VALUES (?, ?)"
	_, err = db.Exec(sql, name, email)
	if err != nil {
		slog.Log(r.Context(), middleware.SeverityError, "can not insert", "error message", err, "requestId", middleware.GetRequestID(r.Context()))
		return
	}
	middleware.Response(&w, http.StatusCreated, reqBody)
}
api/controller/health_check.go
package controller

import (
	"fmt"
	"net/http"
)

func Health(w http.ResponseWriter, r *http.Request) {

	fmt.Fprint(w, "Hellow World go 1.22 ! from GET\n")
}
api/middleware/context.go
package middleware

import (
	"context"
	"net/http"
	"time"

	"github.com/o-ga09/go122rcsample/api/pkg"
)

type RequestId string

// AddIDはリクエスト毎にIDを付与するmiddlewareです。
func AddID(next http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// IDを生成してcontextに保存
		id := pkg.GenerateID()
		ctx := context.WithValue(r.Context(), RequestId("requestId"), id)
		// 次のハンドラーに渡す
		next.ServeHTTP(w, r.WithContext(ctx))
	})
}

// WithTimeoutはIDを追加するmiddlewareです。
func WithTimeout(next http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if r.Context() == nil {
			r = r.WithContext(context.Background())
		}

		// タイムアウトを設定
		ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
		defer cancel() // 処理が終了したらキャンセルする

		// 次のハンドラーを実行し、タイムアウトが発生した場合はエラーメッセージを出力
		done := make(chan struct{})
		go func() {
			defer close(done)
			next.ServeHTTP(w, r)
		}()
		select {
		case <-done:
			// ハンドラーが正常に終了した場合は何もしない
			return
		case <-ctx.Done():
			http.Error(w, "Timeout", http.StatusRequestTimeout)
		}
	})
}

func GetRequestID(ctx context.Context) string {
	return ctx.Value(RequestId("requestId")).(string)
}

func UseMiddleware(handler http.HandlerFunc) http.HandlerFunc {
	handler = WithTimeout(handler)
	handler = RequestLogger(handler)
	handler = AddID(handler)
	handler = Logger(handler)

	return handler
}
api/middleware/logger.go
package middleware

import (
	"context"
	"fmt"
	"log/slog"
	"net/http"
	"os"

	"cloud.google.com/go/logging"
	"github.com/o-ga09/go122rcsample/api/internal/config"
	"go.opentelemetry.io/otel/trace"
)

// cloud logging の Log level 定義
var (
	Severitydefault = slog.Level(logging.Default)
	SeverityInfo    = slog.Level(logging.Info)
	SeverityWarn    = slog.Level(logging.Warning)
	SeverityError   = slog.Level(logging.Error)
	SeverityNotice  = slog.Level(logging.Notice)
)

// traceId , spanId 追加
type traceHandler struct {
	slog.Handler
	projectID string
}

// traceHandler 実装
func (h *traceHandler) Enabled(ctx context.Context, l slog.Level) bool {
	return h.Handler.Enabled(ctx, l)
}

func (h *traceHandler) Handle(ctx context.Context, r slog.Record) error {
	if sc := trace.SpanContextFromContext(ctx); sc.IsValid() {
		trace := fmt.Sprintf("projects/%s/traces/%s", h.projectID, sc.TraceID().String())
		r.AddAttrs(slog.String("logging.googleapis.com/trace", trace),
			slog.String("logging.googleapis.com/spanId", sc.SpanID().String()))
	}

	return h.Handler.Handle(ctx, r)
}

func (h *traceHandler) WithAttr(attrs []slog.Attr) slog.Handler {
	return &traceHandler{h.Handler.WithAttrs(attrs), h.projectID}
}

func (h *traceHandler) WithGroup(g string) slog.Handler {
	return h.Handler.WithGroup(g)
}

// logger 生成関数
func Logger(next http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		replacer := func(groups []string, a slog.Attr) slog.Attr {
			if a.Key == slog.MessageKey {
				a.Key = "message"
			}

			if a.Key == slog.LevelKey {
				a.Key = "severity"
				a.Value = slog.StringValue(logging.Severity(a.Value.Any().(slog.Level)).String())
			}

			if a.Key == slog.SourceKey {
				a.Key = "logging.googleapis.com/sourceLocation"
			}

			return a
		}
		cfg, _ := config.New()
		projectID := cfg.ProjectID
		h := traceHandler{slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{AddSource: true, ReplaceAttr: replacer}), projectID}
		newh := h.WithAttr([]slog.Attr{
			slog.Group("logging.googleapis.com/labels",
				slog.String("app", "MH-API"),
				slog.String("env", cfg.Env),
			),
		})
		logger := slog.New(newh)
		slog.SetDefault(logger)
		next.ServeHTTP(w, r)
	})
}
api/middleware/RequestInfo.go
package middleware

import (
	"log/slog"
	"net/http"
	"time"
)

type RequestInfo struct {
	ContentsLength int64
	Path           string
	SourceIP       string
	Query          string
	UserAgent      string
	Errors         string
	Elapsed        time.Duration
}

func (r *RequestInfo) LogValue() interface{} { // Assuming slog expects an interface{}
	return map[string]interface{}{
		"contents_length": r.ContentsLength,
		"path":            r.Path,
		"sourceIP":        r.SourceIP,
		"query":           r.Query,
		"user_agent":      r.UserAgent,
		"errors":          r.Errors,
		"elapsed":         r.Elapsed.String(),
	}
}

func RequestLogger(next http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		slog.Log(r.Context(), SeverityInfo, "処理開始", "requestId", GetRequestID(r.Context()))
		start := time.Now()

		next.ServeHTTP(w, r)

		req := RequestInfo{
			ContentsLength: r.ContentLength,
			Path:           r.RequestURI,
			SourceIP:       r.RemoteAddr,
			Query:          r.URL.RawQuery,
			UserAgent:      r.UserAgent(),
			Errors:         "errors",
			Elapsed:        time.Since(start),
		}

		slog.Log(r.Context(), SeverityInfo, "処理終了", "Request", req.LogValue(), "requestId", GetRequestID(r.Context())) // Adjust logging context as needed
	})
}
api/middleware/response.go
package middleware

import (
	"encoding/json"
	"net/http"
)

func Response(w *http.ResponseWriter, status int, message interface{}) {
	json, err := json.Marshal(message)
	if err != nil {
		(*w).Header().Set("Content-Type", "application/json")
		(*w).WriteHeader(http.StatusInternalServerError)
		(*w).Write([]byte(`{"message": "error marshalling json"}`))
		return
	}

	(*w).Header().Set("Content-Type", "application/json")
	(*w).WriteHeader(status)
	(*w).Write(json)
}
api/config/config.go
package config

import "github.com/caarlos0/env"

type Config struct {
	Env          string `env:"ENV" envDefault:"dev"`
	Port         string `env:"PORT" envDefault:"80"`
	Database_url string `env:"DATABASE_URL" envDefult:""`
	ProjectID    string `env:"PROJECTID" envDefault:""`
}

func New() (*Config, error) {
	cfg := &Config{}
	if err := env.Parse(cfg); err != nil {
		return nil, err
	}
	return cfg, nil
}

解説

例として、requestIdを付与するミドルウェアを見てみましょう。

context.go
func AddID(next http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// IDを生成してcontextに保存
		id := pkg.GenerateID()
		ctx := context.WithValue(r.Context(), RequestId("requestId"), id)
		// 次のハンドラーに渡す
		next.ServeHTTP(w, r.WithContext(ctx))
	})
}

func AddID(next http.HandlerFunc) http.HandlerFuncは、引数に次のハンドラーに渡すためのhttp.handler型の変数を、
戻り値にhttp.HandlerFunc型を持ちます。

処理の最後に、next.ServeHTTPで次のハンドラに処理を渡します。
next.ServeHTTPを渡してミドルウェアを繋いで最後にアプリケーションハンドラーを呼び出します。

また、以下のようにすることでnext.ServeHTTPで処理を次のハンドラーに渡すけど、アプリケーションハンドラーまで行ったところで、
また戻って処理するようにできます。

context.go
func RequestLogger(next http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		slog.Log(r.Context(), SeverityInfo, "処理開始", "requestId", GetRequestID(r.Context()))
		start := time.Now()

		next.ServeHTTP(w, r)

		req := RequestInfo{
			ContentsLength: r.ContentLength,
			Path:           r.RequestURI,
			SourceIP:       r.RemoteAddr,
			Query:          r.URL.RawQuery,
			UserAgent:      r.UserAgent(),
			Errors:         "errors",
			Elapsed:        time.Since(start),
		}

		slog.Log(r.Context(), SeverityInfo, "処理終了", "Request", req.LogValue(), "requestId", GetRequestID(r.Context())) 
	})
}

ハンドラーにミドルウェアを登録するには以下のようにします。
UseMiddlewareは、ただ、一つにまとめただけです。
本当は、ginのように、r.Useのようにしたかったのですが、
これからやってみたいところではあります。

なので、ハンドラー毎に、UseMiddlewareを呼び出さないといけないです。

呼び出し順は、LoggerAddIdRequestLoggerWithTimeoutHealthCheckRequestLogger(nextの後ろの処理)となります。

server.go
HealthCheckhandler := http.HandlerFunc(controller.Health)
HealthCheckhandler = middleware.UseMiddleware(HealthCheckhandler)
context.go
func UseMiddleware(handler http.HandlerFunc) http.HandlerFunc {
	handler = WithTimeout(handler)
	handler = RequestLogger(handler)
	handler = AddID(handler)
	handler = Logger(handler)

	return handler
}

まとめ

APIのリクエストを処理するミドルウェアの部分は理解が曖昧でしたが、自分で実装することで、
かなり理解が進みました。

言語はなんでも良いので、一度は、フレームワークを使わずに実装してみてはいかがでしょうか。

よろしければいいねをよろしくお願いいたします。

Discussion