🎃
【Go】OAuth2.0で使用するstate, code_verifier, code_challengeを生成するサンプルコード
概要
- OAuth2.0で使用するstate, code_verifier, code_challengeを生成するサンプルコード
- code_verifier, code_challengeの生成はoauth2パッケージがある。自前で実装してしまったので供養のために記載
- 参考
サンプルコード
実装
//go:generate mockgen -source=oauth.go -destination=./mock/oauth_mock.go
package main
import (
"crypto/sha256"
"encoding/base64"
"errors"
"golang.org/x/oauth2"
)
type Oauth interface {
GenerateRandomString(length int, charSet string) (string, error)
GenerateState() (string, error)
GenerateCodeVerifier() (string, error)
GenerateCodeChallenge(verifier string) (string, error)
}
type oauth struct {
readRandomBytes func([]byte) (int, error)
}
func NewOauth(readRandomBytes func([]byte) (int, error)) Oauth {
return &oauth{
readRandomBytes: readRandomBytes,
}
}
const (
stateLength = 32
verifierLength = 128
charSet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~"
)
func (o *oauth) GenerateRandomString(length int, charSet string) (string, error) {
result := make([]byte, length)
for i := range result {
b := make([]byte, 1)
_, err := o.readRandomBytes(b)
if err != nil {
return "", err
}
result[i] = charSet[b[0]%byte(len(charSet))]
}
return string(result), nil
}
func (o *oauth) GenerateState() (string, error) {
return o.GenerateRandomString(stateLength, charSet)
}
func (o *oauth) GenerateCodeVerifier() (string, error) {
return o.GenerateRandomString(verifierLength, charSet)
}
func (o *oauth) GenerateCodeChallenge(verifier string) (string, error) {
if verifier == "" {
return "", errors.New("code_verifier cannot be empty")
}
hash := sha256.Sum256([]byte(verifier))
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
return codeChallenge, nil
}
解説
- テストしやすいようにmock化
- stateの長さは仕様では特に定義されていなかったので「CSRF トークン 長さ」と検索してで出てきたこちらを参考に32とした
- stateに使用する文字は特に定義されていなかったので、code_verifierで定義されている文字列を使用
- code_verifierの長さは最長の128
- oauth2パッケージ使用の場合は長さ43の文字列が生成される模様
- GenerateVerifier():https://cs.opensource.google/go/x/oauth2/+/refs/tags/v0.25.0:pkce.go;l=26
- oauth2パッケージ使用の場合は長さ43の文字列が生成される模様
おまけ
使用例
package main
type Handler interface {
GenerateCodeChallenges() (*GenerateCodeChallengesResponse, error)
}
type handler struct {
oauth Oauth
}
func NewHandler(
oauth Oauth,
) Handler {
return &handler{
oauth: oauth,
}
}
type GenerateCodeChallengesResponse struct {
State string `json:"state" example:"state"`
CodeVerifier string `json:"code_verifier" example:"code_verifier"`
CodeChallenge string `json:"code_challenge" example:"code_challenge"`
}
func (h *handler) GenerateCodeChallenges() (*GenerateCodeChallengesResponse, error) {
state, err := h.oauth.GenerateState()
if err != nil {
return nil, err
}
codeVerifier, err := h.oauth.GenerateCodeVerifier()
if err != nil {
return nil, err
}
codeChallenge, err := h.oauth.GenerateCodeChallenge(codeVerifier)
if err != nil {
return nil, err
}
res := &GenerateCodeChallengesResponse{
State: state,
CodeVerifier: codeVerifier,
CodeChallenge: codeChallenge,
}
return res, nil
}
func main() {
o := NewOauth(rand.Read)
h := NewHandler(o)
res, err := h.GenerateCodeChallenges()
if err != nil {
panic(err)
}
fmt.Printf("State: %+v\n", res.State)
fmt.Printf("CodeVerifier: %+v\n", res.CodeVerifier)
fmt.Printf("CodeChallenge: %+v\n", res.CodeChallenge)
}
// 出力
// State: pDDUp-iuVF3AJs.L03u-KIYvZw0kBzvn
// CodeVerifier: 1ZV3COMU.9gGrC_PtkKIJUxg5qrSSZ7UhxRiqt1094aFpSVH49R2_SvwbgLyrIolgHH3ZNbvtbEr.z1klF-fjXxHaDef7NAzpqTJqstKpYlRlhjJS~IhwdMNKhH0BjTo
// CodeChallenge: 55Ajql4F8Ff9EsrfY3SRKHrRwbGQ4JGNoDeabSYjOHQ
テストコード
package main
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"strings"
"testing"
)
func TestGenerateRandomString(t *testing.T) {
length := 10
charSet := "ABCDEF"
// 正常系: 指定した文字数と文字列で値が生成されること
o := NewOauth(rand.Read)
result, err := o.GenerateRandomString(length, charSet)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// 生成された文字列の文字数をチェック
if len(result) != length {
t.Errorf("expected length %d, got %d", length, len(result))
}
// 生成された文字列の各文字が、charSetに存在するかチェック
for _, char := range result {
if !strings.ContainsRune(charSet, char) {
t.Errorf("generated character %c is not in the allowed charset %s", char, charSet)
}
}
// 異常系: ReadRandomBytes がエラーを返す場合
mockReadRandomBytes := func(_ []byte) (int, error) {
return 0, errors.New("mocked error")
}
mockOauth := NewOauth(mockReadRandomBytes)
_, err = mockOauth.GenerateRandomString(length, charSet)
if err == nil {
t.Error("expected error, but got nil")
} else if err.Error() != "mocked error" {
t.Errorf("expected mocked error, got %v", err)
}
}
func TestGenerateState(t *testing.T) {
// 指定した文字数と文字列でstateが生成されること
o := NewOauth(rand.Read)
state, err := o.GenerateState()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// 生成されたstateの文字数チェック
if len(state) != stateLength {
t.Errorf("expected state length of 16, got %d", len(state))
}
// 生成されたstateの各文字が、charSetに存在するかチェック
for _, char := range state {
if !containsRune(charSet, char) {
t.Errorf("character %c is not in the allowed charset", char)
}
}
}
func TestGenerateCodeVerifier(t *testing.T) {
// 指定した文字数と文字列でverifierが生成されること
o := NewOauth(rand.Read)
verifier, err := o.GenerateCodeVerifier()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// 生成されたverifierの文字数チェック
if len(verifier) != verifierLength {
t.Errorf("expected verifier length of 128, got %d", len(verifier))
}
// 生成されたverifierの各文字が、charSetに存在するかチェック
for _, char := range verifier {
if !containsRune(charSet, char) {
t.Errorf("character %c is not in the allowed charset", char)
}
}
}
func TestGenerateCodeChallenge(t *testing.T) {
o := NewOauth(rand.Read)
// 正常系: 生成されたcodeChallengeがexpectedChallengeと等しいこと
verifier := "test-code-verifier"
expectedHash := sha256.Sum256([]byte(verifier))
expectedChallenge := base64.RawURLEncoding.EncodeToString(expectedHash[:])
challenge, err := o.GenerateCodeChallenge(verifier)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if challenge != expectedChallenge {
t.Errorf("expected code challenge %s, got %s", expectedChallenge, challenge)
}
// 異常系: verifier が空の場合
_, err = o.GenerateCodeChallenge("")
if err == nil {
t.Error("expected error for empty code_verifier, but got nil")
}
expectedErrMsg := "code_verifier cannot be empty"
if err.Error() != expectedErrMsg {
t.Errorf("unexpected error message: %v", err)
}
}
func containsRune(charSet string, char rune) bool {
for _, c := range charSet {
if c == char {
return true
}
}
return false
}
Discussion