You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
nucleus/internal/controllers/oauth2_controller.go

223 lines
6.2 KiB

package controllers
import (
"context"
"encoding/json"
"errors"
"github.com/go-oauth2/oauth2/v4"
"github.com/go-oauth2/oauth2/v4/manage"
"github.com/go-oauth2/oauth2/v4/models"
"github.com/go-oauth2/oauth2/v4/server"
"github.com/go-oauth2/oauth2/v4/store"
"github.com/go-session/session"
"log"
"net/http"
"nucleus/internal"
"time"
)
type UserInfo struct {
Username string `json:"username"`
Gender string `json:"gender"`
}
var manager *manage.Manager
var srv *server.Server
var userInfoMap = make(map[string]UserInfo)
type Oauth2Controller struct {
internal.Controller
}
func (c *Oauth2Controller) Init() {
clientStore := store.NewClientStore()
clientStore.Set("juejin", &models.Client{ID: "juejin", Secret: "xxxxx", Domain: "http://juejin.com"})
// 设置 manager, manager 参与校验 code/access token 请求
manager = manage.NewDefaultManager()
// 校验 redirect_uri 和 client 的 Domain, 简单起见, 不做校验
manager.SetValidateURIHandler(func(baseURI, redirectURI string) error {
return nil
})
manager.MustTokenStorage(store.NewMemoryTokenStore())
// manger 包含 client 信息
manager.MapClientStorage(clientStore)
// 也包含 manger, client 信息
srv = server.NewServer(server.NewConfig(), manager)
// 根据 client id 从 manager 中获取 client info, 在获取 access token 校验过程中会被用到
srv.SetClientInfoHandler(func(r *http.Request) (clientID, clientSecret string, err error) {
clientInfo, err := srv.Manager.GetClient(r.Context(), r.URL.Query().Get("client_id"))
if err != nil {
log.Println(err)
return "", "", err
}
return clientInfo.GetID(), clientInfo.GetSecret(), nil
})
// 设置为 authorization code 模式
srv.SetAllowedGrantType(oauth2.AuthorizationCode)
// authorization code 模式, 第一步获取code,然后再用code换取 access token, 而不是直接获取 access token
srv.SetAllowedResponseType(oauth2.Code)
// 校验授权请求用户的handler, 会重定向到 登陆页面, 返回"", nil
srv.SetUserAuthorizationHandler(handleUserAuthorization)
// 校验授权请求的用户的账号密码, 给 LoginHandler 使用, 简单起见, 只允许一个用户授权
srv.SetPasswordAuthorizationHandler(func(ctx context.Context, clientID, username, password string) (userID string, err error) {
if username == "Tom" && password == "123456" {
return "0001", nil
}
return "", errors.New("username or password error")
})
// 允许使用 get 方法请求授权
srv.SetAllowGetAccessRequest(true)
// 储存用户信息的一个 map
userInfoMap["0001"] = UserInfo{
"Tom", "Male",
}
}
// HandleAuthorizeRequest 授权入口
func (c *Oauth2Controller) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) {
err := srv.HandleAuthorizeRequest(w, r)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusBadRequest)
}
}
// HandleLoginRequest 登陆入口
func (c *Oauth2Controller) HandleLoginRequest(w http.ResponseWriter, r *http.Request) {
sessionStore, err := session.Start(r.Context(), w, r)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if r.Method == http.MethodPost {
userId, err := srv.PasswordAuthorizationHandler(
r.Context(),
r.FormValue("client_id"),
r.FormValue("username"),
r.FormValue("password"),
)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// 保存登录状态
sessionStore.Set("LoggedInUserId", userId)
sessionStore.Save()
// 跳转到 同意授权页面
w.Header().Set("Location", "/oauth2/agree-auth")
w.WriteHeader(http.StatusFound)
return
}
// 若请求方法错误, 提供login.html页面
c.Render(w, r, "static/login.html")
}
// HandleAgreeAuthRequest 同意授权的页面
func (c *Oauth2Controller) HandleAgreeAuthRequest(w http.ResponseWriter, r *http.Request) {
store, err := session.Start(r.Context(), w, r)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// 如果没有查询到登陆状态, 则跳转到 登陆页面
if _, ok := store.Get("LoggedInUserId"); !ok {
w.Header().Set("Location", "/oauth2/login")
w.WriteHeader(http.StatusFound)
return
}
// 如果有登陆状态, 会跳转到 确认授权页面
c.Render(w, r, "static/agree-auth.html")
}
// HandleAccessTokenRequest 使用 code 换取 access token
func (c *Oauth2Controller) HandleAccessTokenRequest(w http.ResponseWriter, r *http.Request) {
err := srv.HandleTokenRequest(w, r)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusBadRequest)
}
}
// HandleUserInfoRequest 使用 access_token 换取用户信息
func (c *Oauth2Controller) HandleUserInfoRequest(w http.ResponseWriter, r *http.Request) {
// 获取 access token
accessToken, ok := srv.BearerAuth(r)
if !ok {
log.Println("Failed to get access token from request")
return
}
rootCtx := context.Background()
ctx, cancelFunc := context.WithTimeout(rootCtx, time.Second)
defer cancelFunc()
// 从 access token 中获取 信息
tokenInfo, err := srv.Manager.LoadAccessToken(ctx, accessToken)
if err != nil {
log.Println(err)
return
}
// 获取 user id
userId := tokenInfo.GetUserID()
grantScope := tokenInfo.GetScope()
userInfo := UserInfo{}
// 根据 grant scope 决定获取哪些用户信息
if grantScope != "read_user_info" {
log.Println("invalid grant scope")
w.Write([]byte("invalid grant scope"))
return
}
userInfo = userInfoMap[userId]
resp, err := json.Marshal(userInfo)
w.Write(resp)
return
}
// AuthorizeHandler 内部使用, 用于查看是否有登陆状态
func handleUserAuthorization(w http.ResponseWriter, r *http.Request) (userId string, err error) {
store, err := session.Start(r.Context(), w, r)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
uid, ok := store.Get("LoggedInUserId")
// 如果没有查询到登陆状态, 则跳转到登陆页面
if !ok {
if r.Form == nil {
r.ParseForm()
}
w.Header().Set("Location", "/oauth2/login")
w.WriteHeader(http.StatusFound)
return "", nil
}
// 若有登录状态, 返回 user id
userId = uid.(string)
return userId, nil
}