package controllers import ( "context" "encoding/json" "errors" "github.com/go-oauth2/oauth2/v4" "github.com/go-oauth2/oauth2/v4/manage" "github.com/go-oauth2/oauth2/v4/models" "github.com/go-oauth2/oauth2/v4/server" "github.com/go-oauth2/oauth2/v4/store" "github.com/go-session/session" "log" "net/http" "nucleus/internal" "time" ) type UserInfo struct { Username string `json:"username"` Gender string `json:"gender"` } var manager *manage.Manager var srv *server.Server var userInfoMap = make(map[string]UserInfo) type Oauth2Controller struct { internal.Controller } func (c *Oauth2Controller) Init() { clientStore := store.NewClientStore() clientStore.Set("juejin", &models.Client{ID: "juejin", Secret: "xxxxx", Domain: "http://juejin.com"}) // 设置 manager, manager 参与校验 code/access token 请求 manager = manage.NewDefaultManager() // 校验 redirect_uri 和 client 的 Domain, 简单起见, 不做校验 manager.SetValidateURIHandler(func(baseURI, redirectURI string) error { return nil }) manager.MustTokenStorage(store.NewMemoryTokenStore()) // manger 包含 client 信息 manager.MapClientStorage(clientStore) // 也包含 manger, client 信息 srv = server.NewServer(server.NewConfig(), manager) // 根据 client id 从 manager 中获取 client info, 在获取 access token 校验过程中会被用到 srv.SetClientInfoHandler(func(r *http.Request) (clientID, clientSecret string, err error) { clientInfo, err := srv.Manager.GetClient(r.Context(), r.URL.Query().Get("client_id")) if err != nil { log.Println(err) return "", "", err } return clientInfo.GetID(), clientInfo.GetSecret(), nil }) // 设置为 authorization code 模式 srv.SetAllowedGrantType(oauth2.AuthorizationCode) // authorization code 模式, 第一步获取code,然后再用code换取 access token, 而不是直接获取 access token srv.SetAllowedResponseType(oauth2.Code) // 校验授权请求用户的handler, 会重定向到 登陆页面, 返回"", nil srv.SetUserAuthorizationHandler(handleUserAuthorization) // 校验授权请求的用户的账号密码, 给 LoginHandler 使用, 简单起见, 只允许一个用户授权 srv.SetPasswordAuthorizationHandler(func(ctx context.Context, clientID, username, password string) (userID string, err error) { if username == "Tom" && password == "123456" { return "0001", nil } return "", errors.New("username or password error") }) // 允许使用 get 方法请求授权 srv.SetAllowGetAccessRequest(true) // 储存用户信息的一个 map userInfoMap["0001"] = UserInfo{ "Tom", "Male", } } // HandleAuthorizeRequest 授权入口 func (c *Oauth2Controller) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) { err := srv.HandleAuthorizeRequest(w, r) if err != nil { log.Println(err) http.Error(w, err.Error(), http.StatusBadRequest) } } // HandleLoginRequest 登陆入口 func (c *Oauth2Controller) HandleLoginRequest(w http.ResponseWriter, r *http.Request) { sessionStore, err := session.Start(r.Context(), w, r) if err != nil { log.Println(err) http.Error(w, err.Error(), http.StatusInternalServerError) return } if r.Method == http.MethodPost { userId, err := srv.PasswordAuthorizationHandler( r.Context(), r.FormValue("client_id"), r.FormValue("username"), r.FormValue("password"), ) if err != nil { log.Println(err) http.Error(w, err.Error(), http.StatusUnauthorized) return } // 保存登录状态 sessionStore.Set("LoggedInUserId", userId) sessionStore.Save() // 跳转到 同意授权页面 w.Header().Set("Location", "/oauth2/agree-auth") w.WriteHeader(http.StatusFound) return } // 若请求方法错误, 提供login.html页面 c.Render(w, r, "static/login.html") } // HandleAgreeAuthRequest 同意授权的页面 func (c *Oauth2Controller) HandleAgreeAuthRequest(w http.ResponseWriter, r *http.Request) { store, err := session.Start(r.Context(), w, r) if err != nil { log.Println(err) http.Error(w, err.Error(), http.StatusInternalServerError) return } // 如果没有查询到登陆状态, 则跳转到 登陆页面 if _, ok := store.Get("LoggedInUserId"); !ok { w.Header().Set("Location", "/oauth2/login") w.WriteHeader(http.StatusFound) return } // 如果有登陆状态, 会跳转到 确认授权页面 c.Render(w, r, "static/agree-auth.html") } // HandleAccessTokenRequest 使用 code 换取 access token func (c *Oauth2Controller) HandleAccessTokenRequest(w http.ResponseWriter, r *http.Request) { err := srv.HandleTokenRequest(w, r) if err != nil { log.Println(err) http.Error(w, err.Error(), http.StatusBadRequest) } } // HandleUserInfoRequest 使用 access_token 换取用户信息 func (c *Oauth2Controller) HandleUserInfoRequest(w http.ResponseWriter, r *http.Request) { // 获取 access token accessToken, ok := srv.BearerAuth(r) if !ok { log.Println("Failed to get access token from request") return } rootCtx := context.Background() ctx, cancelFunc := context.WithTimeout(rootCtx, time.Second) defer cancelFunc() // 从 access token 中获取 信息 tokenInfo, err := srv.Manager.LoadAccessToken(ctx, accessToken) if err != nil { log.Println(err) return } // 获取 user id userId := tokenInfo.GetUserID() grantScope := tokenInfo.GetScope() userInfo := UserInfo{} // 根据 grant scope 决定获取哪些用户信息 if grantScope != "read_user_info" { log.Println("invalid grant scope") w.Write([]byte("invalid grant scope")) return } userInfo = userInfoMap[userId] resp, err := json.Marshal(userInfo) w.Write(resp) return } // AuthorizeHandler 内部使用, 用于查看是否有登陆状态 func handleUserAuthorization(w http.ResponseWriter, r *http.Request) (userId string, err error) { store, err := session.Start(r.Context(), w, r) if err != nil { log.Println(err) http.Error(w, err.Error(), http.StatusInternalServerError) return } uid, ok := store.Get("LoggedInUserId") // 如果没有查询到登陆状态, 则跳转到登陆页面 if !ok { if r.Form == nil { r.ParseForm() } w.Header().Set("Location", "/oauth2/login") w.WriteHeader(http.StatusFound) return "", nil } // 若有登录状态, 返回 user id userId = uid.(string) return userId, nil }