You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
223 lines
6.2 KiB
223 lines
6.2 KiB
package controllers
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"github.com/go-oauth2/oauth2/v4"
|
|
"github.com/go-oauth2/oauth2/v4/manage"
|
|
"github.com/go-oauth2/oauth2/v4/models"
|
|
"github.com/go-oauth2/oauth2/v4/server"
|
|
"github.com/go-oauth2/oauth2/v4/store"
|
|
"github.com/go-session/session"
|
|
"log"
|
|
"net/http"
|
|
"nucleus/internal"
|
|
"time"
|
|
)
|
|
|
|
type UserInfo struct {
|
|
Username string `json:"username"`
|
|
Gender string `json:"gender"`
|
|
}
|
|
|
|
var manager *manage.Manager
|
|
var srv *server.Server
|
|
var userInfoMap = make(map[string]UserInfo)
|
|
|
|
type Oauth2Controller struct {
|
|
internal.Controller
|
|
}
|
|
|
|
func (c *Oauth2Controller) Init() {
|
|
clientStore := store.NewClientStore()
|
|
clientStore.Set("juejin", &models.Client{ID: "juejin", Secret: "xxxxx", Domain: "http://juejin.com"})
|
|
|
|
// 设置 manager, manager 参与校验 code/access token 请求
|
|
manager = manage.NewDefaultManager()
|
|
|
|
// 校验 redirect_uri 和 client 的 Domain, 简单起见, 不做校验
|
|
manager.SetValidateURIHandler(func(baseURI, redirectURI string) error {
|
|
return nil
|
|
})
|
|
|
|
manager.MustTokenStorage(store.NewMemoryTokenStore())
|
|
|
|
// manger 包含 client 信息
|
|
manager.MapClientStorage(clientStore)
|
|
|
|
// 也包含 manger, client 信息
|
|
srv = server.NewServer(server.NewConfig(), manager)
|
|
|
|
// 根据 client id 从 manager 中获取 client info, 在获取 access token 校验过程中会被用到
|
|
srv.SetClientInfoHandler(func(r *http.Request) (clientID, clientSecret string, err error) {
|
|
clientInfo, err := srv.Manager.GetClient(r.Context(), r.URL.Query().Get("client_id"))
|
|
if err != nil {
|
|
log.Println(err)
|
|
return "", "", err
|
|
}
|
|
return clientInfo.GetID(), clientInfo.GetSecret(), nil
|
|
})
|
|
|
|
// 设置为 authorization code 模式
|
|
srv.SetAllowedGrantType(oauth2.AuthorizationCode)
|
|
|
|
// authorization code 模式, 第一步获取code,然后再用code换取 access token, 而不是直接获取 access token
|
|
srv.SetAllowedResponseType(oauth2.Code)
|
|
|
|
// 校验授权请求用户的handler, 会重定向到 登陆页面, 返回"", nil
|
|
srv.SetUserAuthorizationHandler(handleUserAuthorization)
|
|
|
|
// 校验授权请求的用户的账号密码, 给 LoginHandler 使用, 简单起见, 只允许一个用户授权
|
|
srv.SetPasswordAuthorizationHandler(func(ctx context.Context, clientID, username, password string) (userID string, err error) {
|
|
if username == "Tom" && password == "123456" {
|
|
return "0001", nil
|
|
}
|
|
return "", errors.New("username or password error")
|
|
})
|
|
|
|
// 允许使用 get 方法请求授权
|
|
srv.SetAllowGetAccessRequest(true)
|
|
|
|
// 储存用户信息的一个 map
|
|
userInfoMap["0001"] = UserInfo{
|
|
"Tom", "Male",
|
|
}
|
|
}
|
|
|
|
// HandleAuthorizeRequest 授权入口
|
|
func (c *Oauth2Controller) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) {
|
|
err := srv.HandleAuthorizeRequest(w, r)
|
|
if err != nil {
|
|
log.Println(err)
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
}
|
|
}
|
|
|
|
// HandleLoginRequest 登陆入口
|
|
func (c *Oauth2Controller) HandleLoginRequest(w http.ResponseWriter, r *http.Request) {
|
|
sessionStore, err := session.Start(r.Context(), w, r)
|
|
if err != nil {
|
|
log.Println(err)
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if r.Method == http.MethodPost {
|
|
userId, err := srv.PasswordAuthorizationHandler(
|
|
r.Context(),
|
|
r.FormValue("client_id"),
|
|
r.FormValue("username"),
|
|
r.FormValue("password"),
|
|
)
|
|
if err != nil {
|
|
log.Println(err)
|
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// 保存登录状态
|
|
sessionStore.Set("LoggedInUserId", userId)
|
|
sessionStore.Save()
|
|
|
|
// 跳转到 同意授权页面
|
|
w.Header().Set("Location", "/oauth2/agree-auth")
|
|
w.WriteHeader(http.StatusFound)
|
|
return
|
|
}
|
|
|
|
// 若请求方法错误, 提供login.html页面
|
|
c.Render(w, r, "static/login.html")
|
|
}
|
|
|
|
// HandleAgreeAuthRequest 同意授权的页面
|
|
func (c *Oauth2Controller) HandleAgreeAuthRequest(w http.ResponseWriter, r *http.Request) {
|
|
store, err := session.Start(r.Context(), w, r)
|
|
if err != nil {
|
|
log.Println(err)
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// 如果没有查询到登陆状态, 则跳转到 登陆页面
|
|
if _, ok := store.Get("LoggedInUserId"); !ok {
|
|
w.Header().Set("Location", "/oauth2/login")
|
|
w.WriteHeader(http.StatusFound)
|
|
return
|
|
}
|
|
|
|
// 如果有登陆状态, 会跳转到 确认授权页面
|
|
c.Render(w, r, "static/agree-auth.html")
|
|
}
|
|
|
|
// HandleAccessTokenRequest 使用 code 换取 access token
|
|
func (c *Oauth2Controller) HandleAccessTokenRequest(w http.ResponseWriter, r *http.Request) {
|
|
err := srv.HandleTokenRequest(w, r)
|
|
if err != nil {
|
|
log.Println(err)
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
}
|
|
}
|
|
|
|
// HandleUserInfoRequest 使用 access_token 换取用户信息
|
|
func (c *Oauth2Controller) HandleUserInfoRequest(w http.ResponseWriter, r *http.Request) {
|
|
// 获取 access token
|
|
accessToken, ok := srv.BearerAuth(r)
|
|
if !ok {
|
|
log.Println("Failed to get access token from request")
|
|
return
|
|
}
|
|
|
|
rootCtx := context.Background()
|
|
ctx, cancelFunc := context.WithTimeout(rootCtx, time.Second)
|
|
defer cancelFunc()
|
|
|
|
// 从 access token 中获取 信息
|
|
tokenInfo, err := srv.Manager.LoadAccessToken(ctx, accessToken)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// 获取 user id
|
|
userId := tokenInfo.GetUserID()
|
|
grantScope := tokenInfo.GetScope()
|
|
|
|
userInfo := UserInfo{}
|
|
|
|
// 根据 grant scope 决定获取哪些用户信息
|
|
if grantScope != "read_user_info" {
|
|
log.Println("invalid grant scope")
|
|
w.Write([]byte("invalid grant scope"))
|
|
return
|
|
}
|
|
|
|
userInfo = userInfoMap[userId]
|
|
resp, err := json.Marshal(userInfo)
|
|
w.Write(resp)
|
|
return
|
|
}
|
|
|
|
// AuthorizeHandler 内部使用, 用于查看是否有登陆状态
|
|
func handleUserAuthorization(w http.ResponseWriter, r *http.Request) (userId string, err error) {
|
|
store, err := session.Start(r.Context(), w, r)
|
|
if err != nil {
|
|
log.Println(err)
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
uid, ok := store.Get("LoggedInUserId")
|
|
// 如果没有查询到登陆状态, 则跳转到登陆页面
|
|
if !ok {
|
|
if r.Form == nil {
|
|
r.ParseForm()
|
|
}
|
|
|
|
w.Header().Set("Location", "/oauth2/login")
|
|
w.WriteHeader(http.StatusFound)
|
|
return "", nil
|
|
}
|
|
// 若有登录状态, 返回 user id
|
|
userId = uid.(string)
|
|
return userId, nil
|
|
}
|
|
|