commit 815fbd568fe404deb1fe730f49fd0fcd45a280bf Author: hupeh Date: Fri Sep 15 14:09:57 2023 +0800 tada(*): :tada: 基础库功能 diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..a72359f --- /dev/null +++ b/.env.example @@ -0,0 +1,40 @@ +################ +# 应用基础配置 +################ + +# 设置运行环境,开设置的值有 dev、test、prod +APP_ENV=dev +# 配置时区(预留配置,暂未生效) +APP_TIMEZONE= + + +################ +# 数据库配置 +################ + +# 数据库驱动,支持 sqlite3、mysql、sqlserver、pgsql +DB_DRIVER= +# 数据库连接配置 +DB_DSN= +# 数据表前缀 +DB_PREFIX= +# MySQL存储引擎 +DB_STORE_ENGINE= +# 最大闲置连接数 +DB_MAX_IDLE_CONNS= +# 最大打开连接数 +DB_MAX_OPEN_CONNS= +# 连接最大有效时长 +DB_CONN_MAX_LIFETIME= +# 是不是代码优先模式,只有开启此选项后才能够 +# 使用自动迁移功能同步表结构到数据库 +DB_CODE_FIRST= + + + +################ +# 授权认证 +################ + +JWT_PRIVATE_KEY= +JWT_PUBLIC_KEY= diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2b276b4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +/.idea +/*.iml +*.db +.env +.env.* +!.env.example +/docs \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..242e9cc --- /dev/null +++ b/README.md @@ -0,0 +1,47 @@ +## 开发接口文档 + +> [swag 文档地址](https://github.com/swaggo/swag) + +生成开发文档命令 + +```shell +swag init +``` + +注释文档格式化命令 + +```shell +swag fmt +``` + + +## 关于 Api + +参考 Restful 设计风格。 + +### 请求方法指南 + +> 五个常用的 HTTP 动词,用于表示对于资源的具体操作类型: +> +> * GET(SELECT):从服务器取出资源(一项或多项)。 +> * PUT(CREATE):在服务器新建一个资源。 +> * POST(UPDATE):在服务器更新资源(客户端提供改变后的完整资源)。 +> * PATCH(UPDATE):在服务器更新资源(客户端提供改变的属性)。 +> * DELETE(DELETE):从服务器删除资源。 +> +> 还有两个不常用的HTTP动词: +> +> * HEAD:获取资源的元数据。 +> * OPTIONS:获取信息,关于资源的哪些属性是客户端可以改变的。 + +### 分页查询 + +接口应该提供如下参数过滤返回结果: + +- **sort_by** 排序字段列表,可以为字段指定如下前缀实现排序: + - `+` 表示查询结果按该字段的升序进行排序,为默认前缀,可省略; + - `-` 表示查询结果按该字段的降序进行排序。 +- **page** 页码,可选,默认值为 1 +- **per_page** 页容量,可选,默认值为 30 + +客户端通过 QueryString 方式提交上述参数来分页查询。 \ No newline at end of file diff --git a/TODOs.md b/TODOs.md new file mode 100644 index 0000000..457f832 --- /dev/null +++ b/TODOs.md @@ -0,0 +1,9 @@ +## 实现 env 包 + +https://github.com/golobby/env/blob/master/env.go + +## 支持 HTTP2 +https://www.modb.pro/db/87148 + +## API文档 +https://golang2.eddycjy.com/posts/ch2/04-api-doc/ \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ddcd38b --- /dev/null +++ b/go.mod @@ -0,0 +1,60 @@ +module sorbet + +go 1.21 + +require ( + github.com/go-resty/resty/v2 v2.7.0 + github.com/golang-jwt/jwt/v5 v5.0.0 + github.com/joho/godotenv v1.5.1 + github.com/labstack/echo-jwt/v4 v4.2.0 + github.com/labstack/echo/v4 v4.11.1 + github.com/labstack/gommon v0.4.0 + github.com/mattn/go-isatty v0.0.19 + github.com/mattn/go-sqlite3 v1.14.17 + github.com/mitchellh/mapstructure v1.5.0 + github.com/rs/xid v1.5.0 + github.com/swaggo/echo-swagger v1.4.1 + github.com/swaggo/swag v1.16.2 + golang.org/x/time v0.3.0 + gorm.io/driver/mysql v1.5.1 + gorm.io/driver/postgres v1.5.2 + gorm.io/driver/sqlite v1.5.3 + gorm.io/driver/sqlserver v1.5.1 + gorm.io/gorm v1.25.4 + gorm.io/plugin/optimisticlock v1.1.1 +) + +require ( + github.com/KyleBanks/depth v1.2.1 // indirect + github.com/PuerkitoBio/purell v1.1.1 // indirect + github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect + github.com/ghodss/yaml v1.0.0 // indirect + github.com/go-openapi/jsonpointer v0.19.5 // indirect + github.com/go-openapi/jsonreference v0.19.6 // indirect + github.com/go-openapi/spec v0.20.4 // indirect + github.com/go-openapi/swag v0.19.15 // indirect + github.com/go-sql-driver/mysql v1.7.1 // indirect + github.com/golang-jwt/jwt v3.2.2+incompatible // indirect + github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect + github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.3.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/microsoft/go-mssqldb v1.1.0 // indirect + github.com/rogpeppe/go-internal v1.11.0 // indirect + github.com/stretchr/testify v1.8.2 // indirect + github.com/swaggo/files/v2 v2.0.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasttemplate v1.2.2 // indirect + golang.org/x/crypto v0.11.0 // indirect + golang.org/x/net v0.12.0 // indirect + golang.org/x/sys v0.10.0 // indirect + golang.org/x/text v0.11.0 // indirect + golang.org/x/tools v0.7.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..2f63c4d --- /dev/null +++ b/go.sum @@ -0,0 +1,230 @@ +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= +github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= +github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc= +github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= +github.com/PuerkitoBio/purell v1.1.1 h1:WEQqlqaGbrPkxLJWfBwQmfEAE1Z7ONdDLqrN38tNFfI= +github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M= +github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= +github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= +github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonreference v0.19.6 h1:UBIxjkht+AWIgYzCDSv2GN+E/togfwXUJFRTWhl2Jjs= +github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns= +github.com/go-openapi/spec v0.20.4 h1:O8hJrt0UMnhHcluhIdUgCLRWyM2x7QkBXRvOs7m+O1M= +github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM= +github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY= +github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= +github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= +github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= +github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= +github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= +github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= +github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= +github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= +github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= +github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= +github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= +github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/labstack/echo-jwt/v4 v4.2.0 h1:odSISV9JgcSCuhgQSV/6Io3i7nUmfM/QkBeR5GVJj5c= +github.com/labstack/echo-jwt/v4 v4.2.0/go.mod h1:MA2RqdXdEn4/uEglx0HcUOgQSyBaTh5JcaHIan3biwU= +github.com/labstack/echo/v4 v4.11.1 h1:dEpLU2FLg4UVmvCGPuk/APjlH6GDpbEPti61srUUUs4= +github.com/labstack/echo/v4 v4.11.1/go.mod h1:YuYRTSM3CHs2ybfrL8Px48bO6BAnYIN4l8wSTMP6BDQ= +github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8= +github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/microsoft/go-mssqldb v1.1.0 h1:jsV+tpvcPTbNNKW0o3kiCD69kOHICsfjZ2VcVu2lKYc= +github.com/microsoft/go-mssqldb v1.1.0/go.mod h1:LzkFdl4z2Ck+Hi+ycGOTbL56VEfgoyA2DvYejrNGbRk= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= +github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/swaggo/echo-swagger v1.4.1 h1:Yf0uPaJWp1uRtDloZALyLnvdBeoEL5Kc7DtnjzO/TUk= +github.com/swaggo/echo-swagger v1.4.1/go.mod h1:C8bSi+9yH2FLZsnhqMZLIZddpUxZdBYuNHbtaS1Hljc= +github.com/swaggo/files/v2 v2.0.0 h1:hmAt8Dkynw7Ssz46F6pn8ok6YmGZqHSVLZ+HQM7i0kw= +github.com/swaggo/files/v2 v2.0.0/go.mod h1:24kk2Y9NYEJ5lHuCra6iVwkMjIekMCaFq/0JQj66kyM= +github.com/swaggo/swag v1.16.2 h1:28Pp+8DkQoV+HLzLx8RGJZXNGKbFqnuvSbAAtoxiY04= +github.com/swaggo/swag v1.16.2/go.mod h1:6YzXnDcpr0767iOejs318CwYkCQqyGer6BizOg03f+E= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= +github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= +golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= +golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4= +golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce h1:+JknDZhAj8YMt7GC73Ei8pv4MzjDUNPHgQWJdtMAaDU= +gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce/go.mod h1:5AcXVHNjg+BDxry382+8OKon8SEWiKktQR07RKPsv1c= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.1 h1:WUEH5VF9obL/lTtzjmML/5e6VfFR/788coz2uaVCAZw= +gorm.io/driver/mysql v1.5.1/go.mod h1:Jo3Xu7mMhCyj8dlrb3WoCaRd1FhsVh+yMXb1jUInf5o= +gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= +gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= +gorm.io/driver/sqlite v1.5.3 h1:7/0dUgX28KAcopdfbRWWl68Rflh6osa4rDh+m51KL2g= +gorm.io/driver/sqlite v1.5.3/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= +gorm.io/driver/sqlserver v1.5.1 h1:wpyW/pR26U94uaujltiFGXY7fd2Jw5hC9PB1ZF/Y5s4= +gorm.io/driver/sqlserver v1.5.1/go.mod h1:AYHzzte2msKTmYBYsSIq8ZUsznLJwBdkB2wpI+kt0nM= +gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw= +gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/plugin/optimisticlock v1.1.1 h1:REWF26BNTIcLpgzp34EW1Mi9bPZpthBcwjBkOYINn5Q= +gorm.io/plugin/optimisticlock v1.1.1/go.mod h1:wFWgM/KsGEg+IoxgZAAVBP4OmaPfj337L/+T4AR6/hI= diff --git a/internal/entities/company.go b/internal/entities/company.go new file mode 100644 index 0000000..b9e0a86 --- /dev/null +++ b/internal/entities/company.go @@ -0,0 +1,24 @@ +package entities + +import ( + "gorm.io/gorm" + "sorbet/pkg/db" + "time" +) + +// Company 公司表 +type Company struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:公司编号"` + PrincipalID *uint `json:"principal_id" xml:"principal_id" gorm:"comment:负责人编号(员工)"` + Name string `json:"name" xml:"name" gorm:"size:25;not null;uniqueIndex;comment:公司名称"` + Logo string `json:"logo" xml:"logo" gorm:"not null;comment:形象徽标"` + Status bool `json:"status" xml:"status" gorm:"comment:状态"` + Version db.Version `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Principal *CompanyStaff `json:"principal" xml:"principal"` + Staffs []*CompanyStaff `json:"staffs" xml:"staffs"` + Departments []*CompanyDepartment `json:"departments" xml:"departments"` +} diff --git a/internal/entities/company_department.go b/internal/entities/company_department.go new file mode 100644 index 0000000..b993b55 --- /dev/null +++ b/internal/entities/company_department.go @@ -0,0 +1,37 @@ +package entities + +import ( + "errors" + "gorm.io/gorm" + "sorbet/pkg/db" + "time" + "unicode/utf8" +) + +// CompanyDepartment 公司部门表 +type CompanyDepartment struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:部门编号"` + PID *uint `json:"pid" xml:"pid" gorm:"comment:上级部门编号"` + CompanyID uint `json:"company_id" xml:"company_id" gorm:"comment:所属公司编号"` + PrincipalID *uint `json:"principal_id" xml:"principal_id" gorm:"comment:负责人编号(员工)"` + Name string `json:"name" xml:"name" gorm:"size:50;not null;comment:部门名称"` + Sort int32 `json:"sort" xml:"sort" gorm:"default:0;comment:展示排序"` + Version db.Version `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Staffs []*CompanyStaff `json:"staffs" xml:"staffs" gorm:"many2many:company_staff_to_department_relations"` + Courses []*FeatureContent `json:"courses" xml:"courses" gorm:"many2many:company_course_to_department_relations"` + Children []*CompanyDepartment `json:"children" xml:"children" gorm:"foreignKey:PID"` +} + +func (c *CompanyDepartment) BeforeCreate(tx *gorm.DB) error { + if c.CompanyID == 0 { + return errors.New("缺少所属公司编号") + } + if utf8.RuneCountInString(c.Name) < 6 { + return errors.New("部门名称至少两个字") + } + return nil +} diff --git a/internal/entities/company_staff.go b/internal/entities/company_staff.go new file mode 100644 index 0000000..2d651d3 --- /dev/null +++ b/internal/entities/company_staff.go @@ -0,0 +1,26 @@ +package entities + +import ( + "gorm.io/gorm" + "sorbet/pkg/db" + "time" +) + +// CompanyStaff 公司员工表 +type CompanyStaff struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:员工编号"` + CompanyID uint `json:"company_id" xml:"company_id" gorm:"comment:所属公司编号"` + Name string `json:"name" xml:"name" gorm:"size:20;not null;comment:员工姓名"` + Gender string `json:"gender" xml:"gender" gorm:"not null;default:'unknown';check:gender IN('female','male','unknown');comment:员工性别"` + Position string `json:"position" xml:"position" gorm:"size:100;not null;comment:员工职务"` + PhoneNumber string `json:"phone_number" xml:"phone_number" gorm:"size:11;not null;comment:手机号码"` + WechatOpenid string `json:"wechat_openid" xml:"wechat_openid" gorm:"size:100;not null;comment:微信号"` + WithoutStudy bool `json:"without_study" xml:"without_study" gorm:"comment:是否可以不用学习"` + IsAdmin bool `json:"is_admin" xml:"is_admin" gorm:"comment:是否管理员"` + Version db.Version `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Departments []*CompanyDepartment `json:"departments" xml:"departments" gorm:"many2many:company_staff_to_department_relations"` +} diff --git a/internal/entities/config.go b/internal/entities/config.go new file mode 100644 index 0000000..e42b6bf --- /dev/null +++ b/internal/entities/config.go @@ -0,0 +1,24 @@ +package entities + +import ( + "gorm.io/gorm" + "sorbet/pkg/db" + "time" +) + +// Config 配置表 +type Config struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:配置编号"` + GroupID uint `json:"group_id" xml:"group_id" gorm:"comment:所属配置组编号"` + Name string `json:"name" xml:"name" gorm:"size:30;not null;comment:配置名称"` + Title string `json:"title" xml:"title" gorm:"size:50;not null;comment:配置标题"` + Description string `json:"description" xml:"description" gorm:"size:100;not null;comment:配置描述"` + DataType string `json:"data_type" xml:"data_type" gorm:"size:10;not null;comment:数据类型"` + Attributes map[string]any `json:"attributes" xml:"attributes" gorm:"serializer:json;comment:相关属性值"` + Value any `json:"value" xml:"value" gorm:"serializer:json;not null;comment:配置值"` + Sort int32 `json:"sort" xml:"sort" gorm:"size:4;default:0;comment:排序"` + Version db.Version `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` +} diff --git a/internal/entities/config_group.go b/internal/entities/config_group.go new file mode 100644 index 0000000..2380d37 --- /dev/null +++ b/internal/entities/config_group.go @@ -0,0 +1,21 @@ +package entities + +import ( + "gorm.io/gorm" + "sorbet/pkg/db" + "time" +) + +// ConfigGroup 配置组表 +type ConfigGroup struct { + ID int64 `json:"id" xml:"id" gorm:"primaryKey;not null;comment:配置组编号"` + Name string `json:"name" xml:"name" gorm:"size:25;not null;uniqueIndex;comment:配置组名称"` + Description string `json:"description" xml:"description" gorm:"comment:配置组描述"` + Sort int32 `json:"sort" xml:"sort" gorm:"default:0;comment:排序"` + Version db.Version `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Configs []*Config `json:"configs" xml:"configs" gorm:"foreignKey:GroupID"` +} diff --git a/internal/entities/feature.go b/internal/entities/feature.go new file mode 100644 index 0000000..7c5b181 --- /dev/null +++ b/internal/entities/feature.go @@ -0,0 +1,30 @@ +package entities + +import ( + "gorm.io/gorm" + "sorbet/pkg/db" + "time" +) + +// Feature 栏目表 +type Feature struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:栏目编号"` + Title string `json:"title" xml:"title" gorm:"size:25;not null;uniqueIndex;comment:栏目名称"` + Intro string `json:"intro" xml:"intro" gorm:"comment:栏目简介"` + Icon string `json:"icon" xml:"icon" gorm:"comment:栏目图标"` + Sort int32 `json:"sort" xml:"sort" gorm:"default:0;comment:排序"` + Version db.Version `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Config *FeatureConfig `json:"config" xml:"config"` + Categories []*FeatureCategory `json:"categories" xml:"categories"` + Contents []*FeatureContent `json:"contents" xml:"contents"` +} + +// AfterDelete 将在对应的条件数据成功从数据库删除之后执行 +// todo(hupeh): 是否支持软删除 +func (f *Feature) AfterDelete(tx *gorm.DB) error { + return nil +} diff --git a/internal/entities/feature_category.go b/internal/entities/feature_category.go new file mode 100644 index 0000000..96fc337 --- /dev/null +++ b/internal/entities/feature_category.go @@ -0,0 +1,24 @@ +package entities + +import ( + "gorm.io/gorm" + "sorbet/pkg/db" + "time" +) + +// FeatureCategory 栏目分类表 +type FeatureCategory struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:栏目分类编号"` + PID *uint `json:"pid" xml:"pid" gorm:"comment:上级分类编号"` + FeatureID uint `json:"feature_id" xml:"feature_id" gorm:"index:,unique,composite:idx_title_with_feature;comment:所属栏目编号"` + Title string `json:"title" xml:"title" gorm:"size:25;not null;index:,unique,composite:idx_title_with_feature;comment:栏目分类标题"` + Description string `json:"description" xml:"description" gorm:"size:250;comment:栏目分类描述"` + Sort int32 `json:"sort" xml:"sort" gorm:"default:0;comment:排序"` + Status bool `json:"status" xml:"status" gorm:"comment:栏目分类功能"` + Version db.Version `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Children []*FeatureCategory `json:"children" xml:"children" gorm:"foreignKey:PID"` +} diff --git a/internal/entities/feature_config.go b/internal/entities/feature_config.go new file mode 100644 index 0000000..ca9ecdd --- /dev/null +++ b/internal/entities/feature_config.go @@ -0,0 +1,21 @@ +package entities + +import ( + "gorm.io/gorm" + "sorbet/pkg/db" + "time" +) + +// FeatureConfig 栏目配置表 +type FeatureConfig struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:配置编号"` + FeatureID uint `json:"feature_id" xml:"feature_id" gorm:"comment:所属栏目编号"` + Status bool `json:"status" xml:"status" gorm:"comment:是否启用分类功能"` + Categorizable bool `json:"categorizable" xml:"categorizable" gorm:"comment:是否启用分类功能"` + CategoryDepth int64 `json:"category_depth" xml:"category_depth" gorm:"comment:最大分类层级数"` + ContentTypes []string `json:"content_types" xml:"content_types" gorm:"serializer:json;comment:支持的内容类型"` + Version db.Version `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` +} diff --git a/internal/entities/feature_content.go b/internal/entities/feature_content.go new file mode 100644 index 0000000..670c8b7 --- /dev/null +++ b/internal/entities/feature_content.go @@ -0,0 +1,26 @@ +package entities + +import ( + "gorm.io/gorm" + "time" +) + +// FeatureContent 栏目内容表(文章、视频、课程) +type FeatureContent struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:内容编号"` + FeatureID uint `json:"feature_id" xml:"feature_id" gorm:"comment:所属栏目编号"` + CategoryID *uint `json:"category_id" xml:"category_id" gorm:"default:null;comment:所属分类编号"` + Type string `json:"type" xml:"type" gorm:"not null;comment:内容类型"` + Title string `json:"title" xml:"title" gorm:"size:100;not null;comment:内容标题"` + Intro string `json:"intro" xml:"intro" gorm:"size:250;comment:内容简介"` + Version int `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Chapters []*FeatureContentChapter `json:"chapters" xml:"chapters" gorm:"foreignKey:ContentID"` + Details []*FeatureContentDetail `json:"details" xml:"details" gorm:"foreignKey:ContentID"` + + // 只有当该记录作为课程时才有效,所以这里需要 check + Departments []*FeatureContent `json:"departments" xml:"departments" gorm:"many2many:company_course_to_department_relations"` +} diff --git a/internal/entities/feature_content_chapter.go b/internal/entities/feature_content_chapter.go new file mode 100644 index 0000000..47c2c7d --- /dev/null +++ b/internal/entities/feature_content_chapter.go @@ -0,0 +1,24 @@ +package entities + +import ( + "gorm.io/gorm" + "time" +) + +// FeatureContentChapter 栏目内容章回表 +type FeatureContentChapter struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:章回编号"` + PID *uint `json:"pid" xml:"pid" gorm:"comment:上级章回编号"` + FeatureID uint `json:"feature_id" xml:"feature_id" gorm:"comment:所属栏目编号"` + ContentID uint `json:"content_id" xml:"content_id" gorm:"comment:所属内容编号"` + Title string `json:"title" xml:"title" gorm:"size:100;not null;comment:章回标题"` + Intro string `json:"intro" xml:"intro" gorm:"size:250;comment:章回描述"` + Sort int32 `json:"sort" xml:"sort" gorm:"size:4;default:0;comment:排序"` + Version int `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Details []FeatureContentDetail `json:"details" xml:"details" gorm:"foreignKey:ChapterID"` + Children []*FeatureContentChapter `json:"children" xml:"children" gorm:"foreignKey:PID"` +} diff --git a/internal/entities/feature_content_detail.go b/internal/entities/feature_content_detail.go new file mode 100644 index 0000000..bb60a47 --- /dev/null +++ b/internal/entities/feature_content_detail.go @@ -0,0 +1,28 @@ +package entities + +import ( + "gorm.io/gorm" + "time" +) + +// FeatureContentDetail 内容详情表 +type FeatureContentDetail struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:内容详情编号"` + FeatureID uint `json:"feature_id" xml:"feature_id" gorm:"comment:所属栏目编号"` + ChapterID *uint `json:"chapter_id" xml:"chapter_id" gorm:"comment:所属章回编号"` + ContentID uint `json:"content_id" xml:"content_id" gorm:"comment:所属内容编号"` + Type string `json:"type" xml:"type" gorm:"not null;comment:内容类型"` + Title string `json:"title" xml:"title" gorm:"size:25;not null;comment:标题"` + Intro string `json:"intro" xml:"intro" gorm:"size:250;comment:简介"` + PosterUrl string `json:"poster_url" xml:"poster_url" gorm:"size:250;comment:封面链接"` + VideoUrl string `json:"video_url" xml:"video_url" gorm:"size:250;comment:视频描述"` + Text string `json:"text" xml:"text" gorm:"type:longtext;comment:具体内容"` + Attributes map[string]any `json:"attributes" xml:"attributes" gorm:"serializer:json;type:text;comment:相关属性值"` + Version int `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Chapter *FeatureContentChapter `json:"chapter" xml:"chapter" gorm:"foreignKey:ChapterID"` + Content *FeatureContent `json:"content" xml:"content" gorm:"foreignKey:ContentID"` +} diff --git a/internal/entities/resource.go b/internal/entities/resource.go new file mode 100644 index 0000000..7a60137 --- /dev/null +++ b/internal/entities/resource.go @@ -0,0 +1,25 @@ +package entities + +import ( + "gorm.io/gorm" + "sorbet/pkg/db" + "time" +) + +// Resource 资源表 +type Resource struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:资源编号"` + CategoryID uint `json:"category_id" xml:"category_id" gorm:"not null;comment:所属分类编号"` + Title string `json:"title" xml:"title" gorm:"size:25;not null;comment:资源名称"` + Path string `json:"path" xml:"path" gorm:"comment:资源访问路径"` + Width int32 `json:"width" xml:"width" gorm:"comment:资源宽度"` + Height int32 `json:"height" xml:"height" gorm:"comment:资源高度"` + Duration int32 `json:"duration" xml:"duration" gorm:"comment:播放时长(视频、gif动画)"` + MimeType string `json:"mime_type" xml:"mime_type" gorm:"comment:资源媒体类型"` + Extension string `json:"extension" xml:"extension" gorm:"comment:资源文件扩展名"` + Size int64 `json:"size" xml:"size" gorm:"comment:资源大小"` + Version db.Version `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` +} diff --git a/internal/entities/resource_category.go b/internal/entities/resource_category.go new file mode 100644 index 0000000..4e1831c --- /dev/null +++ b/internal/entities/resource_category.go @@ -0,0 +1,23 @@ +package entities + +import ( + "gorm.io/gorm" + "sorbet/pkg/db" + "time" +) + +// ResourceCategory 资源分类表 +type ResourceCategory struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:资源分类编号"` + PID *uint `json:"pid" xml:"pid" gorm:"comment:上级分类编号"` + Title string `json:"title" xml:"title" gorm:"size:25;not null;uniqueIndex;comment:资源分类名称"` + Sort int32 `json:"sort" xml:"sort" gorm:"default:0;comment:排序"` + Status bool `json:"status" xml:"status" gorm:"comment:状态"` + Version db.Version `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Resources []*Resource `json:"resources" xml:"resources" gorm:"foreignKey:CategoryID"` + Children []*ResourceCategory `json:"children" xml:"children" gorm:"foreignKey:PID"` +} diff --git a/internal/entities/system_log.go b/internal/entities/system_log.go new file mode 100644 index 0000000..e33518d --- /dev/null +++ b/internal/entities/system_log.go @@ -0,0 +1,19 @@ +package entities + +import "time" + +// SystemLog 系统日志表 +type SystemLog struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:系统用户编号"` + Table string `json:"table" xml:"table" gorm:"comment:被操作表名"` + RowID uint `json:"row_id" xml:"row_id" gorm:"comment:被操作的数据编号"` + Operation string `json:"operation" xml:"operation" gorm:"comment:操作类型,1查询、2新增、3编辑、4删除"` + IP string `json:"ip" xml:"IP" gorm:"comment:用户IP"` + Comment string `json:"comment" xml:"comment" gorm:"comment:操作描述"` + RequestID string `json:"request_id" xml:"request_id" gorm:"comment:请求编号"` + RequestInfo string `json:"request_info" xml:"request_info" gorm:"comment:请求信息"` + ColumnInfo string `json:"column_info" xml:"column_info" gorm:"comment:列变更信息"` + UserID int64 `json:"user_id" xml:"user_id" gorm:"comment:用户编号"` + UserType int64 `json:"user_type" xml:"user_type" gorm:"comment:用户类型"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` +} diff --git a/internal/entities/system_menu.go b/internal/entities/system_menu.go new file mode 100644 index 0000000..b3f672b --- /dev/null +++ b/internal/entities/system_menu.go @@ -0,0 +1,23 @@ +package entities + +import ( + "gorm.io/gorm" + "sorbet/pkg/db" + "time" +) + +// SystemMenu 系统菜单表 +type SystemMenu struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:菜单编号"` + PID *uint `json:"pid" xml:"pid" gorm:"comment:上级菜单编号"` + Title string `json:"title" xml:"title" gorm:"size:25;not null;comment:菜单标题"` + Icon string `json:"icon" xml:"icon" gorm:"comment:菜单图标"` + Sort int32 `json:"sort" xml:"sort" gorm:"default:0;comment:排序"` + Path string `json:"path" xml:"path" gorm:"comment:跳转链接"` + Version db.Version `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Children []*SystemMenu `json:"children" xml:"children" gorm:"foreignKey:PID"` +} diff --git a/internal/entities/system_permission.go b/internal/entities/system_permission.go new file mode 100644 index 0000000..e650cad --- /dev/null +++ b/internal/entities/system_permission.go @@ -0,0 +1,19 @@ +package entities + +import ( + "gorm.io/gorm" + "time" +) + +type SystemPermission struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:权限编号"` + PID *uint `json:"pid" xml:"pid" gorm:"comment:上级权限编号"` + Name string `json:"name" xml:"name" gorm:"size:25;not null;comment:权限名称"` + Type string `json:"type" xml:"type" gorm:"size:25;not null;index;comment:权限类型"` + Identifier string `json:"identifier" xml:"identifier" gorm:"size:25;not null;uniqueIndex;comment:权限标识"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Children []*SystemPermission `json:"children" xml:"children" gorm:"foreignKey:PID"` +} diff --git a/internal/entities/system_role.go b/internal/entities/system_role.go new file mode 100644 index 0000000..62d3be3 --- /dev/null +++ b/internal/entities/system_role.go @@ -0,0 +1,20 @@ +package entities + +import ( + "gorm.io/gorm" + "sorbet/pkg/db" + "time" +) + +// SystemRole 系统用户角色表 +type SystemRole struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:角色编号"` + Name string `json:"name" xml:"name" gorm:"size:25;not null;uniqueIndex;comment:角色名称"` + Version db.Version `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Powers []*SystemRolePower `json:"powers" xml:"powers" gorm:"foreignKey:RoleID"` + Users []*SystemUser `json:"users" xml:"users" gorm:"many2many:system_user_to_role_relations"` +} diff --git a/internal/entities/system_role_power.go b/internal/entities/system_role_power.go new file mode 100644 index 0000000..e5f98b4 --- /dev/null +++ b/internal/entities/system_role_power.go @@ -0,0 +1,12 @@ +package entities + +import "time" + +// SystemRolePower 角色授权表 +type SystemRolePower struct { + ID uint `json:"id" xml:"id" gorm:"primaryKey;not null;comment:能力编号"` + RoleID uint `json:"role_id" xml:"role_id" gorm:"comment:关联角色编号"` + WithType string `json:"with_type" xml:"with_type" gorm:"size:25;not null;comment:关联类型,如:perm权限、menu菜单、data数据"` + WithID uint `json:"with_id" xml:"with_id" gorm:"size:25;not null;comment:关联编号"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` +} diff --git a/internal/entities/system_user.go b/internal/entities/system_user.go new file mode 100644 index 0000000..c7fc8dc --- /dev/null +++ b/internal/entities/system_user.go @@ -0,0 +1,20 @@ +package entities + +import ( + "gorm.io/gorm" + "time" +) + +// SystemUser 系统用户表 +type SystemUser struct { + ID int64 `json:"id" xml:"id" gorm:"primaryKey;not null;comment:系统用户编号"` + Username string `json:"username" xml:"username" gorm:"size:25;not null;uniqueIndex;comment:用户名"` + Password string `json:"password" xml:"password" gorm:"size:25;not null;comment:登录密码"` + Status bool `json:"status" xml:"status" gorm:"comment:状态"` + Version int `json:"-" xml:"-" gorm:"comment:乐观锁"` + CreatedAt time.Time `json:"create_time" xml:"create_time" gorm:"<-:false;comment:创建时间"` + UpdatedAt time.Time `json:"update_time" xml:"update_time" gorm:"<-:false;comment:更新时间"` + DeletedAt gorm.DeletedAt `json:"delete_time" xml:"delete_time" gorm:"comment:删除时间"` + + Roles []*SystemRole `json:"roles" xml:"roles" gorm:"many2many:system_user_to_role_relations"` +} diff --git a/internal/init.go b/internal/init.go new file mode 100644 index 0000000..22f642f --- /dev/null +++ b/internal/init.go @@ -0,0 +1,53 @@ +package internal + +import ( + "errors" + "sorbet/internal/entities" + "sorbet/internal/repositories" + "sorbet/pkg/db" + "sorbet/pkg/env" + "sorbet/pkg/ioc" + "sorbet/pkg/log" +) + +func Init() error { + ioc.Bind(db.DB()) // 注入数据库操作 + ioc.Bind(log.Default()) // 注入日志操作 + repositories.Init() // 注入数据仓库操作 + + // 同步数据库结构 + if err := syncEntities(); err != nil { + if !errors.Is(err, db.ErrNoCodeFirst) { + return err + } + if !env.IsEnv("prod") { + log.Error("同步数据表结构需要开启 [DB_CODE_FIRST],在生产模式下请务必关闭。") + } + } + + return nil +} + +func syncEntities() error { + return db.Sync( + &entities.Company{}, + &entities.CompanyDepartment{}, + &entities.CompanyStaff{}, + &entities.Config{}, + &entities.ConfigGroup{}, + &entities.Feature{}, + &entities.FeatureCategory{}, + &entities.FeatureConfig{}, + &entities.FeatureContent{}, + &entities.FeatureContentChapter{}, + &entities.FeatureContentDetail{}, + &entities.Resource{}, + &entities.ResourceCategory{}, + &entities.SystemLog{}, + &entities.SystemMenu{}, + &entities.SystemPermission{}, + &entities.SystemRole{}, + &entities.SystemRolePower{}, + &entities.SystemUser{}, + ) +} diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go new file mode 100644 index 0000000..975aa7e --- /dev/null +++ b/internal/middleware/cors.go @@ -0,0 +1,134 @@ +package middleware + +import ( + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + "net/http" +) + +type CORSConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // AllowOrigins determines the value of the Access-Control-Allow-Origin + // response header. This header defines a list of origins that may access the + // resource. The wildcard characters '*' and '?' are supported and are + // converted to regex fragments '.*' and '.' accordingly. + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // Optional. Default value []string{"*"}. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin + AllowOrigins []string + + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // Optional. + AllowOriginFunc func(origin string) (bool, error) + + // AllowMethods determines the value of the Access-Control-Allow-Methods + // response header. This header specified the list of methods allowed when + // accessing the resource. This is used in response to a preflight request. + // + // Optional. Default value DefaultCORSConfig.AllowMethods. + // If `allowMethods` is left empty, this middleware will fill for preflight + // request `Access-Control-Allow-Methods` header value + // from `Allow` header that echo.Router set into context. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + AllowMethods []string + + // AllowHeaders determines the value of the Access-Control-Allow-Headers + // response header. This header is used in response to a preflight request to + // indicate which HTTP headers can be used when making the actual request. + // + // Optional. Default value []string{}. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + AllowHeaders []string + + // AllowCredentials determines the value of the + // Access-Control-Allow-Credentials response header. This header indicates + // whether or not the response to the request can be exposed when the + // credentials mode (Request.credentials) is true. When used as part of a + // response to a preflight request, this indicates whether or not the actual + // request can be made using credentials. See also + // [MDN: Access-Control-Allow-Credentials]. + // + // Optional. Default value false, in which case the header is not set. + // + // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. + // See "Exploiting CORS misconfigurations for Bitcoins and bounties", + // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials + AllowCredentials bool + + // UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials + // flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header. + // + // This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties) + // attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject. + // + // Optional. Default value is false. + UnsafeWildcardOriginWithAllowCredentials bool + + // ExposeHeaders determines the value of Access-Control-Expose-Headers, which + // defines a list of headers that clients are allowed to access. + // + // Optional. Default value []string{}, in which case the header is not set. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header + ExposeHeaders []string + + // MaxAge determines the value of the Access-Control-Max-Age response header. + // This header indicates how long (in seconds) the results of a preflight + // request can be cached. + // + // Optional. Default value 0. The header is set only if MaxAge > 0. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age + MaxAge int +} + +// DefaultCORSConfig is the default CORS middleware config. +var DefaultCORSConfig = CORSConfig{ + Skipper: middleware.DefaultSkipper, + AllowOrigins: []string{"*"}, + AllowMethods: []string{ + http.MethodGet, + http.MethodHead, + http.MethodPut, + http.MethodPatch, + http.MethodPost, + http.MethodDelete, + }, +} + +func (c *CORSConfig) ToMiddleware() echo.MiddlewareFunc { + return middleware.CORSWithConfig(middleware.CORSConfig{ + Skipper: c.Skipper, + AllowOrigins: c.AllowOrigins, + AllowOriginFunc: c.AllowOriginFunc, + AllowMethods: c.AllowMethods, + AllowHeaders: c.AllowHeaders, + AllowCredentials: c.AllowCredentials, + UnsafeWildcardOriginWithAllowCredentials: c.UnsafeWildcardOriginWithAllowCredentials, + ExposeHeaders: c.ExposeHeaders, + MaxAge: c.MaxAge, + }) +} + +func CORS() echo.MiddlewareFunc { + return DefaultCORSConfig.ToMiddleware() +} diff --git a/internal/middleware/jwt.go b/internal/middleware/jwt.go new file mode 100644 index 0000000..05eda14 --- /dev/null +++ b/internal/middleware/jwt.go @@ -0,0 +1,140 @@ +package middleware + +import ( + "github.com/golang-jwt/jwt/v5" + "github.com/labstack/echo-jwt/v4" + "github.com/labstack/echo/v4" +) + +// JWTConfig defines the config for JWT middleware. +type JWTConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // BeforeFunc defines a function which is executed just before the middleware. + BeforeFunc BeforeFunc + + // SuccessHandler defines a function which is executed for a valid token. + SuccessHandler func(c echo.Context) + + // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator + // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. + // It may be used to define a custom JWT error. + // + // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. + // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users + // In that case you can use ErrorHandler to set default public JWT token value to request and continue with handler chain. + ErrorHandler func(c echo.Context, err error) error + + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandler to set a default public JWT token value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) token value then. + ContinueOnIgnoredError bool + + // Context key to store user information from the token into context. + // Optional. Default value "user". + ContextKey string + + // Signing key to validate token. + // This is one of the three options to provide a token validation key. + // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. + // Required if neither user-defined KeyFunc nor SigningKeys is provided. + SigningKey any + + // Map of signing keys to validate token with kid field usage. + // This is one of the three options to provide a token validation key. + // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. + // Required if neither user-defined KeyFunc nor SigningKey is provided. + SigningKeys map[string]any + + // Signing method used to check the token's signing algorithm. + // Optional. Default value HS256. + SigningMethod string + + // KeyFunc defines a user-defined function that supplies the public key for a token validation. + // The function shall take care of verifying the signing algorithm and selecting the proper key. + // A user-defined KeyFunc can be useful if tokens are issued by an external party. + // Used by default ParseTokenFunc implementation. + // + // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. + // This is one of the three options to provide a token validation key. + // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. + // Required if neither SigningKeys nor SigningKey is provided. + // Not used if custom ParseTokenFunc is set. + // Default to an internal implementation verifying the signing algorithm and selecting the proper key. + KeyFunc jwt.Keyfunc + + // TokenLookup is a string in the form of ":" or ":,:" that is used + // to extract token from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of JWT tokens `Authorization: Bearer ` prefix we cut is `Bearer `. + // If prefix is left empty the whole value is returned. + // - "query:" + // - "param:" + // - "cookie:" + // - "form:" + // Multiple sources example: + // - "header:Authorization:Bearer ,cookie:myowncookie" + TokenLookup string + + // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. + // This is one of the two options to provide a token extractor. + // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. + // You can also provide both if you want. + TokenLookupFuncs []ValuesExtractor + + // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token + // parsing fails or parsed token is invalid. + // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library + ParseTokenFunc func(c echo.Context, auth string) (any, error) + + // Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. + // Not used if custom ParseTokenFunc is set. + // Optional. Defaults to function returning jwt.MapClaims + NewClaimsFunc func(c echo.Context) jwt.Claims +} + +// Errors +var ( + ErrJWTMissing = echojwt.ErrJWTMissing + ErrJWTInvalid = echojwt.ErrJWTInvalid +) + +func (config *JWTConfig) ToMiddleware() echo.MiddlewareFunc { + return echojwt.WithConfig(echojwt.Config{ + Skipper: config.Skipper, + BeforeFunc: config.BeforeFunc, + SuccessHandler: config.SuccessHandler, + ErrorHandler: config.ErrorHandler, + ContinueOnIgnoredError: config.ContinueOnIgnoredError, + ContextKey: config.ContextKey, + SigningKey: config.SigningKey, + SigningKeys: config.SigningKeys, + SigningMethod: config.SigningMethod, + KeyFunc: config.KeyFunc, + TokenLookup: config.TokenLookup, + TokenLookupFuncs: config.TokenLookupFuncs, + ParseTokenFunc: config.ParseTokenFunc, + NewClaimsFunc: nil, + }) +} + +// JWT returns a JSON Web Token (JWT) auth middleware. +// +// For valid token, it sets the user in context and calls next handler. +// For invalid token, it returns "401 - Unauthorized" error. +// For missing token, it returns "400 - Bad Request" error. +// +// See: https://jwt.io/introduction +// See `JWTConfig.TokenLookup` +// See https://github.com/labstack/echo-jwt +func JWT(signingKey any) echo.MiddlewareFunc { + return echojwt.JWT(signingKey) +} diff --git a/internal/middleware/key_auth.go b/internal/middleware/key_auth.go new file mode 100644 index 0000000..b5c8351 --- /dev/null +++ b/internal/middleware/key_auth.go @@ -0,0 +1,74 @@ +package middleware + +import ( + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" +) + +// KeyAuthValidator defines a function to validate KeyAuth credentials. +type KeyAuthValidator = middleware.KeyAuthValidator + +// KeyAuthErrorHandler defines a function which is executed for an invalid key. +type KeyAuthErrorHandler = middleware.KeyAuthErrorHandler + +// KeyAuthConfig defines the config for KeyAuth middleware. +type KeyAuthConfig struct { + Skipper Skipper + + // KeyLookup is a string in the form of ":" or ":,:" that is used + // to extract key from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. + // - "query:" + // - "form:" + // - "cookie:" + // Multiple sources example: + // - "header:Authorization,header:X-Api-Key" + KeyLookup string + + // AuthScheme to be used in the Authorization header. + // Optional. Default value "Bearer". + AuthScheme string + + // Validator is a function to validate key. + // Required. + Validator KeyAuthValidator + + // ErrorHandler defines a function which is executed for an invalid key. + // It may be used to define a custom error. + ErrorHandler KeyAuthErrorHandler + + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandler to set a default public key auth value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. + ContinueOnIgnoredError bool +} + +// DefaultKeyAuthConfig is the default KeyAuth middleware config. +var DefaultKeyAuthConfig = KeyAuthConfig{ + Skipper: DefaultSkipper, + KeyLookup: "header:" + echo.HeaderAuthorization, + AuthScheme: "Bearer", +} + +func (a *KeyAuthConfig) ToMiddleware() echo.MiddlewareFunc { + return middleware.KeyAuthWithConfig(middleware.KeyAuthConfig{ + Skipper: a.Skipper, + KeyLookup: a.KeyLookup, + AuthScheme: a.AuthScheme, + Validator: a.Validator, + ErrorHandler: a.ErrorHandler, + ContinueOnIgnoredError: a.ContinueOnIgnoredError, + }) +} + +func KeyAuth() echo.MiddlewareFunc { + return DefaultKeyAuthConfig.ToMiddleware() +} diff --git a/internal/middleware/logger.go b/internal/middleware/logger.go new file mode 100644 index 0000000..6e25d64 --- /dev/null +++ b/internal/middleware/logger.go @@ -0,0 +1,63 @@ +package middleware + +import ( + "fmt" + "github.com/labstack/echo/v4" + "github.com/rs/xid" + "net/http" + "sorbet/internal/util" + "sorbet/pkg/log" + "time" +) + +var color = log.NewColorer() + +func requestId(req *http.Request, res *echo.Response) string { + id := req.Header.Get(echo.HeaderXRequestID) + if id == "" { + id = xid.New().String() + res.Header().Set(echo.HeaderXRequestID, id) + } + return id +} + +// Logger 该日志中间件会自动获取或设置 RequestID +func Logger(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) (err error) { + req := c.Request() + res := c.Response() + start := time.Now() + id := requestId(req, res) + l := log.With(log.String("id", id)) + l.Info( + "Started %s %s for %s", + req.Method, req.RequestURI, c.RealIP(), + log.RawTime(start), + ) + c.SetLogger(util.NewCustomLogger(l)) + if err = next(c); err != nil { + c.Error(err) + } + stop := time.Now() + content := fmt.Sprintf( + "Completed %s %s %v %s in %v", + req.Method, req.RequestURI, res.Status, + http.StatusText(res.Status), stop.Sub(start), + ) + if res.Status >= 500 { + content = color.Cyan(content) + } else if res.Status >= 400 { + content = color.Red(content) + } else if res.Status >= 300 { + if res.Status == 304 { + content = color.Yellow(content) + } else { + content = color.White(content) + } + } else if res.Status >= 200 { + content = color.Green(content) + } + l.Info(content, log.RawTime(stop)) + return + } +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go new file mode 100644 index 0000000..27a5914 --- /dev/null +++ b/internal/middleware/middleware.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" +) + +// Skipper defines a function to skip middleware. Returning true skips processing +// the middleware. +type Skipper = middleware.Skipper + +// BeforeFunc defines a function which is executed just before the middleware. +type BeforeFunc = middleware.BeforeFunc + +type ValuesExtractor = middleware.ValuesExtractor + +type ToMiddleware interface { + ToMiddleware() echo.MiddlewareFunc +} + +// DefaultSkipper returns false which processes the middleware. +func DefaultSkipper(echo.Context) bool { + return false +} diff --git a/internal/middleware/rate_limiter.go b/internal/middleware/rate_limiter.go new file mode 100644 index 0000000..4d87a9a --- /dev/null +++ b/internal/middleware/rate_limiter.go @@ -0,0 +1,268 @@ +package middleware + +import ( + "github.com/labstack/echo/v4/middleware" + "net/http" + "sync" + "time" + + "github.com/labstack/echo/v4" + "golang.org/x/time/rate" +) + +// RateLimiterStore is the interface to be implemented by custom stores. +type RateLimiterStore interface { + // Allow Stores for the rate limiter have to implement the Allow method + Allow(identifier string) (bool, error) +} + +type ( + // RateLimiterConfig defines the configuration for the rate limiter + RateLimiterConfig struct { + Skipper Skipper + BeforeFunc middleware.BeforeFunc + // IdentifierExtractor uses echo.Context to extract the identifier for a visitor + IdentifierExtractor Extractor + // Store defines a store for the rate limiter + Store RateLimiterStore + // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error + ErrorHandler func(context echo.Context, err error) error + // DenyHandler provides a handler to be called when RateLimiter denies access + DenyHandler func(context echo.Context, identifier string, err error) error + } + // Extractor is used to extract data from echo.Context + Extractor func(context echo.Context) (string, error) +) + +// errors +var ( + // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded + ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") + // ErrExtractorError denotes an error raised when extractor function is unsuccessful + ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") +) + +// DefaultRateLimiterConfig defines default values for RateLimiterConfig +var DefaultRateLimiterConfig = RateLimiterConfig{ + Skipper: middleware.DefaultSkipper, + IdentifierExtractor: func(ctx echo.Context) (string, error) { + id := ctx.RealIP() + return id, nil + }, + ErrorHandler: func(context echo.Context, err error) error { + return &echo.HTTPError{ + Code: ErrExtractorError.Code, + Message: ErrExtractorError.Message, + Internal: err, + } + }, + DenyHandler: func(context echo.Context, identifier string, err error) error { + return &echo.HTTPError{ + Code: ErrRateLimitExceeded.Code, + Message: ErrRateLimitExceeded.Message, + Internal: err, + } + }, +} + +/* +RateLimiter returns a rate limiting middleware + + e := echo.New() + + limiterStore := middleware.NewRateLimiterMemoryStore(20) + + e.GET("/rate-limited", func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }, RateLimiter(limiterStore)) +*/ +func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc { + config := DefaultRateLimiterConfig + config.Store = store + return config.ToMiddleware() +} + +/* +ToMiddleware returns a rate limiting middleware + + e := echo.New() + + config := middleware.RateLimiterConfig{ + Skipper: DefaultSkipper, + Store: middleware.NewRateLimiterMemoryStore( + middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute} + ) + IdentifierExtractor: func(ctx echo.Context) (string, error) { + id := ctx.RealIP() + return id, nil + }, + ErrorHandler: func(context echo.Context, err error) error { + return context.JSON(http.StatusTooManyRequests, nil) + }, + DenyHandler: func(context echo.Context, identifier string) error { + return context.JSON(http.StatusForbidden, nil) + }, + } + + e.GET("/rate-limited", func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }, middleware.RateLimiterWithConfig(config)) +*/ +func (config *RateLimiterConfig) ToMiddleware() echo.MiddlewareFunc { + if config.Skipper == nil { + config.Skipper = DefaultSkipper + } + if config.IdentifierExtractor == nil { + config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor + } + if config.ErrorHandler == nil { + config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler + } + if config.DenyHandler == nil { + config.DenyHandler = DefaultRateLimiterConfig.DenyHandler + } + if config.Store == nil { + panic("Store configuration must be provided") + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + if config.BeforeFunc != nil { + config.BeforeFunc(c) + } + + identifier, err := config.IdentifierExtractor(c) + if err != nil { + c.Error(config.ErrorHandler(c, err)) + return nil + } + + if allow, err := config.Store.Allow(identifier); !allow { + c.Error(config.DenyHandler(c, identifier, err)) + return nil + } + return next(c) + } + } +} + +type ( + // RateLimiterMemoryStore is the built-in store implementation for RateLimiter + RateLimiterMemoryStore struct { + visitors map[string]*Visitor + mutex sync.Mutex + rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. + + burst int + expiresIn time.Duration + lastCleanup time.Time + + timeNow func() time.Time + } + // Visitor signifies a unique user's limiter details + Visitor struct { + *rate.Limiter + lastSeen time.Time + } +) + +/* +NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with +the provided rate (as req/s). +for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. + +Burst and ExpiresIn will be set to default values. + +Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded down value of the rate. + +Example (with 20 requests/sec): + + limiterStore := middleware.NewRateLimiterMemoryStore(20) +*/ +func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) { + return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ + Rate: rate, + }) +} + +/* +NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore +with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of +the configured rate if not provided or set to 0. + +The build-in memory store is usually capable for modest loads. For higher loads other +store implementations should be considered. + +Characteristics: +* Concurrency above 100 parallel requests may causes measurable lock contention +* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map +* A high number of requests from a single IP address may cause lock contention + +Example: + + limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig( + middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minute}, + ) +*/ +func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) { + store = &RateLimiterMemoryStore{} + + store.rate = config.Rate + store.burst = config.Burst + store.expiresIn = config.ExpiresIn + if config.ExpiresIn == 0 { + store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn + } + if config.Burst == 0 { + store.burst = int(config.Rate) + } + store.visitors = make(map[string]*Visitor) + store.timeNow = time.Now + store.lastCleanup = store.timeNow() + return +} + +// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore +type RateLimiterMemoryStoreConfig struct { + Rate rate.Limit // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. + Burst int // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached. + ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up +} + +// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore +var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{ + ExpiresIn: 3 * time.Minute, +} + +// Allow implements RateLimiterStore.Allow +func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { + store.mutex.Lock() + limiter, exists := store.visitors[identifier] + if !exists { + limiter = new(Visitor) + limiter.Limiter = rate.NewLimiter(store.rate, store.burst) + store.visitors[identifier] = limiter + } + now := store.timeNow() + limiter.lastSeen = now + if now.Sub(store.lastCleanup) > store.expiresIn { + store.cleanupStaleVisitors() + } + store.mutex.Unlock() + return limiter.AllowN(store.timeNow(), 1), nil +} + +/* +cleanupStaleVisitors helps manage the size of the visitors map by removing stale records +of users who haven't visited again after the configured expiry time has elapsed +*/ +func (store *RateLimiterMemoryStore) cleanupStaleVisitors() { + for id, visitor := range store.visitors { + if store.timeNow().Sub(visitor.lastSeen) > store.expiresIn { + delete(store.visitors, id) + } + } + store.lastCleanup = store.timeNow() +} diff --git a/internal/middleware/recover.go b/internal/middleware/recover.go new file mode 100644 index 0000000..629368e --- /dev/null +++ b/internal/middleware/recover.go @@ -0,0 +1,135 @@ +package middleware + +import ( + "fmt" + "net/http" + "runtime" + "sorbet/internal/util" + "sorbet/pkg/log" + + "github.com/labstack/echo/v4" +) + +// LogErrorFunc defines a function for custom logging in the middleware. +type LogErrorFunc func(c echo.Context, err error, stack []byte) error + +// RecoverConfig defines the config for Recover middleware. +type RecoverConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Size of the stack to be printed. + // Optional. Default value 4KB. + StackSize int + + // DisableStackAll disables formatting stack traces of all other goroutines + // into buffer after the trace for the current goroutine. + // Optional. Default value false. + DisableStackAll bool + + // DisablePrintStack disables printing stack trace. + // Optional. Default value as false. + DisablePrintStack bool + + // LogLevel is log level to printing stack trace. + // Optional. Default value 0 (Print). + LogLevel log.Level + + // LogErrorFunc defines a function for custom logging in the middleware. + // If it's set you don't need to provide LogLevel for config. + // If this function returns nil, the centralized HTTPErrorHandler will not be called. + LogErrorFunc LogErrorFunc + + // DisableErrorHandler disables the call to centralized HTTPErrorHandler. + // The recovered error is then passed back to upstream middleware, instead of swallowing the error. + // Optional. Default value false. + DisableErrorHandler bool +} + +// DefaultRecoverConfig is the default Recover middleware config. +var DefaultRecoverConfig = RecoverConfig{ + Skipper: DefaultSkipper, + StackSize: 4 << 10, // 4 KB + DisableStackAll: false, + DisablePrintStack: false, + LogLevel: log.LevelDebug, + LogErrorFunc: nil, + DisableErrorHandler: false, +} + +func (config *RecoverConfig) ToMiddleware() echo.MiddlewareFunc { + if config.Skipper == nil { + config.Skipper = DefaultRecoverConfig.Skipper + } + if config.StackSize == 0 { + config.StackSize = DefaultRecoverConfig.StackSize + } + switch config.LogLevel { + case log.LevelTrace, log.LevelFatal, log.LevelPanic: + panic("不应该将 LevelTrace、LevelFatal 和 LevelPanic 这三个日志作用在错误恢复中间件上") + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) (returnErr error) { + if config.Skipper(c) { + return next(c) + } + defer func() { + if r := recover(); r != nil { + if r == http.ErrAbortHandler { + panic(r) + } + err, ok := r.(error) + if !ok { + err = fmt.Errorf("%v", r) + } + var stack []byte + var length int + if !config.DisablePrintStack { + stack = make([]byte, config.StackSize) + length = runtime.Stack(stack, !config.DisableStackAll) + stack = stack[:length] + } + if config.LogErrorFunc != nil { + err = config.LogErrorFunc(c, err, stack) + } else if !config.DisablePrintStack { + var i []any + if _, ok := c.Logger().(*util.EchoLogger); ok { + i = append(i, + fmt.Sprintf("%v %s\n", err, stack[:length]), + log.RawLevel("PANIC RECOVER"), + ) + } else { + i = append(i, fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length])) + } + switch config.LogLevel { + case log.LevelDebug: + c.Logger().Debug(i...) + case log.LevelInfo: + c.Logger().Info(i...) + case log.LevelWarn: + c.Logger().Warn(i...) + case log.LevelError: + c.Logger().Error(i...) + case log.LevelOff: + // None. + default: + c.Logger().Print(i...) + } + } + if err != nil && !config.DisableErrorHandler { + c.Error(err) + } else { + returnErr = err + } + } + }() + return next(c) + } + } +} + +// Recover returns a middleware which recovers from panics anywhere in the chain +// and handles the control to the centralized HTTPErrorHandler. +func Recover() echo.MiddlewareFunc { + return DefaultRecoverConfig.ToMiddleware() +} diff --git a/internal/repositories/company.go b/internal/repositories/company.go new file mode 100644 index 0000000..21f3f55 --- /dev/null +++ b/internal/repositories/company.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type CompanyRepository struct { + *db.Repository[entities.Company] +} + +// NewCompanyRepository 创建公司仓库 +func NewCompanyRepository(orm *gorm.DB) *CompanyRepository { + return &CompanyRepository{ + db.NewRepositoryWith[entities.Company](orm, "id"), + } +} + diff --git a/internal/repositories/company_department.go b/internal/repositories/company_department.go new file mode 100755 index 0000000..2d127b2 --- /dev/null +++ b/internal/repositories/company_department.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type CompanyDepartmentRepository struct { + *db.Repository[entities.CompanyDepartment] +} + +// NewCompanyDepartmentRepository 创建公司部门仓库 +func NewCompanyDepartmentRepository(orm *gorm.DB) *CompanyDepartmentRepository { + return &CompanyDepartmentRepository{ + db.NewRepositoryWith[entities.CompanyDepartment](orm, "id"), + } +} + diff --git a/internal/repositories/company_staff.go b/internal/repositories/company_staff.go new file mode 100755 index 0000000..93cb3c8 --- /dev/null +++ b/internal/repositories/company_staff.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type CompanyStaffRepository struct { + *db.Repository[entities.CompanyStaff] +} + +// NewCompanyStaffRepository 创建公司员工仓库 +func NewCompanyStaffRepository(orm *gorm.DB) *CompanyStaffRepository { + return &CompanyStaffRepository{ + db.NewRepositoryWith[entities.CompanyStaff](orm, "id"), + } +} + diff --git a/internal/repositories/config.go b/internal/repositories/config.go new file mode 100755 index 0000000..8141718 --- /dev/null +++ b/internal/repositories/config.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type ConfigRepository struct { + *db.Repository[entities.Config] +} + +// NewConfigRepository 创建配置仓库 +func NewConfigRepository(orm *gorm.DB) *ConfigRepository { + return &ConfigRepository{ + db.NewRepositoryWith[entities.Config](orm, "id"), + } +} + diff --git a/internal/repositories/config_group.go b/internal/repositories/config_group.go new file mode 100755 index 0000000..73161a2 --- /dev/null +++ b/internal/repositories/config_group.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type ConfigGroupRepository struct { + *db.Repository[entities.ConfigGroup] +} + +// NewConfigGroupRepository 创建配置组仓库 +func NewConfigGroupRepository(orm *gorm.DB) *ConfigGroupRepository { + return &ConfigGroupRepository{ + db.NewRepositoryWith[entities.ConfigGroup](orm, "id"), + } +} + diff --git a/internal/repositories/feature.go b/internal/repositories/feature.go new file mode 100755 index 0000000..851ffe8 --- /dev/null +++ b/internal/repositories/feature.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type FeatureRepository struct { + *db.Repository[entities.Feature] +} + +// NewFeatureRepository 创建栏目仓库 +func NewFeatureRepository(orm *gorm.DB) *FeatureRepository { + return &FeatureRepository{ + db.NewRepositoryWith[entities.Feature](orm, "id"), + } +} + diff --git a/internal/repositories/feature_category.go b/internal/repositories/feature_category.go new file mode 100755 index 0000000..ba0fd45 --- /dev/null +++ b/internal/repositories/feature_category.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type FeatureCategoryRepository struct { + *db.Repository[entities.FeatureCategory] +} + +// NewFeatureCategoryRepository 创建栏目分类仓库 +func NewFeatureCategoryRepository(orm *gorm.DB) *FeatureCategoryRepository { + return &FeatureCategoryRepository{ + db.NewRepositoryWith[entities.FeatureCategory](orm, "id"), + } +} + diff --git a/internal/repositories/feature_config.go b/internal/repositories/feature_config.go new file mode 100755 index 0000000..f8be8cb --- /dev/null +++ b/internal/repositories/feature_config.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type FeatureConfigRepository struct { + *db.Repository[entities.FeatureConfig] +} + +// NewFeatureConfigRepository 创建栏目配置仓库 +func NewFeatureConfigRepository(orm *gorm.DB) *FeatureConfigRepository { + return &FeatureConfigRepository{ + db.NewRepositoryWith[entities.FeatureConfig](orm, "id"), + } +} + diff --git a/internal/repositories/feature_content.go b/internal/repositories/feature_content.go new file mode 100755 index 0000000..f3e40c1 --- /dev/null +++ b/internal/repositories/feature_content.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type FeatureContentRepository struct { + *db.Repository[entities.FeatureContent] +} + +// NewFeatureContentRepository 创建栏目内容仓库 +func NewFeatureContentRepository(orm *gorm.DB) *FeatureContentRepository { + return &FeatureContentRepository{ + db.NewRepositoryWith[entities.FeatureContent](orm, "id"), + } +} + diff --git a/internal/repositories/feature_content_chapter.go b/internal/repositories/feature_content_chapter.go new file mode 100755 index 0000000..9dc8832 --- /dev/null +++ b/internal/repositories/feature_content_chapter.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type FeatureContentChapterRepository struct { + *db.Repository[entities.FeatureContentChapter] +} + +// NewFeatureContentChapterRepository 创建栏目内容章回仓库 +func NewFeatureContentChapterRepository(orm *gorm.DB) *FeatureContentChapterRepository { + return &FeatureContentChapterRepository{ + db.NewRepositoryWith[entities.FeatureContentChapter](orm, "id"), + } +} + diff --git a/internal/repositories/feature_content_detail.go b/internal/repositories/feature_content_detail.go new file mode 100755 index 0000000..6c67632 --- /dev/null +++ b/internal/repositories/feature_content_detail.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type FeatureContentDetailRepository struct { + *db.Repository[entities.FeatureContentDetail] +} + +// NewFeatureContentDetailRepository 创建栏目内容详情仓库 +func NewFeatureContentDetailRepository(orm *gorm.DB) *FeatureContentDetailRepository { + return &FeatureContentDetailRepository{ + db.NewRepositoryWith[entities.FeatureContentDetail](orm, "id"), + } +} + diff --git a/internal/repositories/ioc.go b/internal/repositories/ioc.go new file mode 100644 index 0000000..81b3fa7 --- /dev/null +++ b/internal/repositories/ioc.go @@ -0,0 +1,27 @@ +package repositories + +import ( + "sorbet/pkg/ioc" +) + +func Init() { + ioc.MustFactory(NewCompanyRepository) + ioc.MustFactory(NewCompanyDepartmentRepository) + ioc.MustFactory(NewCompanyStaffRepository) + ioc.MustFactory(NewConfigRepository) + ioc.MustFactory(NewConfigGroupRepository) + ioc.MustFactory(NewFeatureRepository) + ioc.MustFactory(NewFeatureCategoryRepository) + ioc.MustFactory(NewFeatureConfigRepository) + ioc.MustFactory(NewFeatureContentRepository) + ioc.MustFactory(NewFeatureContentChapterRepository) + ioc.MustFactory(NewFeatureContentDetailRepository) + ioc.MustFactory(NewResourceRepository) + ioc.MustFactory(NewResourceCategoryRepository) + ioc.MustFactory(NewSystemLogRepository) + ioc.MustFactory(NewSystemMenuRepository) + ioc.MustFactory(NewSystemPermissionRepository) + ioc.MustFactory(NewSystemRoleRepository) + ioc.MustFactory(NewSystemRolePowerRepository) + ioc.MustFactory(NewSystemUserRepository) +} diff --git a/internal/repositories/resource.go b/internal/repositories/resource.go new file mode 100755 index 0000000..a8230da --- /dev/null +++ b/internal/repositories/resource.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type ResourceRepository struct { + *db.Repository[entities.Resource] +} + +// NewResourceRepository 创建资源仓库 +func NewResourceRepository(orm *gorm.DB) *ResourceRepository { + return &ResourceRepository{ + db.NewRepositoryWith[entities.Resource](orm, "id"), + } +} + diff --git a/internal/repositories/resource_category.go b/internal/repositories/resource_category.go new file mode 100755 index 0000000..47e2202 --- /dev/null +++ b/internal/repositories/resource_category.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type ResourceCategoryRepository struct { + *db.Repository[entities.ResourceCategory] +} + +// NewResourceCategoryRepository 创建资源分类仓库 +func NewResourceCategoryRepository(orm *gorm.DB) *ResourceCategoryRepository { + return &ResourceCategoryRepository{ + db.NewRepositoryWith[entities.ResourceCategory](orm, "id"), + } +} + diff --git a/internal/repositories/system_log.go b/internal/repositories/system_log.go new file mode 100755 index 0000000..1988f4e --- /dev/null +++ b/internal/repositories/system_log.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type SystemLogRepository struct { + *db.Repository[entities.SystemLog] +} + +// NewSystemLogRepository 创建系统日志仓库 +func NewSystemLogRepository(orm *gorm.DB) *SystemLogRepository { + return &SystemLogRepository{ + db.NewRepositoryWith[entities.SystemLog](orm, "id"), + } +} + diff --git a/internal/repositories/system_menu.go b/internal/repositories/system_menu.go new file mode 100755 index 0000000..d8d1868 --- /dev/null +++ b/internal/repositories/system_menu.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type SystemMenuRepository struct { + *db.Repository[entities.SystemMenu] +} + +// NewSystemMenuRepository 创建系统菜单仓库 +func NewSystemMenuRepository(orm *gorm.DB) *SystemMenuRepository { + return &SystemMenuRepository{ + db.NewRepositoryWith[entities.SystemMenu](orm, "id"), + } +} + diff --git a/internal/repositories/system_permission.go b/internal/repositories/system_permission.go new file mode 100755 index 0000000..457ec68 --- /dev/null +++ b/internal/repositories/system_permission.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type SystemPermissionRepository struct { + *db.Repository[entities.SystemPermission] +} + +// NewSystemPermissionRepository 创建系统权限仓库 +func NewSystemPermissionRepository(orm *gorm.DB) *SystemPermissionRepository { + return &SystemPermissionRepository{ + db.NewRepositoryWith[entities.SystemPermission](orm, "id"), + } +} + diff --git a/internal/repositories/system_role.go b/internal/repositories/system_role.go new file mode 100755 index 0000000..7da9cde --- /dev/null +++ b/internal/repositories/system_role.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type SystemRoleRepository struct { + *db.Repository[entities.SystemRole] +} + +// NewSystemRoleRepository 创建系统用户角色仓库 +func NewSystemRoleRepository(orm *gorm.DB) *SystemRoleRepository { + return &SystemRoleRepository{ + db.NewRepositoryWith[entities.SystemRole](orm, "id"), + } +} + diff --git a/internal/repositories/system_role_power.go b/internal/repositories/system_role_power.go new file mode 100755 index 0000000..385ea29 --- /dev/null +++ b/internal/repositories/system_role_power.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type SystemRolePowerRepository struct { + *db.Repository[entities.SystemRolePower] +} + +// NewSystemRolePowerRepository 创建角色授权仓库 +func NewSystemRolePowerRepository(orm *gorm.DB) *SystemRolePowerRepository { + return &SystemRolePowerRepository{ + db.NewRepositoryWith[entities.SystemRolePower](orm, "id"), + } +} + diff --git a/internal/repositories/system_user.go b/internal/repositories/system_user.go new file mode 100755 index 0000000..98b537c --- /dev/null +++ b/internal/repositories/system_user.go @@ -0,0 +1,19 @@ +package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type SystemUserRepository struct { + *db.Repository[entities.SystemUser] +} + +// NewSystemUserRepository 创建系统用户仓库 +func NewSystemUserRepository(orm *gorm.DB) *SystemUserRepository { + return &SystemUserRepository{ + db.NewRepositoryWith[entities.SystemUser](orm, "id"), + } +} + diff --git a/internal/util/echo_context.go b/internal/util/echo_context.go new file mode 100644 index 0000000..3c2628c --- /dev/null +++ b/internal/util/echo_context.go @@ -0,0 +1,7 @@ +package util + +import "github.com/labstack/echo/v4" + +type EchoContext struct { + echo.Context +} diff --git a/internal/util/echo_logger.go b/internal/util/echo_logger.go new file mode 100644 index 0000000..6681a14 --- /dev/null +++ b/internal/util/echo_logger.go @@ -0,0 +1,201 @@ +package util + +import ( + "encoding/json" + "fmt" + "github.com/labstack/gommon/log" + "io" + "os" + sorbetLog "sorbet/pkg/log" +) + +type EchoLogger struct { + l sorbetLog.Logger +} + +func NewLogger() *EchoLogger { + return NewCustomLogger(sorbetLog.Default()) +} + +func NewCustomLogger(l sorbetLog.Logger) *EchoLogger { + return &EchoLogger{l} +} + +func (l *EchoLogger) Output() io.Writer { + return l.l.Writer() +} + +func (l *EchoLogger) SetOutput(w io.Writer) { + l.l.SetWriter(w) +} + +func (l *EchoLogger) Prefix() string { + return "" +} + +func (l *EchoLogger) SetPrefix(_ string) { + // FIXME(hupeh): 能否使用 WithGroup 或者 Attr 实现? + fmt.Println("cannot set a prefix into logging") +} + +func (l *EchoLogger) Level() log.Lvl { + switch v := l.l.Level(); v { + case sorbetLog.LevelDebug, sorbetLog.LevelTrace: + return log.DEBUG + case sorbetLog.LevelInfo: + return log.INFO + case sorbetLog.LevelWarn: + return log.WARN + case sorbetLog.LevelError: + return log.ERROR + case sorbetLog.LevelOff: + return log.OFF + case sorbetLog.LevelFatal: + return log.Lvl(7) + default: + if v < sorbetLog.LevelTrace { + return log.DEBUG + } else { + return log.Lvl(7) + } + } +} + +func (l *EchoLogger) SetLevel(v log.Lvl) { + switch v { + case log.DEBUG: + l.l.SetLevel(sorbetLog.LevelDebug) + case log.INFO: + l.l.SetLevel(sorbetLog.LevelInfo) + case log.WARN: + l.l.SetLevel(sorbetLog.LevelWarn) + case log.ERROR: + l.l.SetLevel(sorbetLog.LevelError) + case log.OFF: + l.l.SetLevel(sorbetLog.LevelOff) + } +} + +func (l *EchoLogger) SetHeader(_ string) { + fmt.Println("cannot set a header into logging") +} + +func (l *EchoLogger) Print(i ...interface{}) { + l.log(l.l.Level(), i...) +} + +func (l *EchoLogger) Printf(format string, args ...interface{}) { + l.logf(l.l.Level(), format, args...) +} + +func (l *EchoLogger) Printj(j log.JSON) { + l.logj(l.l.Level(), j) +} + +func (l *EchoLogger) Debug(i ...interface{}) { + l.log(sorbetLog.LevelDebug, i...) +} + +func (l *EchoLogger) Debugf(format string, args ...interface{}) { + l.logf(sorbetLog.LevelDebug, format, args...) +} + +func (l *EchoLogger) Debugj(j log.JSON) { + l.logj(sorbetLog.LevelDebug, j) +} + +func (l *EchoLogger) Info(i ...interface{}) { + l.log(sorbetLog.LevelInfo, i...) +} + +func (l *EchoLogger) Infof(format string, args ...interface{}) { + l.logf(sorbetLog.LevelInfo, format, args...) +} + +func (l *EchoLogger) Infoj(j log.JSON) { + l.logj(sorbetLog.LevelInfo, j) +} + +func (l *EchoLogger) Warn(i ...interface{}) { + l.log(sorbetLog.LevelWarn, i...) +} + +func (l *EchoLogger) Warnf(format string, args ...interface{}) { + l.logf(sorbetLog.LevelInfo, format, args...) +} + +func (l *EchoLogger) Warnj(j log.JSON) { + l.logj(sorbetLog.LevelWarn, j) +} + +func (l *EchoLogger) Error(i ...interface{}) { + l.log(sorbetLog.LevelError, i...) +} + +func (l *EchoLogger) Errorf(format string, args ...interface{}) { + l.logf(sorbetLog.LevelError, format, args...) +} + +func (l *EchoLogger) Errorj(j log.JSON) { + l.logj(sorbetLog.LevelError, j) +} + +func (l *EchoLogger) Fatal(i ...interface{}) { + l.log(sorbetLog.LevelFatal, i...) + os.Exit(1) +} + +func (l *EchoLogger) Fatalj(j log.JSON) { + l.logj(sorbetLog.LevelFatal, j) + os.Exit(1) +} + +func (l *EchoLogger) Fatalf(format string, args ...interface{}) { + l.logf(sorbetLog.LevelFatal, format, args...) + os.Exit(1) +} + +func (l *EchoLogger) Panic(i ...interface{}) { + l.log(sorbetLog.LevelPanic, i...) + panic(fmt.Sprint(i...)) +} + +func (l *EchoLogger) Panicj(j log.JSON) { + l.logj(sorbetLog.LevelPanic, j) + panic(j) +} + +func (l *EchoLogger) Panicf(format string, args ...interface{}) { + l.logf(sorbetLog.LevelPanic, format, args...) + panic(fmt.Sprintf(format, args...)) +} + +func (l *EchoLogger) log(level sorbetLog.Level, args ...any) { + var attrs []any + var formats []any + for _, arg := range args { + switch arg.(type) { + case sorbetLog.Attr: + attrs = append(attrs, attrs) + default: + formats = append(formats, args) + } + } + msg := "" + if len(formats) > 0 { + msg = fmt.Sprint(formats...) + } + l.l.Log(level, msg, attrs...) +} + +func (l *EchoLogger) logf(level sorbetLog.Level, format string, args ...any) { + l.l.Log(level, format, args...) +} + +func (l *EchoLogger) logj(level sorbetLog.Level, j log.JSON) { + b, err := json.Marshal(j) + if err != nil { + panic(err) + } + l.l.Log(level, string(b)) +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..34396bf --- /dev/null +++ b/main.go @@ -0,0 +1,88 @@ +package main + +import ( + "github.com/labstack/echo/v4" + "github.com/swaggo/echo-swagger" + "gorm.io/gorm" + "net/http" + _ "sorbet/docs" // 开发文档 + "sorbet/internal" + "sorbet/internal/entities" + "sorbet/internal/middleware" + "sorbet/internal/repositories" + "sorbet/internal/util" + "sorbet/pkg/env" + "sorbet/pkg/ioc" + "sorbet/pkg/rsp" +) + +// @title 博客系统 +// @version 1.0 +// @description 基于 Echo 框架的基本库 +// +// @contact.name API Support +// @contact.url http://www.swagger.io/support +// @contact.email support@swagger.io +// +// @license.name Apache 2.0 +// @license.url http://www.apache.org/licenses/LICENSE-2.0.html +// +// @accept json +func main() { + if err := env.Init(); err != nil { + panic(err) + } + + if err := internal.Init(); err != nil { + panic(err) + } + + repositories.Init() + + e := echo.New() + e.HideBanner = true + e.HidePort = true + e.HTTPErrorHandler = func(err error, c echo.Context) { + if !c.Response().Committed { + http.Error(c.Response(), err.Error(), 500) + } + } + e.Logger = util.NewLogger() + e.Use(middleware.Recover()) + e.Use(middleware.CORS()) + e.Use(middleware.Logger) + e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + db := ioc.MustGet[gorm.DB]().WithContext(c.Request().Context()) + ci := ioc.Fork() + ci.Bind(db) + c.Set("db", db) + c.Set("ioc", ci) + return next(c) + } + }) + e.GET("/swagger/*", echoSwagger.WrapHandler) + e.GET("/", func(c echo.Context) error { + repo := repositories.NewCompanyRepository(c.Get("db").(*gorm.DB)) + //err := c.Get("ioc").(*ioc.Container).Resolve(&repo) + //if err != nil { + // return err + //} + //db := ioc.MustGet[gorm.DB]().WithContext(c.Request().Context()) + //ioc.Fork().Bind(db) + //repo := ioc.MustGet[repositories.CompanyRepository]() + repo.Create(&entities.Company{Name: "海苔一诺"}) + pager, err := repo.Paginate() + if err != nil { + return err + } + return rsp.Ok(c, pager) + }) + e.Logger.Fatal(e.Start(":1323")) +} + +func panicIf(e error) { + if e != nil { + panic(e) + } +} diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go new file mode 100644 index 0000000..afeea8d --- /dev/null +++ b/pkg/bus/bus.go @@ -0,0 +1,17 @@ +package bus + +import "context" + +var DefaultEmitter = New() + +func Emit(ctx context.Context, topic string, data any) { + DefaultEmitter.Emit(ctx, topic, data) +} + +func Listen(topic string, listener Listener, options ...ListenOption) { + DefaultEmitter.Listen(topic, listener, options...) +} + +func Cancel(topic string, listeners ...Listener) { + DefaultEmitter.Cancel(topic, listeners...) +} diff --git a/pkg/bus/emitter.go b/pkg/bus/emitter.go new file mode 100644 index 0000000..f5882d8 --- /dev/null +++ b/pkg/bus/emitter.go @@ -0,0 +1,141 @@ +package bus + +import ( + "context" + "github.com/rs/xid" + "reflect" + "sync" + "time" +) + +type Emitter struct { + mu sync.RWMutex + topics map[string][]ListenerObject +} + +func New() *Emitter { + return &Emitter{ + mu: sync.RWMutex{}, + topics: make(map[string][]ListenerObject), + } +} + +func (e *Emitter) Emit(ctx context.Context, topic string, data any) { + e.mu.RLock() + listeners, ok := e.topics[topic] + e.mu.RUnlock() + + if !ok { + return + } + + stop := make(chan struct{}) + txID, _ := ctx.Value("txid-key").(string) + source, _ := ctx.Value("source-key").(string) + ctx, cancel := context.WithCancel(ctx) + + if txID == "" { + txID = xid.New().String() + } + + event := Event{ + ID: xid.New().String(), + TxID: txID, + Topic: topic, + Source: source, + OccurredAt: time.Now(), + Data: data, + stopPropagation: func() { + select { + case <-stop: + default: + close(stop) + cancel() + } + }, + } + + for _, listener := range listeners { + select { + case <-stop: + return + default: + if listener.once != nil { + listener.once.Do(func() { + if listener.async { + go listener.do(ctx, event) + } else { + listener.do(ctx, event) + } + e.Cancel(topic, listener.do) + }) + } else if listener.async { + go listener.do(ctx, event) + } else { + listener.do(ctx, event) + } + } + } +} + +func (e *Emitter) Listen(topic string, listener Listener, options ...ListenOption) { + if listener == nil { + return + } + e.mu.Lock() + defer e.mu.Unlock() + if e.topics == nil { + e.topics = make(map[string][]ListenerObject) + } + listenerObject := ListenerObject{ + async: false, + once: nil, + do: listener, + ptr: reflect.ValueOf(listener).Pointer(), + } + for _, option := range options { + option(&listenerObject) + } + if listeners, has := e.topics[topic]; !has { + e.topics[topic] = []ListenerObject{listenerObject} + } else { + e.topics[topic] = append(listeners, listenerObject) + } +} + +func (e *Emitter) Cancel(topic string, listeners ...Listener) { + e.mu.Lock() + defer e.mu.Unlock() + if e.topics == nil { + return + } + ls, has := e.topics[topic] + if !has || len(ls) == 0 { + return + } + if len(listeners) == 0 { + delete(e.topics, topic) + return + } + for _, listener := range listeners { + for i, l := range ls { + if l.ptr == 0 { + l.ptr = reflect.ValueOf(l).Pointer() + } + if l.ptr == reflect.ValueOf(listener).Pointer() { + if i == 0 { + ls = ls[1:] + } else if i == len(ls)-1 { + ls = ls[:i] + } else { + ls = append(ls[:i], ls[i+1:]...) + } + } + } + } + if len(ls) == 0 { + delete(e.topics, topic) + } else { + e.topics[topic] = ls + } +} diff --git a/pkg/bus/event.go b/pkg/bus/event.go new file mode 100644 index 0000000..7db31f3 --- /dev/null +++ b/pkg/bus/event.go @@ -0,0 +1,17 @@ +package bus + +import "time" + +type Event struct { + ID string // identifier + TxID string // transaction identifier + Topic string // topic name + Source string // source of the event + OccurredAt time.Time // creation time in nanoseconds + Data any // actual event data + stopPropagation func() +} + +func (e *Event) StopPropagation() { + e.stopPropagation() +} diff --git a/pkg/bus/listener.go b/pkg/bus/listener.go new file mode 100644 index 0000000..e69277b --- /dev/null +++ b/pkg/bus/listener.go @@ -0,0 +1,33 @@ +package bus + +import ( + "context" + "sync" +) + +type Listener func(ctx context.Context, event Event) + +type ListenerObject struct { + async bool + once *sync.Once + do Listener + ptr uintptr +} + +type ListenOption func(*ListenerObject) + +func WithAsync(async bool) ListenOption { + return func(o *ListenerObject) { + o.async = async + } +} + +func WithOnce(once bool) ListenOption { + return func(o *ListenerObject) { + if !once && o.once != nil { + o.once = nil + } else if once && o.once == nil { + o.once = new(sync.Once) + } + } +} diff --git a/pkg/cast/cast.go b/pkg/cast/cast.go new file mode 100644 index 0000000..e4e49c6 --- /dev/null +++ b/pkg/cast/cast.go @@ -0,0 +1,174 @@ +// 将字符串转换成其它基本类型 + +package cast + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +const message = "cast: cannot cast `%v` to type `%v`" + +func Uint(value string) (uint, error) { + v, err := strconv.ParseUint(value, 0, strconv.IntSize) + if err != nil { + return 0, err + } + return uint(v), nil +} + +func Uint8(value string) (uint8, error) { + v, err := strconv.ParseUint(value, 0, 8) + if err != nil { + return 0, err + } + return uint8(v), nil +} + +func Uint16(value string) (uint16, error) { + v, err := strconv.ParseUint(value, 0, 16) + if err != nil { + return 0, err + } + return uint16(v), nil +} + +func Uint32(value string) (uint32, error) { + v, err := strconv.ParseUint(value, 0, 32) + if err != nil { + return 0, err + } + return uint32(v), nil +} + +func AsUint64(value string) (uint64, error) { + v, err := strconv.ParseUint(value, 0, 64) + if err != nil { + return 0, err + } + return v, nil +} + +func Int(value string) (int, error) { + v, err := strconv.ParseInt(value, 0, strconv.IntSize) + if err != nil { + return 0, fmt.Errorf(message, value, "int") + } + return int(v), nil +} + +func Int8(value string) (int8, error) { + v, err := strconv.ParseInt(value, 0, 8) + if err != nil { + return 0, err + } + return int8(v), nil +} + +func Int16(value string) (int16, error) { + v, err := strconv.ParseInt(value, 0, 16) + if err != nil { + return 0, err + } + return int16(v), nil +} + +func Int32(value string) (int32, error) { + v, err := strconv.ParseInt(value, 0, 32) + if err != nil { + return 0, err + } + return int32(v), nil +} + +func Int64(value string) (int64, error) { + v, err := strconv.ParseInt(value, 0, 64) + if err != nil { + return 0, err + } + return v, nil +} + +func Float32(value string) (float32, error) { + v, err := strconv.ParseFloat(value, 64) + if err != nil { + return 0, err + } + return float32(v), nil +} + +func Float64(value string) (float64, error) { + v, err := strconv.ParseFloat(value, 64) + if err != nil { + return 0, err + } + return v, nil +} + +func Bool(value string) (bool, error) { + v, err := strconv.ParseBool(value) + if err != nil { + return false, err + } + return v, nil +} + +// FromType casts a string value to the given reflected type. +func FromType(value string, targetType reflect.Type) (interface{}, error) { + var typeName = targetType.String() + + if strings.HasPrefix(typeName, "[]") { + itemType := typeName[2:] + array := reflect.New(targetType).Elem() + + for _, v := range strings.Split(value, ",") { + if item, err := FromString(strings.Trim(v, " \n\r"), itemType); err != nil { + return array.Interface(), err + } else { + array = reflect.Append(array, reflect.ValueOf(item)) + } + } + + return array.Interface(), nil + } + + return FromString(value, typeName) +} + +// FromString casts a string value to the given type name. +func FromString(value string, targetType string) (any, error) { + switch targetType { + case "int": + return Int(value) + case "int8": + return Int8(value) + case "int16": + return Int16(value) + case "int32": + return Int32(value) + case "int64": + return Int64(value) + case "uint": + return Uint(value) + case "uint8": + return Uint8(value) + case "uint16": + return Uint16(value) + case "uint32": + return Uint32(value) + case "uint64": + return AsUint64(value) + case "bool": + return Bool(value) + case "float32": + return Float32(value) + case "float64": + return Float64(value) + case "string": + return value, nil + } + + return nil, fmt.Errorf("cast: type %v is not supported", targetType) +} diff --git a/pkg/db/db.go b/pkg/db/db.go new file mode 100644 index 0000000..a1d4b4d --- /dev/null +++ b/pkg/db/db.go @@ -0,0 +1,311 @@ +package db + +import ( + "database/sql" + "errors" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + "gorm.io/plugin/optimisticlock" + "sorbet/pkg/env" + "sync" + "time" +) + +var ( + // TODO(hupeh): 使用原子性操作 atomic.Value + db *gorm.DB + lock sync.RWMutex + + ErrNoCodeFirst = errors.New("no code first") + + // 使用东八区时间 + // https://cloud.tencent.com/developer/article/1805859 + cstZone = time.FixedZone("CST", 8*3600) +) + +type Version = optimisticlock.Version +type SessionConfig = gorm.Session + +type BaseConfig struct { + TimeLocation *time.Location + NamingStrategy schema.Namer + Logger logger.Interface + Plugins map[string]gorm.Plugin + TablePrefix string + SingularTable bool + NameReplacer schema.Replacer + IdentifierMaxLength int + MaxIdleConns int + MaxOpenConns int + ConnMaxLifetime time.Duration +} + +type Config struct { + BaseConfig + Driver string + StoreEngine string + DSN string +} + +// DB 获取数据库操作实例 +func DB() *gorm.DB { + lock.RLock() + if db != nil { + lock.RUnlock() + return db + } + lock.RUnlock() + + lock.Lock() + db = New() + lock.Unlock() + + return db +} + +// SetDB 自定义操作引擎 +func SetDB(engine *gorm.DB) { + lock.Lock() + defer lock.Unlock() + db = engine +} + +// New 创建数据库操作引擎,初始化参数来自环境变量 +func New() *gorm.DB { + engine, err := NewWithConfig(&Config{ + BaseConfig: BaseConfig{ + TimeLocation: cstZone, + TablePrefix: env.String("DB_PREFIX"), + SingularTable: env.Bool("DB_SINGULAR_TABLE", false), + IdentifierMaxLength: env.Int("DB_IDENTIFIER_MAX_LENGTH", 0), + Logger: &dbLogger{200 * time.Millisecond}, + MaxIdleConns: env.Int("DB_MAX_IDLE_CONNS", 0), + MaxOpenConns: env.Int("DB_MAX_OPEN_CONNS", 0), + ConnMaxLifetime: env.Duration("DB_CONN_MAX_LIFETIME", 0), + }, + Driver: env.String("DB_DRIVER", "sqlite3"), + StoreEngine: env.String("DB_STORE_ENGINE", "InnoDB"), + DSN: env.String("DB_DSN", "./app.db"), + }) + if err != nil { + panic(err) + } + return engine +} + +// NewWithConfig 通过配置创建数据库操作引擎 +func NewWithConfig(config *Config) (*gorm.DB, error) { + var dialector gorm.Dialector + switch config.Driver { + case "mysql": + dialector = mysql.Open(config.DSN) + case "pgsql": + dialector = postgres.Open(config.DSN) + case "sqlite", "sqlite3": + dialector = sqlite.Open(config.DSN) + case "sqlserver": + dialector = sqlserver.Open(config.DSN) + default: + return nil, errors.New("不支持的数据库驱动:" + config.Driver) + } + + engine, err := NewWithDialector(dialector, &config.BaseConfig) + if err != nil { + return nil, err + } + + if config.Driver == "mysql" && config.StoreEngine != "" { + engine = engine.Set("gorm:table_options", "ENGINE="+config.StoreEngine) + } + + return engine, nil +} + +// NewWithDialector 通过指定的 dialector 创建数据库操作引擎 +func NewWithDialector(dialector gorm.Dialector, config *BaseConfig) (*gorm.DB, error) { + engine, err := gorm.Open(dialector, &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + TablePrefix: config.TablePrefix, + SingularTable: config.SingularTable, + NameReplacer: config.NameReplacer, + NoLowerCase: false, + IdentifierMaxLength: config.IdentifierMaxLength, + }, + Logger: config.Logger, + NowFunc: func() time.Time { + if config.TimeLocation == nil { + return time.Now() + } + return time.Now().In(config.TimeLocation) + }, + QueryFields: false, + }) + if err != nil { + return nil, err + } + + rawDB, err := engine.DB() + if err != nil { + return nil, err + } + if config.MaxIdleConns > 0 { + rawDB.SetMaxIdleConns(config.MaxIdleConns) + } + if config.MaxOpenConns > 0 { + rawDB.SetMaxOpenConns(config.MaxOpenConns) + } + if config.ConnMaxLifetime > 0 { + rawDB.SetConnMaxLifetime(config.ConnMaxLifetime) + } + + return engine, nil +} + +// Sync 同步数据库结构,属于代码优先模式。 +// +// 在使用该方法之前需要在环境变量中开启 "DB_CODE_FIRST" 选项。 +// +// 这是非常危险的操作,必须慎之又慎,因为函数将进行如下的同步操作: +// * 自动检测和创建表,这个检测是根据表的名字 +// * 自动检测和新增表中的字段,这个检测是根据字段名,同时对表中多余的字段给出警告信息 +// * 自动检测,创建和删除索引和唯一索引,这个检测是根据索引的一个或多个字段名,而不根据索引名称。因此这里需要注意,如果在一个有大量数据的表中引入新的索引,数据库可能需要一定的时间来建立索引。 +// * 自动转换varchar字段类型到text字段类型,自动警告其它字段类型在模型和数据库之间不一致的情况。 +// * 自动警告字段的默认值,是否为空信息在模型和数据库之间不匹配的情况 +// +// 以上这些警告信息需要将日志的显示级别调整为Warn级别才会显示。 +func Sync(beans ...any) error { + if env.Bool("DB_CODE_FIRST") { + return DB().AutoMigrate(beans...) + } + return ErrNoCodeFirst +} + +// Ping ping 一下数据库连接 +func Ping() error { + raw, err := DB().DB() + if err != nil { + return err + } + return raw.Ping() +} + +// Stats 返回数据库统计信息 +func Stats() (*sql.DBStats, error) { + raw, err := DB().DB() + if err != nil { + return nil, err + } + stats := raw.Stats() + return &stats, nil +} + +// Now 这是个工具函数,返回当前时间 +func Now() time.Time { + return DB().Config.NowFunc() +} + +// Session 会话模式 +// +// 在该模式下会创建并缓存预编译语句,从而提高后续的调用速度 +func Session(config *SessionConfig) *gorm.DB { + return DB().Session(config) +} + +// Model 通过模型进行下一步操作 +func Model(value any) *gorm.DB { + return DB().Model(value) +} + +// Table 通过数据表面进行下一步操作 +func Table(name string, args ...any) *gorm.DB { + return DB().Table(name, args...) +} + +// Create 通过模型创建记录 +// +// 使用模型创建一条记录: +// +// user := User{Name: "Jinzhu", Age: 18, Birthday: time.Now()} +// ok, err := db.Create(&user) // 通过数据的指针来创建 +// +// 我们还可以使用模型切边创建多项记录: +// +// users := []*User{ +// User{Name: "Jinzhu", Age: 18, Birthday: time.Now()}, +// User{Name: "Jackson", Age: 19, Birthday: time.Now()}, +// } +// ok, err := db.Create(users) // 通过 slice 创建多条记录 +func Create(value any) (bool, error) { + result := DB().Create(value) + if err := result.Error; err != nil { + return false, err + } + return result.RowsAffected > 0, nil +} + +// Save 保存模型数据,由以下两点需要注意: +// +// - 该函数会保存所有的字段,即使字段是零值。 +// - 如果模型中的主键值是零值,将会创建该数据。 +func Save(value any) (bool, error) { + result := DB().Save(value) + if err := result.Error; err != nil { + return false, err + } + return result.RowsAffected > 0, nil +} + +func Upsert(bean any, conflict clause.OnConflict) (bool, error) { + result := DB().Clauses(conflict).Create(bean) + if err := result.Error; err != nil { + return false, err + } + return result.RowsAffected > 0, nil +} + +// Transaction 自动事务管理 +// +// 如果在 fc 中开启了新的事务,必须确保这个内嵌的事务被提交或被回滚。 +func Transaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) error { + return DB().Transaction(fc, opts...) +} + +// Begin 开启事务 +// +// 使用示例 +// +// tx := db.Begin() // 开始事务 +// tx.Create() // 执行一些数据库操作 +// tx.Rollback() // 遇到错误时回滚事务 +// tx.Commit() // 否则,提交事务 +// +// 事务一旦开始,就应该使用返回的 tx 对象处理数据 +func Begin(opts ...*sql.TxOptions) (tx *gorm.DB) { + return DB().Begin(opts...) +} + +// Raw 执行 SQL 查询语句 +func Raw(sql string, values ...any) *gorm.DB { + return DB().Raw(sql, values...) +} + +// Exec 执行Insert, Update, Delete 等命令的 SQL 语句, +// 如果需要查询数据请使用 Query 函数 +func Exec(sql string, values ...any) *gorm.DB { + return DB().Exec(sql, values...) +} + +func Unscoped() *gorm.DB { + return DB().Unscoped() +} + +// Migrator 返回迁移接口 +func Migrator() gorm.Migrator { + return DB().Migrator() +} diff --git a/pkg/db/delete_builder.go b/pkg/db/delete_builder.go new file mode 100644 index 0000000..881dd8e --- /dev/null +++ b/pkg/db/delete_builder.go @@ -0,0 +1,18 @@ +package db + +import "gorm.io/gorm" + +type DeleteBuilder[T any] struct { + Expr + db *gorm.DB +} + +func NewDeleteBuilder[T any](db *gorm.DB) *DeleteBuilder[T] { + return &DeleteBuilder[T]{Expr{}, db} +} + +func (b *DeleteBuilder[T]) Commit() (int64, error) { + var t T + res := b.db.Scopes(b.Scopes).Delete(&t) + return res.RowsAffected, res.Error +} diff --git a/pkg/db/expr.go b/pkg/db/expr.go new file mode 100644 index 0000000..c886bf8 --- /dev/null +++ b/pkg/db/expr.go @@ -0,0 +1,174 @@ +package db + +import ( + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +type Expr struct { + clauses []clause.Expression +} + +func (e *Expr) add(expr clause.Expression) *Expr { + e.clauses = append(e.clauses, expr) + return e +} + +func (e *Expr) Eq(col string, val any) *Expr { + return e.add(clause.Eq{Column: col, Value: val}) +} + +func (e *Expr) Neq(col string, val any) *Expr { + return e.add(clause.Neq{Column: col, Value: val}) +} + +func (e *Expr) Lt(col string, val any) *Expr { + return e.add(clause.Lt{Column: col, Value: val}) +} + +func (e *Expr) Lte(col string, val any) *Expr { + return e.add(clause.Lte{Column: col, Value: val}) +} + +func (e *Expr) Gt(col string, val any) *Expr { + return e.add(clause.Gt{Column: col, Value: val}) +} + +func (e *Expr) Gte(col string, val any) *Expr { + return e.add(clause.Gte{Column: col, Value: val}) +} + +func (e *Expr) Between(col string, less, more any) *Expr { + return e.add(between{col, less, more}) +} + +func (e *Expr) NotBetween(col string, less, more any) *Expr { + return e.add(clause.Not(between{col, less, more})) +} + +func (e *Expr) IsNull(col string) *Expr { + return e.add(null{col}) +} + +func (e *Expr) NotNull(col string) *Expr { + return e.add(clause.Not(null{col})) +} + +func (e *Expr) Like(col, tpl string) *Expr { + return e.add(clause.Like{Column: col, Value: tpl}) +} + +func (e *Expr) NotLike(col, tpl string) *Expr { + return e.add(clause.Not(clause.Like{Column: col, Value: tpl})) +} + +func (e *Expr) In(col string, values ...any) *Expr { + return e.add(clause.IN{Column: col, Values: values}) +} + +func (e *Expr) NotIn(col string, values ...any) *Expr { + return e.add(clause.Not(clause.IN{Column: col, Values: values})) +} + +func (e *Expr) When(condition bool, then func(ex *Expr), elses ...func(ex *Expr)) *Expr { + if condition { + then(e) + } else { + for _, els := range elses { + els(e) + } + } + return e +} + +func (e *Expr) Or(or func(ex *Expr)) *Expr { + other := &Expr{} + or(other) + if len(other.clauses) == 0 { + return e + } + if len(e.clauses) == 0 { + e.clauses = other.clauses[:] + return e + } + e.clauses = []clause.Expression{ + clause.Or( + clause.And(e.clauses...), + clause.And(other.clauses...), + ), + } + return e +} + +func (e *Expr) And(and func(ex *Expr)) *Expr { + other := &Expr{} + and(other) + if len(other.clauses) == 0 { + return e + } + if len(e.clauses) == 0 { + e.clauses = other.clauses[:] + return e + } + e.clauses = []clause.Expression{ + clause.And( + clause.And(e.clauses...), + clause.And(other.clauses...), + ), + } + return e +} + +func (e *Expr) Not(not func(ex *Expr)) *Expr { + other := &Expr{} + not(other) + if len(other.clauses) == 0 { + return e + } + return e.add(clause.Not(other.clauses...)) +} + +func (e *Expr) Scopes(tx *gorm.DB) *gorm.DB { + if e.clauses != nil { + for _, express := range e.clauses { + tx = tx.Where(express) + } + } + return tx +} + +type null struct { + Column any +} + +func (n null) Build(builder clause.Builder) { + builder.WriteQuoted(n.Column) + builder.WriteString(" IS NULL") +} + +func (n null) NegationBuild(builder clause.Builder) { + builder.WriteQuoted(n.Column) + builder.WriteString(" IS NOT NULL") +} + +type between struct { + Column any + Less any + More any +} + +func (b between) Build(builder clause.Builder) { + b.build(builder, " BETWEEN ") +} + +func (b between) NegationBuild(builder clause.Builder) { + b.build(builder, " NOT BETWEEN ") +} + +func (b between) build(builder clause.Builder, op string) { + builder.WriteQuoted(b.Column) + builder.WriteString(op) + builder.AddVar(builder, b.Less) + builder.WriteString(" And ") + builder.AddVar(builder, b.More) +} diff --git a/pkg/db/log.go b/pkg/db/log.go new file mode 100644 index 0000000..ca9dc18 --- /dev/null +++ b/pkg/db/log.go @@ -0,0 +1,57 @@ +package db + +// +//import ( +// "context" +// "gorm.io/gorm/dbLogger" +// "io" +// "log" +// "os" +// "time" +//) +// +//type dbLogger struct { +// console dbLogger.Interface +// persist dbLogger.Interface +//} +// +//func NewLogger(persistWriter io.Writer) dbLogger.Interface { +// return &dbLogger{ +// console: dbLogger.New(log.New(os.Stdout, "", log.Ltime|log.Lmicroseconds), dbLogger.Config{ +// SlowThreshold: 200 * time.Millisecond, +// Colorful: true, +// LogLevel: dbLogger.Info, +// }), +// persist: dbLogger.New(log.New(persistWriter, "\r\n", log.LstdFlags), dbLogger.Config{ +// SlowThreshold: 200 * time.Millisecond, +// LogLevel: dbLogger.Info, +// }), +// } +//} +// +//func (l *dbLogger) LogMode(level dbLogger.LogLevel) dbLogger.Interface { +// c := *l +// c.console = c.console.LogMode(level) +// c.persist = c.persist.LogMode(level) +// return &c +//} +// +//func (l *dbLogger) Info(ctx context.Context, s string, i ...interface{}) { +// l.console.Info(ctx, s, i...) +// l.persist.Info(ctx, s, i...) +//} +// +//func (l *dbLogger) Warn(ctx context.Context, s string, i ...interface{}) { +// l.console.Warn(ctx, s, i...) +// l.persist.Warn(ctx, s, i...) +//} +// +//func (l *dbLogger) Error(ctx context.Context, s string, i ...interface{}) { +// l.console.Error(ctx, s, i...) +// l.persist.Error(ctx, s, i...) +//} +// +//func (l *dbLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { +// l.console.Trace(ctx, begin, fc, err) +// l.persist.Trace(ctx, begin, fc, err) +//} diff --git a/pkg/db/logger.go b/pkg/db/logger.go new file mode 100644 index 0000000..6d31b8b --- /dev/null +++ b/pkg/db/logger.go @@ -0,0 +1,67 @@ +package db + +import ( + "context" + "errors" + "fmt" + glog "gorm.io/gorm/logger" + "sorbet/pkg/log" + "time" +) + +type dbLogger struct { + SlowThreshold time.Duration +} + +// LogMode log mode +func (l *dbLogger) LogMode(level glog.LogLevel) glog.Interface { + return l +} + +// Info print info +func (l dbLogger) Info(ctx context.Context, msg string, data ...any) { + log.Info(msg, data...) +} + +// Warn print warn messages +func (l dbLogger) Warn(ctx context.Context, msg string, data ...any) { + log.Warn(msg, data...) +} + +// Error print error messages +func (l dbLogger) Error(ctx context.Context, msg string, data ...any) { + log.Error(msg, data...) +} + +// Trace print sql message +func (l dbLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + elapsed := time.Since(begin) + switch { + case err != nil && !errors.Is(err, glog.ErrRecordNotFound): + sql, rows := fc() + if rows == -1 { + log.Error("%s [rows:%v] %s [%.3fms]", err, "-", sql, float64(elapsed.Nanoseconds())/1e6) + } else { + log.Error("%s [rows:%v] %s [%.3fms]", err, rows, sql, float64(elapsed.Nanoseconds())/1e6) + } + case elapsed > l.SlowThreshold && l.SlowThreshold != 0: + sql, rows := fc() + slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) + if rows == -1 { + log.Warn("%s [rows:%v] %s [%.3fms]", slowLog, "-", sql, float64(elapsed.Nanoseconds())/1e6) + } else { + log.Warn("%s [rows:%v] %s [%.3fms]", slowLog, rows, sql, float64(elapsed.Nanoseconds())/1e6) + } + default: + sql, rows := fc() + if rows == -1 { + log.Trace("[rows:%v] %s [%.3fms]", "-", sql, float64(elapsed.Nanoseconds())/1e6, log.RawLevel("GORM")) + } else { + log.Trace("[rows:%v] %s [%.3fms]", rows, sql, float64(elapsed.Nanoseconds())/1e6, log.RawLevel("GORM")) + } + } +} + +func (l dbLogger) ParamsFilter(ctx context.Context, sql string, params ...any) (string, []any) { + return sql, params +} diff --git a/pkg/db/pager.go b/pkg/db/pager.go new file mode 100644 index 0000000..3a49c63 --- /dev/null +++ b/pkg/db/pager.go @@ -0,0 +1 @@ +package db diff --git a/pkg/db/query_builder.go b/pkg/db/query_builder.go new file mode 100644 index 0000000..a1380d3 --- /dev/null +++ b/pkg/db/query_builder.go @@ -0,0 +1,191 @@ +package db + +import ( + "database/sql" + "gorm.io/gorm" + "math" +) + +// QueryBuilder 查询构造器 +// TODO(hupeh):实现 joins 和表别名 +type QueryBuilder[T any] struct { + Expr + db *gorm.DB + selects []string + omits []string + orders []string + limit int + offset int + distinct []any + preloads []preload +} + +func NewQueryBuilder[T any](db *gorm.DB) *QueryBuilder[T] { + return &QueryBuilder[T]{Expr: Expr{}, db: db} +} + +type preload struct { + query string + args []any +} + +type Pager[T any] struct { + Total int `json:"total" xml:"total"` // 数据总数 + Page int `json:"page" xml:"page"` // 当前页码 + Limit int `json:"limit" xml:"limit"` // 数据容量 + Items []*T `json:"items" xml:"items"` // 数据列表 +} + +func (q *QueryBuilder[T]) Select(columns ...string) *QueryBuilder[T] { + q.selects = append(q.selects, columns...) + return q +} + +func (q *QueryBuilder[T]) Omit(columns ...string) *QueryBuilder[T] { + q.omits = append(q.omits, columns...) + return q +} + +func (q *QueryBuilder[T]) DescentBy(columns ...string) *QueryBuilder[T] { + for _, col := range columns { + q.orders = append(q.orders, col+" DESC") + } + return q +} + +func (q *QueryBuilder[T]) AscentBy(columns ...string) *QueryBuilder[T] { + for _, col := range columns { + q.orders = append(q.orders, col) + } + return q +} + +func (q *QueryBuilder[T]) Limit(limit int) *QueryBuilder[T] { + q.limit = limit + return q +} + +func (q *QueryBuilder[T]) Offset(offset int) *QueryBuilder[T] { + q.offset = offset + return q +} + +func (q *QueryBuilder[T]) Distinct(columns ...any) *QueryBuilder[T] { + q.distinct = append(q.distinct, columns...) + return q +} + +func (q *QueryBuilder[T]) Preload(query string, args ...any) *QueryBuilder[T] { + q.preloads = append(q.preloads, preload{query, args}) + return q +} + +func (q *QueryBuilder[T]) Scopes(tx *gorm.DB) *gorm.DB { + tx = q.scopesWithoutEffect(tx) + if q.orders != nil { + for _, order := range q.orders { + tx = tx.Order(order) + } + } + if q.limit > 0 { + tx = tx.Limit(q.limit) + } + if q.offset > 0 { + tx = tx.Offset(q.offset) + } + if q.preloads != nil { + for _, pl := range q.preloads { + tx = tx.Preload(pl.query, pl.args...) + } + } + return tx +} + +func (q *QueryBuilder[T]) scopesWithoutEffect(tx *gorm.DB) *gorm.DB { + var entity T + tx = tx.Model(&entity) + if q.selects != nil { + tx = tx.Select(q.selects) + } + if q.omits != nil { + tx = tx.Omit(q.omits...) + } + if len(q.distinct) > 0 { + tx = tx.Distinct(q.distinct...) + } + return q.Expr.Scopes(tx) +} + +func (q *QueryBuilder[T]) Count() (int64, error) { + var count int64 + err := q.db.Scopes(q.scopesWithoutEffect).Count(&count).Error + return count, err +} + +func (q *QueryBuilder[T]) First(entity any) error { + return q.db.Scopes(q.Scopes).First(entity).Error +} + +func (q *QueryBuilder[T]) Take(entity any) error { + return q.db.Scopes(q.Scopes).Take(entity).Error +} + +func (q *QueryBuilder[T]) Last(entity any) error { + return q.db.Scopes(q.Scopes).Last(entity).Error +} + +func (q *QueryBuilder[T]) Find(entity any) error { + return q.db.Scopes(q.Scopes).Find(entity).Error +} + +func (q *QueryBuilder[T]) Paginate() (*Pager[T], error) { + if q.limit <= 0 { + q.limit = 30 + } + if q.offset < 0 { + q.offset = 0 + } + count, err := q.Count() + if err != nil { + return nil, err + } + var items []*T + err = q.Find(&items) + if err != nil { + return nil, err + } + return &Pager[T]{ + Total: int(count), + Page: int(math.Ceil(float64(q.offset)/float64(q.limit))) + 1, + Limit: q.limit, + Items: items, + }, nil +} + +// Rows 返回行数据迭代器 +// +// 使用示例: +// +// rows, err := q.Eq("name", "jack").Rows() +// if err != nil { +// panic(err) +// } +// defer rows.Close() +// for rows.Next() { +// var user User +// db.ScanRows(rows, &user) +// // do something +// } +func (q *QueryBuilder[T]) Rows() (*sql.Rows, error) { + return q.db.Scopes(q.Scopes).Rows() +} + +// Pluck 获取指定列的值 +// +// 示例: +// +// var names []string +// q.Pluck("name", &names) +func (q *QueryBuilder[T]) Pluck(column string, dest any) error { + return q.db.Scopes(q.Scopes).Pluck(column, dest).Error +} diff --git a/pkg/db/repo.go b/pkg/db/repo.go new file mode 100644 index 0000000..3a49c63 --- /dev/null +++ b/pkg/db/repo.go @@ -0,0 +1 @@ +package db diff --git a/pkg/db/repository.go b/pkg/db/repository.go new file mode 100644 index 0000000..458dda2 --- /dev/null +++ b/pkg/db/repository.go @@ -0,0 +1,98 @@ +package db + +import ( + "gorm.io/gorm" +) + +type Repository[T any] struct { + db *gorm.DB + pk string // 默认 id +} + +func NewRepository[T any](db ...*gorm.DB) *Repository[T] { + for _, d := range db { + return NewRepositoryWith[T](d) + } + return NewRepositoryWith[T](DB()) +} + +func NewRepositoryWith[T any](db *gorm.DB, pk ...string) *Repository[T] { + r := &Repository[T]{db: db, pk: "id"} + for _, s := range pk { + if s != "" { + r.pk = s + return r + } + } + return r +} + +// Create 创建数据 +func (r *Repository[T]) Create(entity *T) error { + return r.db.Model(&entity).Create(&entity).Error +} + +func (r *Repository[T]) Delete(expr *Expr) (int64, error) { + var entity T + res := r.db.Model(&entity).Scopes(expr.Scopes).Delete(&entity) + return res.RowsAffected, res.Error +} + +func (r *Repository[T]) DeleteByID(id any) error { + var entity T + return r.db.Delete(&entity, r.pk, id).Error +} + +func (r *Repository[T]) Update(expr *Expr, values map[string]any) (int64, error) { + res := r.db.Scopes(expr.Scopes).Updates(values) + return res.RowsAffected, res.Error +} + +func (r *Repository[T]) UpdateByID(id any, values map[string]any) error { + var entity T + return r.db.Model(&entity).Where(r.pk, id).Updates(values).Error +} + +func (r *Repository[T]) GetByID(id any) (*T, error) { + var entity T + err := r.db.Model(&entity).Where(r.pk, id).First(&entity).Error + if err != nil { + return nil, err + } + return &entity, nil +} + +func (r *Repository[T]) Find(expr ...*Expr) ([]*T, error) { + var entity T + var items []*T + err := r.db.Model(&entity).Scopes(func(tx *gorm.DB) *gorm.DB { + for _, e := range expr { + tx = e.Scopes(tx) + } + return tx + }).Find(&items).Error + if err != nil { + return nil, err + } + return items, nil +} + +func (r *Repository[T]) Paginate(expr ...*Expr) (*Pager[T], error) { + qb := NewQueryBuilder[T](r.db) + for _, e := range expr { + qb.Expr = *e + } + return qb.Paginate() +} + +func (r *Repository[T]) NewDeleteBuilder() *DeleteBuilder[T] { + return NewDeleteBuilder[T](r.db) +} + +func (r *Repository[T]) NewUpdateBuilder() *UpdateBuilder[T] { + return NewUpdateBuilder[T](r.db) +} + +func (r *Repository[T]) NewQueryBuilder() *QueryBuilder[T] { + return NewQueryBuilder[T](r.db) +} diff --git a/pkg/db/types.go b/pkg/db/types.go new file mode 100644 index 0000000..37cb90d --- /dev/null +++ b/pkg/db/types.go @@ -0,0 +1,67 @@ +package db + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "strconv" + "time" +) + +type NullTime struct { + sql.NullTime +} + +func (v NullTime) MarshalJSON() ([]byte, error) { + if v.Valid { + return json.Marshal(v.Time) + } else { + return json.Marshal(nil) + } +} + +func (v *NullTime) UnmarshalJSON(data []byte) error { + var s *time.Time + if err := json.Unmarshal(data, &s); err != nil { + return err + } + if s != nil { + v.Valid = true + v.Time = *s + } else { + v.Valid = false + } + return nil +} + +type NullInt64 sql.NullInt64 + +func (v *NullInt64) Scan(value interface{}) error { + return (*sql.NullInt64)(v).Scan(value) +} + +func (v NullInt64) Value() (driver.Value, error) { + if !v.Valid { + return nil, nil + } + return v.Int64, nil +} + +func (v *NullInt64) UnmarshalJSON(bytes []byte) error { + if string(bytes) == "null" { + v.Valid = false + return nil + } + err := json.Unmarshal(bytes, &v.Int64) + if err == nil { + v.Valid = true + } + return err +} + +func (v NullInt64) MarshalJSON() ([]byte, error) { + if v.Valid { + return strconv.AppendInt(nil, v.Int64, 10), nil + } + return []byte("null"), nil +} diff --git a/pkg/db/update_builder.go b/pkg/db/update_builder.go new file mode 100644 index 0000000..9d12364 --- /dev/null +++ b/pkg/db/update_builder.go @@ -0,0 +1,66 @@ +package db + +import ( + "errors" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +type UpdateBuilder[T any] struct { + Expr + db *gorm.DB + selects []string + omits []string + onConflict *clause.OnConflict +} + +func NewUpdateBuilder[T any](db *gorm.DB) *UpdateBuilder[T] { + return &UpdateBuilder[T]{Expr: Expr{}, db: db} +} + +func (b *UpdateBuilder[T]) Select(columns ...string) *UpdateBuilder[T] { + b.selects = append(b.selects, columns...) + return b +} + +func (b *UpdateBuilder[T]) Omit(columns ...string) *UpdateBuilder[T] { + b.omits = append(b.omits, columns...) + return b +} + +func (b *UpdateBuilder[T]) OnConflict(conflict clause.OnConflict) *UpdateBuilder[T] { + if b.onConflict == nil { + b.onConflict = &conflict + } else { + b.onConflict.Columns = conflict.Columns + b.onConflict.Where = conflict.Where + b.onConflict.TargetWhere = conflict.TargetWhere + b.onConflict.OnConstraint = conflict.OnConstraint + b.onConflict.DoNothing = conflict.DoNothing + b.onConflict.DoUpdates = conflict.DoUpdates + b.onConflict.UpdateAll = conflict.UpdateAll + } + return b +} + +func (b *UpdateBuilder[T]) Scopes(tx *gorm.DB) *gorm.DB { + if b.selects != nil { + tx = tx.Select(b.selects) + } + if b.omits != nil { + tx = tx.Omit(b.omits...) + } + return b.Expr.Scopes(tx) +} + +func (b *UpdateBuilder[T]) Commit(values map[string]any) (int64, error) { + var entity T + res := b.db.Model(&entity).Scopes(b.Scopes).Updates(values) + if err := res.Error; err != nil { + return res.RowsAffected, err + } + if res.RowsAffected == 0 { + return 0, errors.New("no record updated") + } + return res.RowsAffected, nil +} diff --git a/pkg/env/env.go b/pkg/env/env.go new file mode 100644 index 0000000..5f43704 --- /dev/null +++ b/pkg/env/env.go @@ -0,0 +1,153 @@ +package env + +import ( + "errors" + "os" + "path/filepath" + "strings" + "time" +) + +// 缓存的环境变量 +var env = New() + +// Init 加载运行目录下的 .env 文件 +func Init() error { + if path, err := filepath.Abs("."); err != nil { + return err + } else { + return InitWithDir(path) + } +} + +// InitWithDir 加载指定录下的 .env 文件 +func InitWithDir(dir string) error { + // 重置缓存的环境变量 + clear(env) + + // 加载系统的环境变量 + for _, value := range os.Environ() { + parts := strings.SplitN(value, "=", 2) + key := strings.TrimSpace(parts[0]) + val := strings.TrimSpace(parts[1]) + env[key] = val + } + + // 将 windows 系统风格的目录分隔符替换为 Unix 风格, + // 并且排除尾部的目录分隔符。 + dir = strings.TrimSuffix(strings.ReplaceAll(dir, "\\", "/"), "/") + + // 加载 .env 文件 + if err := loadEnv(dir, ""); err != nil { + return err + } + + // 加载 .env.{APP_ENV} 文件 + if appEnv := String("APP_ENV", "prod"); len(appEnv) > 0 { + if err := loadEnv(dir, "."+strings.ToLower(appEnv)); err != nil { + return err + } + } + + return nil +} + +func loadEnv(dir, env string) error { + if err := Load(dir + "/.env" + env); err != nil { + if !errors.Is(err, os.ErrNotExist) { + return err + } + } + + if err := Load(dir + "/.env" + env + ".local"); err != nil { + if !errors.Is(err, os.ErrNotExist) { + return err + } + } + + return nil +} + +// Load 加载指定的环境变量文件 +func Load(filenames ...string) error { + return env.Load(filenames...) +} + +// Is 判断是不是期望的值 +func Is(name string, value any) bool { + switch value.(type) { + case bool: + return Bool(name) == value + case int: + return Int(name) == value + case time.Duration: + return Duration(name) == value + default: + return String(name) == value + } +} + +// IsEnv 判断应用环境是否与给出的一致 +func IsEnv(env string) bool { + return Is("APP_ENV", env) +} + +// Lookup 查看配置 +func Lookup(name string) (string, bool) { + return env.Lookup(name) +} + +// Exists 配置是否存在 +func Exists(name string) bool { + return env.Exists(name) +} + +// String 取字符串值 +func String(name string, value ...string) string { + return env.String(name, value...) +} + +// Bytes 取二进制值 +func Bytes(name string, value ...[]byte) []byte { + return env.Bytes(name, value...) +} + +// Int 取整型值 +func Int(name string, value ...int) int { + return env.Int(name, value...) +} + +func Duration(name string, value ...time.Duration) time.Duration { + return env.Duration(name, value...) +} + +func Bool(name string, value ...bool) bool { + return env.Bool(name, value...) +} + +// List 将值按 `,` 分割并返回 +func List(name string, fallback ...[]string) []string { + return env.List(name, fallback...) +} + +func Map(prefix string) map[string]string { + return env.Map(prefix) +} + +func Where(filter func(name string, value string) bool) map[string]string { + return env.Where(filter) +} + +// Fill 将环境变量填充到指定结构体 +func Fill(structure any) error { + return env.Fill(structure) +} + +// All 返回所有值 +func All() map[string]string { + clone := make(map[string]string) + for key, value := range env { + clone[key] = value + } + return clone +} diff --git a/pkg/env/environ.go b/pkg/env/environ.go new file mode 100644 index 0000000..2d852ca --- /dev/null +++ b/pkg/env/environ.go @@ -0,0 +1,186 @@ +package env + +import ( + "errors" + "fmt" + "github.com/joho/godotenv" + "reflect" + "sorbet/pkg/cast" + "strconv" + "strings" + "time" + "unsafe" +) + +type Environ map[string]string + +func New() Environ { + return make(Environ) +} + +// Load 加载环境变量文件 +func (e Environ) Load(filenames ...string) error { + data, err := godotenv.Read(filenames...) + if err != nil { + return err + } + for key, value := range data { + e[key] = value + } + return nil +} + +// Lookup 查看环境变量值,如果不存在或值为空,返回的第二个参数的值则为false。 +func (e Environ) Lookup(key string) (string, bool) { + val, exists := e[key] + if exists && len(val) == 0 { + exists = false + } + return val, exists +} + +// Exists 判断环境变量是否存在 +func (e Environ) Exists(key string) bool { + _, exists := e[key] + return exists +} + +// String 取字符串值 +func (e Environ) String(key string, fallback ...string) string { + if value, exists := e.Lookup(key); exists { + return value + } + for _, value := range fallback { + return value + } + return "" +} + +// Bytes 取二进制值 +func (e Environ) Bytes(key string, fallback ...[]byte) []byte { + if value, exists := e.Lookup(key); exists { + return []byte(value) + } + for _, bytes := range fallback { + return bytes + } + return []byte{} +} + +// Int 取整型值 +func (e Environ) Int(key string, fallback ...int) int { + if val, exists := e.Lookup(key); exists { + if n, err := strconv.Atoi(val); err == nil { + return n + } + } + for _, value := range fallback { + return value + } + return 0 +} + +func (e Environ) Duration(key string, fallback ...time.Duration) time.Duration { + if val, ok := e.Lookup(key); ok { + n, err := strconv.Atoi(val) + if err == nil { + return time.Duration(n) + } + if d, err := time.ParseDuration(val); err == nil { + return d + } + } + for _, value := range fallback { + return value + } + return time.Duration(0) +} + +func (e Environ) Bool(key string, fallback ...bool) bool { + if val, ok := e.Lookup(key); ok { + bl, err := strconv.ParseBool(val) + if err == nil { + return bl + } + } + for _, value := range fallback { + return value + } + return false +} + +// List 将值按 `,` 分割并返回 +func (e Environ) List(key string, fallback ...[]string) []string { + if value, ok := e.Lookup(key); ok { + parts := strings.Split(value, ",") + for i, part := range parts { + parts[i] = strings.Trim(part, " \n\r") + } + return parts + } + for _, value := range fallback { + return value + } + return []string{} +} + +// Map 获取指定前缀的所有值 +func (e Environ) Map(prefix string) map[string]string { + result := map[string]string{} + for k, v := range e { + if strings.HasPrefix(k, prefix) { + name := strings.TrimPrefix(k, prefix) + result[name] = strings.TrimSpace(v) + } + } + return result +} + +// Where 获取符合过滤器的所有值 +func (e Environ) Where(filter func(name, value string) bool) map[string]string { + result := map[string]string{} + for k, v := range e { + if filter(k, v) { + result[k] = v + } + } + return result +} + +// Fill 将环境变量填充到指定结构体 +func (e Environ) Fill(structure any) error { + inputType := reflect.TypeOf(structure) + + if inputType != nil && inputType.Kind() == reflect.Ptr && inputType.Elem().Kind() == reflect.Struct { + return e.fillStruct(reflect.ValueOf(structure).Elem()) + } + + return errors.New("env: invalid structure") +} + +func (e Environ) fillStruct(s reflect.Value) error { + for i := 0; i < s.NumField(); i++ { + if t, exist := s.Type().Field(i).Tag.Lookup("env"); exist { + if osv := e[t]; osv != "" { + v, err := cast.FromType(osv, s.Type().Field(i).Type) + if err != nil { + return fmt.Errorf("env: cannot set `%v` field; err: %v", s.Type().Field(i).Name, err) + } + ptr := reflect.NewAt(s.Field(i).Type(), unsafe.Pointer(s.Field(i).UnsafeAddr())).Elem() + ptr.Set(reflect.ValueOf(v)) + } + } else if s.Type().Field(i).Type.Kind() == reflect.Struct { + if err := e.fillStruct(s.Field(i)); err != nil { + return err + } + } else if s.Type().Field(i).Type.Kind() == reflect.Ptr { + if s.Field(i).IsZero() == false && s.Field(i).Elem().Type().Kind() == reflect.Struct { + if err := e.fillStruct(s.Field(i).Elem()); err != nil { + return err + } + } + } + } + + return nil +} diff --git a/pkg/ioc/container.go b/pkg/ioc/container.go new file mode 100644 index 0000000..dc0128c --- /dev/null +++ b/pkg/ioc/container.go @@ -0,0 +1,228 @@ +package ioc + +import ( + "errors" + "fmt" + "reflect" +) + +type binding struct { + name string + typ reflect.Type + resolver any + shared bool +} + +func (b *binding) make(c *Container) (reflect.Value, error) { + if v, exists := c.instances[b.typ][b.name]; exists { + return v, nil + } + val, err := c.Invoke(b.resolver) + if err != nil { + return reflect.Value{}, err + } + rv := val[0] + if len(val) == 2 { + err = val[1].Interface().(error) + if err != nil { + return reflect.Value{}, err + } + } + if b.shared { + if _, exists := c.instances[b.typ]; !exists { + c.instances[b.typ] = make(map[string]reflect.Value) + } + c.instances[b.typ][b.name] = rv + } + return rv, nil +} + +func (b *binding) mustMake(c *Container) reflect.Value { + val, err := b.make(c) + if err != nil { + panic(err) + } + return val +} + +type Container struct { + // 注册的工厂函数 + factories map[reflect.Type]map[string]*binding + // 注册的共享实例 + instances map[reflect.Type]map[string]reflect.Value + parent *Container +} + +func New() *Container { + return &Container{ + factories: make(map[reflect.Type]map[string]*binding), + instances: make(map[reflect.Type]map[string]reflect.Value), + parent: nil, + } +} + +// Fork 分支 +func (c *Container) Fork() *Container { + ioc := New() + ioc.parent = c + return ioc +} + +// Bind 绑定值到容器,有效类型: +// - 接口的具体实现值 +// - 结构体的实例 +// - 类型的值(尽量不要使用原始类型,而应该使用元素类型的变体) +func (c *Container) Bind(instance any) { + c.NamedBind("", instance) +} + +// NamedBind 绑定具名值到容器 +func (c *Container) NamedBind(name string, instance any) { + //typ := InterfaceOf(instance) + typ := reflect.TypeOf(instance) + if _, ok := c.instances[typ]; !ok { + c.instances[typ] = make(map[string]reflect.Value) + } + c.instances[typ][name] = reflect.ValueOf(instance) +} + +// Factory 绑定工厂函数 +func (c *Container) Factory(factory any, shared ...bool) error { + return c.NamedFactory("", factory, shared...) +} + +// NamedFactory 绑定具名工厂函数 +func (c *Container) NamedFactory(name string, factory any, shared ...bool) error { + reflectedFactory := reflect.TypeOf(factory) + if reflectedFactory.Kind() != reflect.Func { + return errors.New("container: the factory must be a function") + } + if returnCount := reflectedFactory.NumOut(); returnCount == 0 || returnCount > 2 { + return errors.New("container: factory function signature is invalid - it must return abstract, or abstract and error") + } + // TODO(hupeh): 验证第二个参数必须是 error 接口 + concreteType := reflectedFactory.Out(0) + for i := 0; i < reflectedFactory.NumIn(); i++ { + // 循环依赖 + if reflectedFactory.In(i) == concreteType { + return fmt.Errorf("container: factory function signature is invalid - depends on abstract it returns") + } + } + if _, exists := c.factories[concreteType]; !exists { + c.factories[concreteType] = make(map[string]*binding) + } + bd := &binding{ + name: name, + typ: concreteType, + resolver: factory, + shared: false, + } + for _, b := range shared { + bd.shared = b + } + c.factories[concreteType][name] = bd + return nil +} + +// Resolve 完成的注入 +func (c *Container) Resolve(i any) error { + v := reflect.ValueOf(i) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Struct { + return errors.New("must given a struct") + } + t := v.Type() + for i := 0; i < v.NumField(); i++ { + f := v.Field(i) + structField := t.Field(i) + inject, willInject := structField.Tag.Lookup("inject") + if !f.CanSet() { + if willInject { + return fmt.Errorf("container: cannot make %v field", t.Field(i).Name) + } + continue + } + ft := f.Type() + fv := c.NamedGet(inject, ft) + if !fv.IsValid() { + return fmt.Errorf("value not found for type %v", ft) + } + f.Set(fv) + } + return nil +} + +// Get 获取指定类型的值 +func (c *Container) Get(t reflect.Type) reflect.Value { + return c.NamedGet("", t) +} + +// NamedGet 通过注入的名称获取指定类型的值 +func (c *Container) NamedGet(name string, t reflect.Type) reflect.Value { + val, exists := c.instances[t][name] + + if exists && val.IsValid() { + return val + } + + if factory, exists := c.factories[t][name]; exists { + val = factory.mustMake(c) + } + + if val.IsValid() || t.Kind() != reflect.Interface { + goto RESULT + } + + // 使用共享值里面该接口的实现者 + for k, v := range c.instances { + if k.Implements(t) { + for _, value := range v { + if value.IsValid() { + val = value + goto RESULT + } + } + break + } + } + + // 使用工厂函数里面该接口的实现者 + for k, v := range c.factories { + if k.Implements(t) { + for _, bd := range v { + if x := bd.mustMake(c); x.IsValid() { + val = x + goto RESULT + } + } + break + } + } + +RESULT: + if !val.IsValid() && c.parent != nil { + val = c.parent.NamedGet(name, t) + } + return val +} + +// Invoke 执行函数 +func (c *Container) Invoke(f any) ([]reflect.Value, error) { + t := reflect.TypeOf(f) + if t.Kind() != reflect.Func { + return nil, errors.New("container: invalid function") + } + + var in = make([]reflect.Value, t.NumIn()) //Panic if t is not kind of Func + for i := 0; i < t.NumIn(); i++ { + argType := t.In(i) + val := c.Get(argType) + if !val.IsValid() { + return nil, fmt.Errorf("value not found for type %v", argType) + } + in[i] = val + } + return reflect.ValueOf(f).Call(in), nil +} diff --git a/pkg/ioc/ioc.go b/pkg/ioc/ioc.go new file mode 100644 index 0000000..9320821 --- /dev/null +++ b/pkg/ioc/ioc.go @@ -0,0 +1,97 @@ +package ioc + +import ( + "context" + "errors" + "reflect" +) + +var ( + ErrValueNotFound = errors.New("ioc: value not found") +) + +var global = New() + +// Fork 分支 +func Fork() *Container { + return global.Fork() +} + +// Bind 绑定值到容器,有效类型: +// +// - 接口的具体实现值 +// - 结构体的实例 +// - 类型的值(尽量不要使用原始类型,而应该使用元素类型的变体) +func Bind(instance any) { + global.Bind(instance) +} + +// NamedBind 绑定具名值到容器 +func NamedBind(name string, instance any) { + global.NamedBind(name, instance) +} + +// Factory 绑定工厂函数 +func Factory(factory any, shared ...bool) error { + return global.Factory(factory, shared...) +} + +func MustFactory(factory any, shared ...bool) { + err := Factory(factory, shared...) + if err != nil { + panic(err) + } +} + +// NamedFactory 绑定具名工厂函数 +func NamedFactory(name string, factory any, shared ...bool) error { + return global.NamedFactory(name, factory, shared...) +} + +func MustNamedFactory(name string, factory any, shared ...bool) { + err := NamedFactory(name, factory, shared...) + if err != nil { + panic(err) + } +} + +// Resolve 完成的注入 +func Resolve(i any) error { + return global.Resolve(i) +} + +// Get 获取指定类型的值 +func Get[T any](ctx context.Context) (*T, error) { + return NamedGet[T](ctx, "") +} + +func MustGet[T any](ctx context.Context) *T { + return MustNamedGet[T](ctx, "") +} + +// NamedGet 通过注入的名称获取指定类型的值 +func NamedGet[T any](ctx context.Context, name string) (*T, error) { + var abs T + t := reflect.TypeOf(&abs) + v := global.NamedGet(name, t) + if !v.IsValid() { + return nil, ErrValueNotFound + } + if x, ok := v.Interface().(*T); ok { + return x, nil + } + return nil, ErrValueNotFound +} + +func MustNamedGet[T any](ctx context.Context, name string) *T { + v, err := NamedGet[T](ctx, name) + if err != nil { + panic(err) + } + return v +} + +// Invoke 执行函数 +func Invoke(f any) ([]reflect.Value, error) { + return global.Invoke(f) +} diff --git a/pkg/ioc/util.go b/pkg/ioc/util.go new file mode 100644 index 0000000..28fb498 --- /dev/null +++ b/pkg/ioc/util.go @@ -0,0 +1,19 @@ +package ioc + +import "reflect" + +// InterfaceOf dereferences a pointer to an Interface type. +// It panics if value is not a pointer to an interface. +func InterfaceOf(value interface{}) reflect.Type { + t := reflect.TypeOf(value) + + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if t.Kind() != reflect.Interface { + panic("the value is not a pointer to an interface. (*MyInterface)(nil)") + } + + return t +} diff --git a/pkg/log/attr.go b/pkg/log/attr.go new file mode 100644 index 0000000..5ea492d --- /dev/null +++ b/pkg/log/attr.go @@ -0,0 +1,80 @@ +package log + +import ( + "log/slog" + "time" +) + +const ( + rawTimeKey = "@rawTime" + rawLevelKey = "@rawLevel" +) + +type Attr = slog.Attr + +func RawTime(value time.Time) Attr { + return Time(rawTimeKey, value) +} + +func RawLevel(value string) Attr { + return String(rawLevelKey, value) +} + +// String returns an Attr for a string value. +func String(key, value string) Attr { + return slog.String(key, value) +} + +// Int64 returns an Attr for an int64. +func Int64(key string, value int64) Attr { + return slog.Int64(key, value) +} + +// Int converts an int to an int64 and returns +// an Attr with that value. +func Int(key string, value int) Attr { + return slog.Int(key, value) +} + +// Uint64 returns an Attr for a uint64. +func Uint64(key string, v uint64) Attr { + return slog.Uint64(key, v) +} + +// Float64 returns an Attr for a floating-point number. +func Float64(key string, v float64) Attr { + return slog.Float64(key, v) +} + +// Bool returns an Attr for a bool. +func Bool(key string, v bool) Attr { + return slog.Bool(key, v) +} + +// Time returns an Attr for a time.Time. +// It discards the monotonic portion. +func Time(key string, v time.Time) Attr { + return slog.Time(key, v) +} + +// Duration returns an Attr for a time.Duration. +func Duration(key string, v time.Duration) Attr { + return slog.Duration(key, v) +} + +// Group returns an Attr for a Group Instance. +// The first argument is the key; the remaining arguments +// are converted to Attrs as in [Logger.Log]. +// +// Use Group to collect several key-value pairs under a single +// key on a log line, or as the result of LogValue +// in order to log a single value as multiple Attrs. +func Group(key string, args ...any) Attr { + return slog.Group(key, args...) +} + +// Any returns an Attr for the supplied value. +// See [AnyValue] for how values are treated. +func Any(key string, value any) Attr { + return slog.Any(key, value) +} diff --git a/pkg/log/colorer.go b/pkg/log/colorer.go new file mode 100644 index 0000000..d2122ce --- /dev/null +++ b/pkg/log/colorer.go @@ -0,0 +1,71 @@ +package log + +import ( + "github.com/mattn/go-isatty" + "os" +) + +type Colorer interface { + Black(s string) string + Red(s string) string + Green(s string) string + Yellow(s string) string + Blue(s string) string + Magenta(s string) string + Cyan(s string) string + White(s string) string + Grey(s string) string +} + +func NewColorer() Colorer { + return &colorer{} +} + +type clr string + +const ( + black clr = "\x1b[30m" + red clr = "\x1b[31m" + green clr = "\x1b[32m" + yellow clr = "\x1b[33m" + blue clr = "\x1b[34m" + magenta clr = "\x1b[35m" + cyan clr = "\x1b[36m" + white clr = "\x1b[37m" + grey clr = "\x1b[90m" +) + +var ( + // NoColor defines if the output is colorized or not. It's dynamically set to + // false or true based on the stdout's file descriptor referring to a terminal + // or not. It's also set to true if the NO_COLOR environment variable is + // set (regardless of its value). This is a global option and affects all + // colors. For more control over each Color block use the methods + // DisableColor() individually. + noColor = noColorIsSet() || os.Getenv("TERM") == "dumb" || + (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) +) + +// noColorIsSet returns true if the environment variable NO_COLOR is set to a non-empty string. +func noColorIsSet() bool { + return os.Getenv("NO_COLOR") != "" +} + +func (c clr) f(s string) string { + if noColorIsSet() || noColor { + return s + } + return string(c) + s + "\x1b[0m" +} + +type colorer struct{} + +func (c *colorer) Black(s string) string { return black.f(s) } +func (c *colorer) Red(s string) string { return red.f(s) } +func (c *colorer) Green(s string) string { return green.f(s) } +func (c *colorer) Yellow(s string) string { return yellow.f(s) } +func (c *colorer) Blue(s string) string { return blue.f(s) } +func (c *colorer) Magenta(s string) string { return magenta.f(s) } +func (c *colorer) Cyan(s string) string { return cyan.f(s) } +func (c *colorer) White(s string) string { return white.f(s) } +func (c *colorer) Grey(s string) string { return grey.f(s) } diff --git a/pkg/log/handler.go b/pkg/log/handler.go new file mode 100644 index 0000000..b95acff --- /dev/null +++ b/pkg/log/handler.go @@ -0,0 +1,164 @@ +package log + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "log" + "log/slog" + "runtime" + "strings" + "sync/atomic" + "time" + "unsafe" +) + +type Handler struct { + ctx unsafe.Pointer // 结构体 logger + color Colorer + slog.Handler + attrs []Attr + l *log.Logger +} + +func (h *Handler) WithAttrs(attrs []Attr) slog.Handler { + return &Handler{ + ctx: h.ctx, + color: h.color, + Handler: h.Handler.WithAttrs(attrs), + attrs: append(h.attrs, attrs...), + l: h.l, + } +} + +func (h *Handler) WithGroup(name string) slog.Handler { + return &Handler{ + ctx: h.ctx, + color: h.color, + Handler: h.Handler.WithGroup(name), + l: h.l, + } +} + +func (h *Handler) Handle(ctx context.Context, r slog.Record) error { + // 参考 atomic.Pointer#Load 实现 + l := (*logger)(atomic.LoadPointer(&h.ctx)) + if atomic.LoadInt32(&l.discard) != 1 { + if len(h.attrs) > 0 { + x := slog.NewRecord(r.Time, r.Level, r.Message, r.PC) + r.Attrs(func(attr slog.Attr) bool { + x.AddAttrs(attr) + return true + }) + x.AddAttrs(h.attrs...) + if err := h.Print(l, x); err != nil { + return err + } + } else if err := h.Print(l, r); err != nil { + return err + } + } + if atomic.LoadInt32(&l.unpersist) != 1 { + return h.Handler.Handle(ctx, r) + } + return nil +} + +func (h *Handler) Print(l *logger, r slog.Record) error { + flags := l.Flags() + colorful := flags&Lcolor != 0 + level := parseSlogLevel(r.Level) + levelStr := level.String() + var fields map[string]any + if numAttrs := r.NumAttrs(); numAttrs > 0 { + fields = make(map[string]any, numAttrs) + r.Attrs(func(a Attr) bool { + switch a.Key { + case rawTimeKey: + r.Time = a.Value.Any().(time.Time) + case rawLevelKey: + levelStr = a.Value.Any().(string) + default: + fields[a.Key] = a.Value.Any() + } + return true + }) + } + var output string + var sep string + write := func(s string) { + if sep == "" { + sep = " " + } else { + output += sep + } + output += s + } + if flags&(Ldate|Ltime|Lmicroseconds) != 0 { + t := r.Time.In(l.Timezone()) // todo 原子操作 + if flags&Ldate != 0 { + write(t.Format("2006/01/02")) + } + if flags&(Ltime|Lmicroseconds) != 0 { + if flags&Lmicroseconds != 0 { + write(t.Format("15:04:05.000")) + } else { + write(t.Format("15:04:05")) + } + } + } + if colorful { + // TODO(hupeh): 重新设计不同颜色 + colorize := identify + switch level { + case LevelDebug: + colorize = h.color.Cyan + case LevelInfo: + colorize = h.color.Blue + case LevelWarn: + colorize = h.color.Yellow + case LevelError: + colorize = h.color.Red + case LevelFatal, LevelPanic: + colorize = h.color.Magenta + } + write(h.color.Grey("[") + colorize(levelStr) + h.color.Grey("]")) + write(colorize(r.Message)) + } else { + write("[" + levelStr + "]") + write(r.Message) + } + if flags&(Lshortfile|Llongfile) != 0 && r.PC > 0 { + var fileStr string + fs := runtime.CallersFrames([]uintptr{r.PC}) + f, _ := fs.Next() + file := f.File + if flags&Lshortfile != 0 { + i := strings.LastIndexAny(file, "\\/") + if i > -1 { + file = file[i+1:] + } + } + fileStr = fmt.Sprintf("%s:%s:%d", f.Function, file, f.Line) + if colorful { + fileStr = h.color.Green(fileStr) + } + write(fileStr) + } + if flags&Lfields != 0 && len(fields) > 0 { + b, err := json.Marshal(fields) + if err != nil { + return err + } + fieldsStr := string(bytes.TrimSpace(b)) + if fieldsStr != "" { + if colorful { + fieldsStr = h.color.White(fieldsStr) + } + write(fieldsStr) + } + } + h.l.Println(strings.TrimSpace(output)) + return nil +} diff --git a/pkg/log/level.go b/pkg/log/level.go new file mode 100644 index 0000000..4b3921c --- /dev/null +++ b/pkg/log/level.go @@ -0,0 +1,62 @@ +package log + +import ( + "log/slog" +) + +type Level int + +const ( + LevelTrace Level = iota + LevelDebug // 用于程序调试 + LevelInfo // 用于程序运行 + LevelWarn // 潜在错误或非预期结果 + LevelError // 发生错误,但不影响系统的继续运行 + LevelFatal + LevelPanic + LevelOff +) + +// 越界取近值 +func (l Level) real() Level { + return min(LevelOff, max(l, LevelTrace)) +} + +// Level 实现 slog.Leveler 接口 +func (l Level) Level() slog.Level { + return slog.Level(20 - int(LevelOff-l.real())*4) +} + +func (l Level) slog() slog.Leveler { + return l.Level() +} + +func (l Level) String() string { + switch l { + case LevelTrace: + return "TRACE" + case LevelDebug: + return "DEBUG" + case LevelInfo: + return "INFO" + case LevelWarn: + return "WARN" + case LevelError: + return "ERROR" + case LevelFatal: + return "FATAL" + case LevelPanic: + return "PANIC" + case LevelOff: + return "OFF" + } + if l < LevelTrace { + return "TRACE" + } else { + return "OFF" + } +} + +func parseSlogLevel(level slog.Level) Level { + return Level(level/4 + 2).real() +} diff --git a/pkg/log/log.go b/pkg/log/log.go new file mode 100644 index 0000000..e5fe588 --- /dev/null +++ b/pkg/log/log.go @@ -0,0 +1,97 @@ +package log + +import ( + "io" + "time" +) + +var std = New(&Options{Level: LevelTrace}) + +func Default() Logger { return std } + +func Flags() int { + return std.Flags() +} + +func SetFlags(flags int) { + std.SetFlags(flags) +} + +func LevelMode() Level { + return std.Level() +} + +func SetLevel(level Level) { + std.SetLevel(level) +} + +func Timezone() *time.Location { + return std.Timezone() +} + +func SetTimezone(loc *time.Location) { + std.SetTimezone(loc) +} + +func IgnorePC() bool { + return std.IgnorePC() +} + +func SetIgnorePC(ignore bool) { + std.SetIgnorePC(ignore) +} + +func Enabled(level Level) bool { + return std.Enabled(level) +} + +func SetWriter(w io.Writer) { + std.SetWriter(w) +} + +func SetPersistWriter(w io.Writer) { + std.SetPersistWriter(w) +} + +func With(attrs ...Attr) Logger { + return std.With(attrs...) +} + +func WithGroup(name string) Logger { + return std.WithGroup(name) +} + +// Log logs at level. +func Log(level Level, msg string, args ...any) { + std.Log(level, msg, args...) +} + +// Trace logs at LevelTrace. +func Trace(msg string, args ...any) { + std.Trace(msg, args...) +} + +// Debug logs at LevelDebug. +func Debug(msg string, args ...any) { + std.Debug(msg, args...) +} + +// Info logs at LevelInfo. +func Info(msg string, args ...any) { + std.Info(msg, args...) +} + +// Warn logs at LevelWarn. +func Warn(msg string, args ...any) { + std.Warn(msg, args...) +} + +// Error logs at LevelError. +func Error(msg string, args ...any) { + std.Error(msg, args...) +} + +// Fatal logs at LevelFatal. +func Fatal(msg string, args ...any) { + std.Fatal(msg, args...) +} diff --git a/pkg/log/logger.go b/pkg/log/logger.go new file mode 100644 index 0000000..5a40646 --- /dev/null +++ b/pkg/log/logger.go @@ -0,0 +1,273 @@ +package log + +import ( + "fmt" + "io" + "log" + "log/slog" + "os" + "runtime" + "sync" + "sync/atomic" + "time" + "unsafe" +) + +const ( + Ldate = 1 << iota + Ltime + Lmicroseconds + Llongfile + Lshortfile + Lfields + Lcolor + LstdFlags = Ltime | Lmicroseconds | Lfields | Lcolor +) + +type Options struct { + Flags int + Level Level + IgnorePC bool + Timezone *time.Location + Colorer Colorer + PersistWriter io.Writer + Writer io.Writer +} + +type Logger interface { + Flags() int + SetFlags(flags int) + Level() Level + SetLevel(Level) + Timezone() *time.Location + SetTimezone(loc *time.Location) + IgnorePC() bool + SetIgnorePC(ignore bool) + Enabled(level Level) bool + Writer() io.Writer + SetWriter(w io.Writer) + PersistWriter() io.Writer + SetPersistWriter(w io.Writer) + With(attrs ...Attr) Logger + WithGroup(name string) Logger + Log(level Level, msg string, args ...any) + Trace(msg string, args ...any) + Debug(msg string, args ...any) + Info(msg string, args ...any) + Warn(msg string, args ...any) + Error(msg string, args ...any) + Fatal(msg string, args ...any) +} + +var _ Logger = &logger{} + +type logger struct { + flags int32 + level int32 + ignorePC int32 + unpersist int32 // 忽略持久化写入 + discard int32 // 忽略控制台输出 + timezone unsafe.Pointer + mu *sync.Mutex + persistWriter io.Writer + writer io.Writer + handler slog.Handler +} + +func New(o *Options) Logger { + if o.Flags == 0 { + o.Flags = LstdFlags + } + if o.Timezone == nil { + o.Timezone = time.FixedZone("CST", 8*3600) // 使用东八区时间 + } + if o.Writer == nil { + o.Writer = os.Stderr + } + if o.PersistWriter == nil { + o.PersistWriter = io.Discard + } + if o.Colorer == nil { + o.Colorer = NewColorer() + } + + l := &logger{mu: &sync.Mutex{}} + + l.handler = &Handler{ + ctx: unsafe.Pointer(l), + color: o.Colorer, + Handler: slog.NewJSONHandler( + &mutablePersistWriter{l}, + &slog.HandlerOptions{ + AddSource: true, + Level: &mutableLevel{l}, + ReplaceAttr: createAttrReplacer(l), + }, + ), + l: log.New(&mutableWriter{l}, "", 0), + } + + l.SetFlags(o.Flags) + l.SetLevel(o.Level) + l.SetIgnorePC(o.IgnorePC) + l.SetTimezone(o.Timezone) + l.SetPersistWriter(o.PersistWriter) + l.SetWriter(o.Writer) + + return l +} + +func (l *logger) Flags() int { + return int(atomic.LoadInt32(&l.flags)) +} + +func (l *logger) SetFlags(flags int) { + atomic.StoreInt32(&l.flags, int32(flags)) +} + +func (l *logger) Level() Level { + return Level(int(atomic.LoadInt32(&l.level))) +} + +func (l *logger) SetLevel(level Level) { + atomic.StoreInt32(&l.level, int32(level)) +} + +func (l *logger) Timezone() *time.Location { + // 参考 atomic.Pointer#Load 实现 + return (*time.Location)(atomic.LoadPointer(&l.timezone)) +} + +func (l *logger) SetTimezone(loc *time.Location) { + // 参考 atomic.Pointer#Store 实现 + atomic.StorePointer(&l.timezone, unsafe.Pointer(loc)) +} + +func (l *logger) IgnorePC() bool { + return atomic.LoadInt32(&l.ignorePC) == 1 +} + +func (l *logger) SetIgnorePC(ignore bool) { + atomic.StoreInt32(&l.ignorePC, bool2int32(ignore)) +} + +func (l *logger) Enabled(level Level) bool { + return l.handler.Enabled(nil, level.slog().Level()) +} + +func (l *logger) SetWriter(w io.Writer) { + l.mu.Lock() + defer l.mu.Unlock() + l.writer = w + atomic.StoreInt32(&l.discard, bool2int32(w == io.Discard)) +} + +func (l *logger) Writer() io.Writer { + l.mu.Lock() + defer l.mu.Unlock() + return l.writer +} + +func (l *logger) SetPersistWriter(w io.Writer) { + l.mu.Lock() + defer l.mu.Unlock() + l.persistWriter = w + atomic.StoreInt32(&l.unpersist, bool2int32(w == io.Discard)) +} + +func (l *logger) PersistWriter() io.Writer { + l.mu.Lock() + defer l.mu.Unlock() + return l.persistWriter +} + +func (l *logger) With(attrs ...Attr) Logger { + if len(attrs) == 0 { + return l + } + c := l.clone() + c.handler = l.handler.WithAttrs(attrs).(*Handler) + return c +} + +func (l *logger) WithGroup(name string) Logger { + if name == "" { + return l + } + c := l.clone() + c.handler = l.handler.WithGroup(name).(*Handler) + return c +} + +func (l *logger) clone() *logger { + c := *l + // TODO(hupeh): 测试 clone 是否报错 + //c.writer = l.writer + //c.persistWriter = l.persistWriter + return &c +} + +// Log logs at level. +func (l *logger) Log(level Level, msg string, args ...any) { + l.log(level, msg, args...) +} + +// Trace logs at LevelTrace. +func (l *logger) Trace(msg string, args ...any) { + l.log(LevelTrace, msg, args...) +} + +// Debug logs at LevelDebug. +func (l *logger) Debug(msg string, args ...any) { + l.log(LevelDebug, msg, args...) +} + +// Info logs at LevelInfo. +func (l *logger) Info(msg string, args ...any) { + l.log(LevelInfo, msg, args...) +} + +// Warn logs at LevelWarn. +func (l *logger) Warn(msg string, args ...any) { + l.log(LevelWarn, msg, args...) +} + +// Error logs at LevelError. +func (l *logger) Error(msg string, args ...any) { + l.log(LevelError, msg, args...) +} + +// Fatal logs at LevelFatal. +func (l *logger) Fatal(msg string, args ...any) { + l.log(LevelFatal, msg, args...) +} + +func (l *logger) log(level Level, msg string, args ...any) { + if !l.Enabled(level) { + return + } + var pc uintptr + if atomic.LoadInt32(&l.ignorePC) != 1 { + var pcs [1]uintptr + // skip [runtime.Callers, this function, this function's caller] + runtime.Callers(3, pcs[:]) + pc = pcs[0] + } + r := slog.NewRecord(time.Now(), level.slog().Level(), msg, pc) + if len(args) > 0 { + var sprintfArgs []any + for _, arg := range args { + switch v := arg.(type) { + case Attr: + r.AddAttrs(v) + default: + sprintfArgs = append(sprintfArgs, arg) + } + } + if len(sprintfArgs) > 0 { + msg = fmt.Sprintf(msg, sprintfArgs...) + } + r.Message = msg + } + _ = l.handler.Handle(nil, r) +} diff --git a/pkg/log/util.go b/pkg/log/util.go new file mode 100644 index 0000000..4de92cb --- /dev/null +++ b/pkg/log/util.go @@ -0,0 +1,77 @@ +package log + +import ( + "log/slog" + "time" +) + +type mutableLevel struct { + l *logger +} + +func (l *mutableLevel) Level() slog.Level { + return l.l.Level().slog().Level() +} + +type mutablePersistWriter struct { + l *logger +} + +func (l *mutablePersistWriter) Write(b []byte) (int, error) { + l.l.mu.Lock() + defer l.l.mu.Unlock() + return l.l.persistWriter.Write(b) +} + +type mutableWriter struct { + l *logger +} + +func (l *mutableWriter) Write(b []byte) (int, error) { + l.l.mu.Lock() + defer l.l.mu.Unlock() + return l.l.writer.Write(b) +} + +func bool2int32(v bool) int32 { + if v { + return 1 + } else { + return 0 + } +} + +func createAttrReplacer(l *logger) func([]string, Attr) Attr { + return func(_ []string, a Attr) Attr { + if a.Key == slog.LevelKey { + level := a.Value.Any().(slog.Level) + levelLabel := parseSlogLevel(level).String() + a.Value = slog.StringValue(levelLabel) + } else if a.Key == slog.TimeKey { + t := a.Value.Any().(time.Time) + a.Value = slog.TimeValue(t.In(l.Timezone())) + } else if a.Key == slog.SourceKey { + s := a.Value.Any().(*slog.Source) + var as []Attr + if s.Function != "" { + as = append(as, String("func", s.Function)) + } + if s.File != "" { + as = append(as, String("file", s.File)) + } + if s.Line != 0 { + as = append(as, Int("line", s.Line)) + } + a.Value = slog.GroupValue(as...) + } else if a.Key == rawLevelKey || a.Key == rawTimeKey { + // TODO(hupeh): 在 JSONHandler 中替换这两个值 + a.Key = "" + a.Value = slog.AnyValue(nil) + } + return a + } +} + +func identify(s string) string { + return s +} diff --git a/pkg/log/writer.go b/pkg/log/writer.go new file mode 100644 index 0000000..ff4a646 --- /dev/null +++ b/pkg/log/writer.go @@ -0,0 +1,169 @@ +package log + +import ( + "errors" + "os" + "path" + "strings" + "sync" + "time" +) + +const ( + WhenSecond = iota + WhenMinute + WhenHour + WhenDay +) + +type RotateWriter struct { + filename string // should be set to the actual filename + written int + interval time.Duration + rotateSize int + rotateTo func(time.Time) string + rolloverAt chan struct{} + timer *time.Timer + mu sync.Mutex + fp *os.File + closed chan struct{} +} + +// Rotate Make a new RotateWriter. Return nil if error occurs during setup. +func Rotate(basename string, when int, rotateSize int) *RotateWriter { + if rotateSize <= 0 { + panic("invalid rotate size") + } + + var interval time.Duration + var suffix string + switch when { + case WhenSecond: + interval = time.Second + suffix = "20060102150405" + case WhenMinute: + interval = time.Minute + suffix = "200601021504" + case WhenHour: + interval = time.Hour + suffix = "2006010215" + case WhenDay: + fallthrough + default: + interval = time.Hour * 24 + suffix = "20060102" + } + + // 解决 Windows 电脑路径问题 + basename = strings.ReplaceAll(basename, "\\", "/") + + filenameWithSuffix := path.Base(basename) + fileSuffix := path.Ext(filenameWithSuffix) + filename := strings.TrimSuffix(filenameWithSuffix, fileSuffix) + fileDir := path.Dir(basename) + + // 创建日志文件目录 + if err := os.MkdirAll(fileDir, 0777); err != nil { + panic(err) + } + + w := &RotateWriter{ + filename: basename, + interval: interval, + rotateSize: rotateSize, + rotateTo: func(t time.Time) string { + return fileDir + "/" + filename + "." + t.Format(suffix) + fileSuffix + }, + rolloverAt: make(chan struct{}), + mu: sync.Mutex{}, + } + + err := w.Rotate() + if err != nil { + return nil + } + + return w +} + +func (w *RotateWriter) Write(b []byte) (int, error) { + select { + case <-w.closed: + return 0, errors.New("already closed") + case <-w.rolloverAt: + if err := w.Rotate(); err != nil { + return 0, err + } + return w.Write(b) + default: + w.mu.Lock() + defer w.mu.Unlock() + n, err := w.fp.Write(b) + w.written += n + if w.written >= w.rotateSize { + w.timer.Stop() + w.rolloverAt <- struct{}{} + } + return n, err + } +} + +func (w *RotateWriter) Close() error { + select { + case <-w.closed: + default: + w.mu.Lock() + defer w.mu.Unlock() + if w.fp != nil { + return w.fp.Close() + } + w.timer.Stop() + } + return nil +} + +func (w *RotateWriter) Rotate() error { + w.mu.Lock() + defer w.mu.Unlock() + + // Close existing file if open + if w.fp != nil { + err := w.fp.Close() + w.fp = nil + if err != nil { + return err + } + } + + // Rename dest file if it already exists + _, err := os.Stat(w.filename) + if err == nil { + err = os.Rename(w.filename, w.rotateTo(time.Now())) + if err != nil { + return err + } + } + + // Create a file. + w.fp, err = os.Create(w.filename) + if err == nil { + return err + } + + if w.timer != nil { + w.timer.Stop() + } + + w.timer = time.NewTimer(w.interval) + + go func() { + // todo 到底是 ok 还是 !ok + if _, ok := <-w.timer.C; !ok { + w.rolloverAt <- struct{}{} + } + }() + + w.written = 0 + + return err +} diff --git a/pkg/logs/attr.go b/pkg/logs/attr.go new file mode 100644 index 0000000..48e12d3 --- /dev/null +++ b/pkg/logs/attr.go @@ -0,0 +1,67 @@ +package logs + +import ( + "log/slog" + "time" +) + +type Attr = slog.Attr + +// String returns an Attr for a string value. +func String(key, value string) Attr { + return slog.String(key, value) +} + +// Int64 returns an Attr for an int64. +func Int64(key string, value int64) Attr { + return slog.Int64(key, value) +} + +// Int converts an int to an int64 and returns +// an Attr with that value. +func Int(key string, value int) Attr { + return slog.Int(key, value) +} + +// Uint64 returns an Attr for a uint64. +func Uint64(key string, v uint64) Attr { + return slog.Uint64(key, v) +} + +// Float64 returns an Attr for a floating-point number. +func Float64(key string, v float64) Attr { + return slog.Float64(key, v) +} + +// Bool returns an Attr for a bool. +func Bool(key string, v bool) Attr { + return slog.Bool(key, v) +} + +// Time returns an Attr for a time.Time. +// It discards the monotonic portion. +func Time(key string, v time.Time) Attr { + return slog.Time(key, v) +} + +// Duration returns an Attr for a time.Duration. +func Duration(key string, v time.Duration) Attr { + return slog.Duration(key, v) +} + +// Group returns an Attr for a Group Instance. +// The first argument is the key; the remaining arguments +// are converted to Attrs as in [Logger.Log]. +// +// Use Group to collect several key-value pairs under a single +// key on a log line, or as the result of LogValue +// in order to log a single value as multiple Attrs. +func Group(key string, args ...any) Attr { + return slog.Group(key, args...) +} + +// Any returns an Attr for the supplied value. +// See [AnyValue] for how values are treated. +func Any(key string, value any) Attr { + return slog.Any(key, value) +} diff --git a/pkg/logs/color.go b/pkg/logs/color.go new file mode 100644 index 0000000..74dbce4 --- /dev/null +++ b/pkg/logs/color.go @@ -0,0 +1,44 @@ +package logs + +import ( + "github.com/mattn/go-isatty" + "os" +) + +type Color string + +var ( + //fgBlack Color = "\x1b[30m" + //fgWhiteItalic Color = "\x1b[37;3m" + + FgRed Color = "\x1b[31m" + FgGreen Color = "\x1b[32m" + FgYellow Color = "\x1b[33m" + FgBlue Color = "\x1b[34m" + FgMagenta Color = "\x1b[35m" + FgCyan Color = "\x1b[36m" + FgWhite Color = "\x1b[37m" + FgHiBlack Color = "\x1b[90m" + fgGreenItalic Color = "\x1b[32;3m" + + // NoColor defines if the output is colorized or not. It's dynamically set to + // false or true based on the stdout's file descriptor referring to a terminal + // or not. It's also set to true if the NO_COLOR environment variable is + // set (regardless of its value). This is a global option and affects all + // colors. For more control over each Color block use the methods + // DisableColor() individually. + noColor = noColorIsSet() || os.Getenv("TERM") == "dumb" || + (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) +) + +// noColorIsSet returns true if the environment variable NO_COLOR is set to a non-empty string. +func noColorIsSet() bool { + return os.Getenv("NO_COLOR") != "" +} + +func (c Color) Wrap(msg string) string { + if noColorIsSet() || noColor { + return msg + } + return string(c) + msg + "\x1b[0m" +} diff --git a/pkg/logs/handler.go b/pkg/logs/handler.go new file mode 100644 index 0000000..80aeb27 --- /dev/null +++ b/pkg/logs/handler.go @@ -0,0 +1 @@ +package logs diff --git a/pkg/logs/level.go b/pkg/logs/level.go new file mode 100644 index 0000000..d7a6e90 --- /dev/null +++ b/pkg/logs/level.go @@ -0,0 +1,59 @@ +package logs + +import ( + "log/slog" +) + +type Level int + +const ( + LevelTrace Level = iota + LevelDebug // 用于程序调试 + LevelInfo // 用于程序运行 + LevelWarn // 潜在错误或非预期结果 + LevelError // 发生错误,但不影响系统的继续运行 + LevelFatal + LevelSilent +) + +// 越界取近值 +func (l Level) real() Level { + return min(LevelSilent, max(l, LevelTrace)) +} + +// Level 实现 slog.Leveler 接口 +func (l Level) Level() slog.Level { + return slog.Level(16 - int(LevelSilent-l.real())*4) +} + +func (l Level) slog() slog.Leveler { + return l.Level() +} + +func (l Level) String() string { + switch l { + case LevelTrace: + return "TRACE" + case LevelDebug: + return "DEBUG" + case LevelInfo: + return "INFO" + case LevelWarn: + return "WARN" + case LevelError: + return "ERROR" + case LevelFatal: + return "FATAL" + case LevelSilent: + return "OFF" + } + if l < LevelTrace { + return "TRACE" + } else { + return "OFF" + } +} + +func parseSlogLevel(level slog.Level) Level { + return Level(level/4 + 2).real() +} diff --git a/pkg/logs/log.go b/pkg/logs/log.go new file mode 100644 index 0000000..90c2fe0 --- /dev/null +++ b/pkg/logs/log.go @@ -0,0 +1,38 @@ +package logs + +import ( + "io" + "time" +) + +var std = New(&Options{Level: LevelInfo}) + +func Default() Logger { return std } + +func SetFlags(flags int) { Default().SetFlags(flags) } +func Flags() int { return Default().Flags() } +func SetTimezone(loc *time.Location) { Default().SetTimezone(loc) } +func Timezone() *time.Location { return Default().Timezone() } +func SetLevel(level Level) { Default().SetLevel(level) } +func GetLevel() Level { return Default().Level() } +func SetPersistWriter(w io.Writer) { Default().SetPersistWriter(w) } +func SetWriter(w io.Writer) { Default().SetWriter(w) } +func With(args ...Attr) Logger { return Default().With(args...) } +func WithGroup(name string) Logger { return Default().WithGroup(name) } +func Enabled(level Level) bool { return Default().Enabled(level) } +func Log(level Level, msg string, args ...any) { Default().Log(level, msg, args...) } +func ForkLevel(level Level, msg string, args ...any) ChildLogger { + return Default().ForkLevel(level, msg, args...) +} +func Trace(msg string, args ...any) { Default().Trace(msg, args...) } +func ForkTrace(msg string, args ...any) ChildLogger { return Default().ForkTrace(msg, args...) } +func Debug(msg string, args ...any) { Default().Debug(msg, args...) } +func ForkDebug(msg string, args ...any) ChildLogger { return Default().ForkDebug(msg, args...) } +func Info(msg string, args ...any) { Default().Info(msg) } +func ForkInfo(msg string, args ...any) ChildLogger { return Default().ForkInfo(msg, args...) } +func Warn(msg string, args ...any) { Default().Warn(msg, args...) } +func ForkWarn(msg string, args ...any) ChildLogger { return Default().ForkWarn(msg, args...) } +func Error(msg string, args ...any) { Default().Error(msg, args...) } +func ForkError(msg string, args ...any) ChildLogger { return Default().ForkError(msg, args...) } +func Fatal(msg string, args ...any) { Default().Fatal(msg, args...) } +func ForkFatal(msg string, args ...any) ChildLogger { return Default().ForkFatal(msg, args...) } diff --git a/pkg/logs/logger.go b/pkg/logs/logger.go new file mode 100644 index 0000000..b1ac0e6 --- /dev/null +++ b/pkg/logs/logger.go @@ -0,0 +1,512 @@ +package logs + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log" + "log/slog" + "os" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" + "unsafe" +) + +var ( + // 使用东八区时间 + // https://cloud.tencent.com/developer/article/1805859 + cstZone = time.FixedZone("CST", 8*3600) + childLoggerKey = "sorbet/log:ChildLogger" +) + +const ( + Ldate = 1 << iota + Ltime + Lmicroseconds + Llongfile + Lshortfile + Lfields + Lcolor + LstdFlags = Ltime | Lmicroseconds | Lfields | Lcolor +) + +type Logger interface { + SetFlags(flags int) + Flags() int + SetTimezone(loc *time.Location) + Timezone() *time.Location + SetLevel(level Level) + Level() Level + SetPersistWriter(w io.Writer) + SetWriter(w io.Writer) + With(args ...Attr) Logger + WithGroup(name string) Logger + Enabled(level Level) bool + Log(level Level, msg string, args ...any) + ForkLevel(level Level, msg string, args ...any) ChildLogger + Trace(msg string, args ...any) + ForkTrace(msg string, args ...any) ChildLogger + Debug(msg string, args ...any) + ForkDebug(msg string, args ...any) ChildLogger + Info(msg string, args ...any) + ForkInfo(msg string, args ...any) ChildLogger + Warn(msg string, args ...any) + ForkWarn(msg string, args ...any) ChildLogger + Error(msg string, args ...any) + ForkError(msg string, args ...any) ChildLogger + Fatal(msg string, args ...any) + ForkFatal(msg string, args ...any) ChildLogger +} + +type ChildLogger interface { + Print(msg string, args ...any) + Finish() +} + +type Options struct { + Flags int + Level Level + Timezone *time.Location + PersistWriter io.Writer + Writer io.Writer +} + +type logger struct { + parent *logger + isChild int32 + indent int32 + + level int32 + flags int32 + timezone unsafe.Pointer + + outMu *sync.Mutex + isPersistDiscard int32 + isDiscard int32 + persistWriter io.Writer + writer io.Writer + + handler slog.Handler + l *log.Logger +} + +func New(opts *Options) Logger { + if opts.Flags == 0 { + opts.Flags = LstdFlags + } + if opts.Timezone == nil { + opts.Timezone = cstZone + } + if opts.PersistWriter == nil { + opts.PersistWriter = io.Discard + } + if opts.Writer == nil { + opts.Writer = os.Stderr + } + + var l *logger + l = &logger{ + outMu: &sync.Mutex{}, + persistWriter: opts.PersistWriter, + writer: opts.Writer, + l: log.New(opts.Writer, "", 0), + } + + l.SetLevel(opts.Level) + l.SetFlags(opts.Flags) + l.SetTimezone(opts.Timezone) + l.SetPersistWriter(opts.PersistWriter) + l.SetWriter(opts.Writer) + + l.handler = slog.NewJSONHandler(opts.PersistWriter, &slog.HandlerOptions{ + AddSource: true, + Level: opts.Level, + ReplaceAttr: l.onAttr, + }) + + return l +} + +func (l *logger) SetFlags(flags int) { + atomic.StoreInt32(&l.flags, int32(flags)) +} + +func (l *logger) Flags() int { + return int(atomic.LoadInt32(&l.flags)) +} + +func (l *logger) SetTimezone(loc *time.Location) { + // FIXME(hupeh): 如何原子化储存结构体实例 + atomic.StorePointer(&l.timezone, unsafe.Pointer(loc)) +} + +func (l *logger) Timezone() *time.Location { + return (*time.Location)(atomic.LoadPointer(&l.timezone)) +} + +func (l *logger) SetLevel(level Level) { + atomic.StoreInt32(&l.level, int32(level)) +} + +func (l *logger) Level() Level { + return Level(int(atomic.LoadInt32(&l.level))) +} + +func (l *logger) SetPersistWriter(w io.Writer) { + l.outMu.Lock() + defer l.outMu.Unlock() + l.persistWriter = w + atomic.StoreInt32(&l.isPersistDiscard, discard(w)) +} + +func (l *logger) SetWriter(w io.Writer) { + l.outMu.Lock() + defer l.outMu.Unlock() + l.writer = w + atomic.StoreInt32(&l.isDiscard, discard(w)) +} + +func discard(w io.Writer) int32 { + if w == io.Discard { + return 1 + } + return 0 +} + +func (l *logger) onAttr(_ []string, a slog.Attr) slog.Attr { + switch a.Key { + case slog.LevelKey: + level := a.Value.Any().(slog.Level) + levelLabel := parseSlogLevel(level).String() + a.Value = slog.StringValue(levelLabel) + case slog.TimeKey: + t := a.Value.Any().(time.Time) + a.Value = slog.TimeValue(t.In(l.Timezone())) + case slog.SourceKey: + s := a.Value.Any().(*slog.Source) + var as []Attr + if s.Function != "" { + as = append(as, String("func", s.Function)) + } + if s.File != "" { + as = append(as, String("file", s.File)) + } + if s.Line != 0 { + as = append(as, Int("line", s.Line)) + } + a.Value = slog.GroupValue(as...) + } + return a +} + +func (l *logger) Handle(ctx context.Context, r slog.Record) error { + if atomic.LoadInt32(&l.isDiscard) == 0 { + child, ok := ctx.Value(childLoggerKey).(*childLogger) + indent, err := l.println(child, r) + if err != nil { + return err + } + if ok && indent > 0 { + atomic.StoreInt32(&child.indent, indent) + } + } + if atomic.LoadInt32(&l.isPersistDiscard) == 0 { + return l.handler.Handle(ctx, r) + } + return nil +} + +func (l *logger) println(child *childLogger, r slog.Record) (int32, error) { + var output string + var sep string + var indent int32 + + write := func(s string) { + if sep == "" { + sep = " " + } else { + output += sep + } + output += s + } + + flags := l.Flags() + colorful := flags&Lcolor != 0 + msg := r.Message + level := parseSlogLevel(r.Level) + levelStr := level.String() + withChild := child != nil + + if withChild { + indent = atomic.LoadInt32(&child.indent) + withChild = indent > 0 + } + + if withChild { + write(strings.Repeat(" ", int(indent))) + indent = 0 + } else { + if flags&(Ldate|Ltime|Lmicroseconds) != 0 { + t := r.Time.In(l.Timezone()) + if flags&Ldate != 0 { + write(t.Format("2006/01/02")) + } + if flags&(Ltime|Lmicroseconds) != 0 { + if flags&Lmicroseconds != 0 { + write(t.Format("15:04:05.000")) + } else { + write(t.Format("15:04:05")) + } + } + } + + // 保存缩进 + indent += int32(len(output) + len(levelStr) + 3) + + if colorful { + switch level { + case LevelDebug: + levelStr = FgCyan.Wrap(levelStr) + msg = FgCyan.Wrap(msg) + case LevelInfo: + levelStr = FgBlue.Wrap(levelStr) + " " + msg = FgBlue.Wrap(msg) + indent += 1 + case LevelWarn: + levelStr = FgYellow.Wrap(levelStr) + " " + msg = FgYellow.Wrap(msg) + indent += 1 + case LevelError: + levelStr = FgRed.Wrap(levelStr) + msg = FgRed.Wrap(msg) + case LevelFatal: + levelStr = FgMagenta.Wrap(levelStr) + msg = FgMagenta.Wrap(msg) + } + levelStr = FgHiBlack.Wrap("[") + levelStr + FgHiBlack.Wrap("]") + } else { + levelStr = "[" + r.Level.String() + "]" + } + + write(levelStr) + } + + write(msg) + + if flags&(Lshortfile|Llongfile) != 0 && r.PC > 0 { + var fileStr string + fs := runtime.CallersFrames([]uintptr{r.PC}) + f, _ := fs.Next() + file := f.File + if flags&Lshortfile != 0 { + i := strings.LastIndexAny(file, "\\/") + if i > -1 { + file = file[i+1:] + } + } + fileStr = fmt.Sprintf("%s:%s:%d", f.Function, file, f.Line) + if colorful { + fileStr = fgGreenItalic.Wrap(fileStr) + } + write(fileStr) + } + + if numAttrs := r.NumAttrs(); flags&Lfields != 0 && numAttrs > 0 { + fields := make(map[string]any, numAttrs) + r.Attrs(func(a Attr) bool { + fields[a.Key] = a.Value.Any() + return true + }) + b, err := json.Marshal(fields) + if err != nil { + return 0, err + } + fieldsStr := string(bytes.TrimSpace(b)) + if fieldsStr != "" { + if colorful { + fieldsStr = FgHiBlack.Wrap(fieldsStr) + } + write(fieldsStr) + } + } + + l.l.Println(output) + + return indent, nil +} + +func (l *logger) clone() *logger { + c := *l + return &c +} + +func (l *logger) With(args ...Attr) Logger { + if len(args) == 0 { + return l + } + c := l.clone() + c.handler = c.handler.WithAttrs(args) + return c +} + +func (l *logger) WithGroup(name string) Logger { + if name == "" { + return l + } + c := l.clone() + c.handler = c.handler.WithGroup(name) + return c +} + +func (l *logger) Enabled(level Level) bool { + return l.handler.Enabled(nil, level.slog().Level()) +} + +// Log logs at level. +func (l *logger) Log(level Level, msg string, args ...any) { + l.log(nil, level, msg, args...) +} + +func (l *logger) ForkLevel(level Level, msg string, args ...any) ChildLogger { + c := &childLogger{ + parent: l, + level: level, + indent: 0, + records: make([]slog.Record, 0), + closed: make(chan struct{}), + } + c.Print(msg, args...) + return c +} + +// Trace logs at LevelTrace. +func (l *logger) Trace(msg string, args ...any) { + l.log(nil, LevelTrace, msg, args...) +} + +func (l *logger) ForkTrace(msg string, args ...any) ChildLogger { + return l.ForkLevel(LevelTrace, msg, args...) +} + +// Debug logs at LevelDebug. +func (l *logger) Debug(msg string, args ...any) { + l.log(nil, LevelDebug, msg, args...) +} + +func (l *logger) ForkDebug(msg string, args ...any) ChildLogger { + return l.ForkLevel(LevelDebug, msg, args...) +} + +// Info logs at LevelInfo. +func (l *logger) Info(msg string, args ...any) { + l.log(nil, LevelInfo, msg, args...) +} + +func (l *logger) ForkInfo(msg string, args ...any) ChildLogger { + return l.ForkLevel(LevelInfo, msg, args...) +} + +// Warn logs at LevelWarn. +func (l *logger) Warn(msg string, args ...any) { + l.log(nil, LevelWarn, msg, args...) +} + +func (l *logger) ForkWarn(msg string, args ...any) ChildLogger { + return l.ForkLevel(LevelWarn, msg, args...) +} + +// Error logs at LevelError. +func (l *logger) Error(msg string, args ...any) { + l.log(nil, LevelError, msg, args...) +} + +func (l *logger) ForkError(msg string, args ...any) ChildLogger { + return l.ForkLevel(LevelError, msg, args...) +} + +// Fatal logs at LevelFatal. +func (l *logger) Fatal(msg string, args ...any) { + l.log(nil, LevelFatal, msg, args...) +} + +func (l *logger) ForkFatal(msg string, args ...any) ChildLogger { + return l.ForkLevel(LevelFatal, msg, args...) +} + +func (l *logger) log(ctx context.Context, level Level, msg string, args ...any) { + if !l.Enabled(level) { + return + } + if ctx == nil { + ctx = context.Background() + } + _ = l.Handle(ctx, newRecord(level, msg, args)) +} + +func newRecord(level Level, msg string, args []any) slog.Record { + //var pc uintptr + //if !internal.IgnorePC { + // var pcs [1]uintptr + // // skip [runtime.Callers, this function, this function's caller] + // runtime.Callers(3, pcs[:]) + // pc = pcs[0] + //} + //r := slog.NewRecord(time.Now(), level.slog().Level(), msg, pc) + r := slog.NewRecord(time.Now(), level.slog().Level(), msg, 0) + if len(args) > 0 { + var sprintfArgs []any + for _, arg := range args { + switch v := arg.(type) { + case Attr: + r.AddAttrs(v) + default: + sprintfArgs = append(sprintfArgs, arg) + } + } + if len(sprintfArgs) > 0 { + r.Message = fmt.Sprintf(msg, sprintfArgs...) + } + } + return r +} + +type childLogger struct { + parent *logger + level Level + indent int32 + begin slog.Record + finish slog.Record + records []slog.Record + closed chan struct{} +} + +func (c *childLogger) Print(msg string, args ...any) { + select { + case <-c.closed: + default: + c.records = append( + c.records, + newRecord(c.level, msg, args), + ) + } +} + +func (c *childLogger) Finish() { + select { + case <-c.closed: + return + default: + close(c.closed) + } + + ctx := context.Background() + ctx = context.WithValue(ctx, childLoggerKey, c) + for _, record := range c.records { + _ = c.parent.Handle(ctx, record) + } +} diff --git a/pkg/misc/clamp.go b/pkg/misc/clamp.go new file mode 100644 index 0000000..69d02ce --- /dev/null +++ b/pkg/misc/clamp.go @@ -0,0 +1,13 @@ +package misc + +import "cmp" + +func Clamp[T cmp.Ordered](val, min, max T) T { + if val > max { + val = max + } + if val < min { + val = min + } + return val +} diff --git a/pkg/misc/fallback.go b/pkg/misc/fallback.go new file mode 100644 index 0000000..fd71913 --- /dev/null +++ b/pkg/misc/fallback.go @@ -0,0 +1,14 @@ +package misc + +// Fallback 当值为空时返回默认值 +func Fallback[T any](v T, fs ...T) T { + if !IsZero(v) { + return v + } + for _, f := range fs { + if !IsZero(f) { + return f + } + } + return v +} diff --git a/pkg/misc/key.go b/pkg/misc/key.go new file mode 100644 index 0000000..3cb484f --- /dev/null +++ b/pkg/misc/key.go @@ -0,0 +1,26 @@ +package misc + +import ( + "context" +) + +type Key[T any] struct { + Name string +} + +func (k *Key[T]) Wrap(ctx context.Context, value *T) context.Context { + return context.WithValue(ctx, k, value) +} + +func (k *Key[T]) Lookup(ctx context.Context) (*T, bool) { + t, ok := ctx.Value(k).(*T) + return t, ok +} + +func (k *Key[T]) Value(ctx context.Context) *T { + t, ok := ctx.Value(k).(*T) + if !ok { + panic("not value found") + } + return t +} diff --git a/pkg/misc/result.go b/pkg/misc/result.go new file mode 100644 index 0000000..e64b245 --- /dev/null +++ b/pkg/misc/result.go @@ -0,0 +1,45 @@ +package misc + +// Result 使用泛型模仿 Rust 实现一个简单的 Result 类型 +// https://juejin.cn/post/7161342717190996005 +type Result[T any] struct { + value T + err error +} + +func (r *Result[T]) Ok() bool { + return r.err != nil +} + +func (r *Result[T]) Err() error { + return r.err +} + +func (r *Result[T]) Unwrap() T { + if r.err != nil { + panic(r.err) + } + return r.value +} + +func (r *Result[T]) Expect(err error) T { + if r.err != nil { + panic(err) + } + return r.value +} + +func NewResult[T any](v T, err error) *Result[T] { + return &Result[T]{ + value: v, + err: err, + } +} + +func Match[T any](r *Result[T], okF func(T), errF func(error)) { + if r.err != nil { + errF(r.err) + } else { + okF(r.value) + } +} diff --git a/pkg/misc/zero.go b/pkg/misc/zero.go new file mode 100644 index 0000000..b66ee91 --- /dev/null +++ b/pkg/misc/zero.go @@ -0,0 +1,19 @@ +package misc + +import "reflect" + +// IsZero 泛型零值判断 +// https://stackoverflow.com/questions/74000242/in-golang-how-to-compare-interface-as-generics-type-to-nil +func IsZero[T any](v T) bool { + return isZero(reflect.ValueOf(v)) +} + +func isZero(ref reflect.Value) bool { + if !ref.IsValid() { + return true + } + if ref.Type().Kind() == reflect.Ptr { + return isZero(ref.Elem()) + } + return ref.IsZero() +} diff --git a/pkg/rsp/accept.go b/pkg/rsp/accept.go new file mode 100644 index 0000000..1cb5719 --- /dev/null +++ b/pkg/rsp/accept.go @@ -0,0 +1,171 @@ +package rsp + +import ( + "fmt" + "sort" + "strconv" + "strings" +) + +var errInvalidTypeSubtype = "accept: Invalid type '%s'." + +// headerAccept represents a parsed headerAccept(-Charset|-Encoding|-Language) header. +type headerAccept struct { + Type, Subtype string + Q float64 + Extensions map[string]string +} + +// AcceptSlice is a slice of headerAccept. +type acceptSlice []headerAccept + +func Accepts(header, expect string) bool { + _, typeSubtype, err := parseMediaRange(expect) + if err != nil { + return false + } + return accepts(parse(header), typeSubtype[0], typeSubtype[1]) +} + +func accepts(slice acceptSlice, typ, sub string) bool { + for _, a := range slice { + if a.Type != typ { + continue + } + if a.Subtype == "*" || a.Subtype == sub { + return true + } + } + return false +} + +// parses a HTTP headerAccept(-Charset|-Encoding|-Language) header and returns +// AcceptSlice, sorted in decreasing order of preference. If the header lists +// multiple types that have the same level of preference (same specificity of +// type and subtype, same qvalue, and same number of extensions), the type +// that was listed in the header first comes first in the returned value. +// +// See http://www.w3.org/Protocols/rfc2616/rfc2616-sec14 for more information. +func parse(header string) acceptSlice { + mediaRanges := strings.Split(header, ",") + accepted := make(acceptSlice, 0, len(mediaRanges)) + for _, mediaRange := range mediaRanges { + rangeParams, typeSubtype, err := parseMediaRange(mediaRange) + if err != nil { + continue + } + + accept := headerAccept{ + Type: typeSubtype[0], + Subtype: typeSubtype[1], + Q: 1.0, + Extensions: make(map[string]string), + } + + // If there is only one rangeParams, we can stop here. + if len(rangeParams) == 1 { + accepted = append(accepted, accept) + continue + } + + // Validate the rangeParams. + validParams := true + for _, v := range rangeParams[1:] { + nameVal := strings.SplitN(v, "=", 2) + if len(nameVal) != 2 { + validParams = false + break + } + nameVal[1] = strings.TrimSpace(nameVal[1]) + if name := strings.TrimSpace(nameVal[0]); name == "q" { + qval, err := strconv.ParseFloat(nameVal[1], 64) + if err != nil || qval < 0 { + validParams = false + break + } + if qval > 1.0 { + qval = 1.0 + } + accept.Q = qval + } else { + accept.Extensions[name] = nameVal[1] + } + } + + if validParams { + accepted = append(accepted, accept) + } + } + + sort.Sort(accepted) + return accepted +} + +// Len implements the Len() method of the Sort interface. +func (a acceptSlice) Len() int { + return len(a) +} + +// Less implements the Less() method of the Sort interface. Elements are +// sorted in order of decreasing preference. +func (a acceptSlice) Less(i, j int) bool { + // Higher qvalues come first. + if a[i].Q > a[j].Q { + return true + } else if a[i].Q < a[j].Q { + return false + } + + // Specific types come before wildcard types. + if a[i].Type != "*" && a[j].Type == "*" { + return true + } else if a[i].Type == "*" && a[j].Type != "*" { + return false + } + + // Specific subtypes come before wildcard subtypes. + if a[i].Subtype != "*" && a[j].Subtype == "*" { + return true + } else if a[i].Subtype == "*" && a[j].Subtype != "*" { + return false + } + + // A lot of extensions comes before not a lot of extensions. + if len(a[i].Extensions) > len(a[j].Extensions) { + return true + } + + return false +} + +// Swap implements the Swap() method of the Sort interface. +func (a acceptSlice) Swap(i, j int) { + a[i], a[j] = a[j], a[i] +} + +// parseMediaRange parses the provided media range, and on success returns the +// parsed range params and type/subtype pair. +func parseMediaRange(mediaRange string) (rangeParams, typeSubtype []string, err error) { + rangeParams = strings.Split(mediaRange, ";") + typeSubtype = strings.Split(rangeParams[0], "/") + + // typeSubtype should have a length of exactly two. + if len(typeSubtype) > 2 { + err = fmt.Errorf(errInvalidTypeSubtype, rangeParams[0]) + return + } else { + typeSubtype = append(typeSubtype, "*") + } + + // Sanitize typeSubtype. + typeSubtype[0] = strings.TrimSpace(typeSubtype[0]) + typeSubtype[1] = strings.TrimSpace(typeSubtype[1]) + if typeSubtype[0] == "" { + typeSubtype[0] = "*" + } + if typeSubtype[1] == "" { + typeSubtype[1] = "*" + } + + return +} diff --git a/pkg/rsp/error.go b/pkg/rsp/error.go new file mode 100644 index 0000000..154ac56 --- /dev/null +++ b/pkg/rsp/error.go @@ -0,0 +1,102 @@ +package rsp + +import ( + "fmt" + "github.com/labstack/echo/v4" + "strings" +) + +var ( + // ErrOK 表示没有任何错误。 + // 对应 HTTP 响应状态码为 500。 + ErrOK = NewError(0, "") + + // ErrInternal 客户端请求有效,但服务器处理时发生了意外。 + // 对应 HTTP 响应状态码为 500。 + ErrInternal = NewError(-100, "internal error") + + // ErrServiceUnavailable 服务器无法处理请求,一般用于网站维护状态。 + // 对应 HTTP 响应状态码为 503。 + ErrServiceUnavailable = NewError(-101, "Service Unavailable") + + // ErrUnauthorized 用户未提供身份验证凭据,或者没有通过身份验证。 + // 响应的 HTTP 状态码为 401。 + ErrUnauthorized = NewError(-102, "unauthorized") + + // ErrForbidden 用户通过了身份验证,但是不具有访问资源所需的权限。 + // 响应的 HTTP 状态码为 403。 + ErrForbidden = NewError(-103, "Forbidden") + + // ErrGone 所请求的资源已从这个地址转移,不再可用。 + // 响应的 HTTP 状态码为 410。 + ErrGone = NewError(-104, "Gone") + + // ErrUnsupportedMediaType 客户端要求的返回格式不支持。 + // 比如,API 只能返回 JSON 格式,但是客户端要求返回 XML 格式。 + // 响应的 HTTP 状态码为 415。 + ErrUnsupportedMediaType = NewError(-105, "Unsupported Media Type") + + // ErrUnprocessableEntity 无法处理客户端上传的附件,导致请求失败。 + // 响应的 HTTP 状态码为 422。 + ErrUnprocessableEntity = NewError(-106, "Unprocessable Entity") + + // ErrTooManyRequests 客户端的请求次数超过限额。 + // 响应的 HTTP 状态码为 422。 + ErrTooManyRequests = NewError(-107, "Too Many Requests") + + // ErrSeeOther 表示需要参考另一个 URL 才能完成接收的请求操作, + // 当请求方式使用 POST、PUT 和 DELETE 时,对应的 HTTP 状态码为 303, + // 其它的请求方式在大多数情况下应该使用 400 状态码。 + ErrSeeOther = NewError(-108, "see other") + + // ErrBadRequest 服务器不理解客户端的请求。 + // 对应 HTTP 状态码为 404。 + ErrBadRequest = NewError(-109, "bad request") + + // ErrBadParams 客户端提交的参数不符合要求 + // 对应 HTTP 状态码为 400。 + ErrBadParams = NewError(-110, "bad parameters") + + // ErrRecordNotFound 访问的数据不存在 + // 对应 HTTP 状态码为 404。 + ErrRecordNotFound = NewError(-111, "record not found") +) + +type Error struct { + code int + text string +} + +func NewError(code int, text string) *Error { + return &Error{code, text} +} + +func (e *Error) Code() int { + return e.code +} + +func (e *Error) Text() string { + return e.text +} + +func (e *Error) WithText(text ...string) *Error { + for _, s := range text { + if s != "" { + return NewError(e.code, s) + } + } + return e +} + +func (e *Error) AsHttpError(code int) *echo.HTTPError { + he := echo.NewHTTPError(code, e.text) + return he.WithInternal(e) +} + +func (e *Error) String() string { + return strings.TrimSpace(fmt.Sprintf("%d %s", e.code, e.text)) +} + +func (e *Error) Error() string { + return e.String() +} diff --git a/pkg/rsp/rsp.go b/pkg/rsp/rsp.go new file mode 100644 index 0000000..2356538 --- /dev/null +++ b/pkg/rsp/rsp.go @@ -0,0 +1,255 @@ +package rsp + +import ( + "bytes" + "encoding/json" + "errors" + "github.com/labstack/echo/v4" + "net/http" +) + +var ( + TextMarshaller func(map[string]any) (string, error) + HtmlMarshaller func(map[string]any) (string, error) + JsonpCallbacks []string + DefaultJsonpCallback string +) + +func init() { + TextMarshaller = toText + HtmlMarshaller = toText + JsonpCallbacks = []string{"callback", "cb", "jsonp"} + DefaultJsonpCallback = "callback" +} + +func toText(m map[string]any) (string, error) { + buf := &bytes.Buffer{} + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(true) + if err := enc.Encode(m); err != nil { + return "", err + } else { + return buf.String(), nil + } +} + +type response struct { + code int + headers map[string]string + cookies []*http.Cookie + err error + message string + data any +} + +type Option func(o *response) + +func StatusCode(code int) Option { + return func(o *response) { + o.code = code + } +} + +func Header(key, value string) Option { + return func(o *response) { + if o.headers == nil { + o.headers = make(map[string]string) + } + o.headers[key] = value + } +} + +func Cookie(cookie *http.Cookie) Option { + return func(o *response) { + if o.cookies != nil { + for i, h := range o.cookies { + if h.Name == cookie.Name { + o.cookies[i] = cookie + return + } + } + } + o.cookies = append(o.cookies, cookie) + } +} + +func Message(msg string) Option { + return func(o *response) { + o.message = msg + } +} + +func Data(data any) Option { + return func(o *response) { + o.data = data + } +} + +func respond(c echo.Context, o *response) error { + defer func() { + if o.err != nil { + c.Logger().Error(o.err) + } + }() + + m := map[string]any{ + "code": nil, + "success": false, + "message": o.message, + } + var err *Error + if errors.As(o.err, &err) { + m["code"] = err.code + m["message"] = err.text + m["success"] = errors.Is(err, ErrOK) + } else if o.err != nil { + m["code"] = ErrInternal.code + m["message"] = o.err.Error() + } else { + m["code"] = 0 + m["success"] = true + if o.data != nil { + m["data"] = o.data + } + } + if m["message"] == "" { + m["message"] = http.StatusText(o.code) + } + + if c.Response().Committed { + return nil + } + + if o.headers != nil { + header := c.Response().Header() + for key, value := range o.headers { + header.Set(key, value) + } + } + + if o.cookies != nil { + for _, cookie := range o.cookies { + c.SetCookie(cookie) + } + } + + r := c.Request() + if r.Method == http.MethodHead { + return c.NoContent(o.code) + } + + slice := parse(r.Header.Get("Accept")) + for _, a := range slice { + switch x := a.Type + "/" + a.Subtype; x { + case echo.MIMEApplicationJavaScript: + qs := c.Request().URL.Query() + for _, name := range JsonpCallbacks { + if cb := qs.Get(name); cb != "" { + return c.JSONP(o.code, cb, m) + } + } + return c.JSONP(o.code, DefaultJsonpCallback, m) + case echo.MIMEApplicationJSON: + return c.JSON(o.code, m) + case echo.MIMEApplicationXML, echo.MIMETextXML: + return c.XML(o.code, m) + case echo.MIMETextHTML: + if html, err := HtmlMarshaller(m); err != nil { + return err + } else { + return c.HTML(o.code, html) + } + case echo.MIMETextPlain: + if text, err := TextMarshaller(m); err != nil { + return err + } else { + return c.String(o.code, text) + } + } + } + return c.JSON(o.code, m) +} + +func Respond(c echo.Context, opts ...Option) error { + o := response{code: http.StatusOK} + for _, option := range opts { + option(&o) + } + return respond(c, &o) +} + +func Ok(c echo.Context, data any) error { + return Respond(c, Data(data)) +} + +func Created(c echo.Context, data any) error { + return Respond(c, Data(data), StatusCode(http.StatusCreated)) +} + +// Fail 响应一个错误 +func Fail(c echo.Context, err error, opts ...Option) error { + o := response{code: http.StatusInternalServerError} + for _, option := range opts { + option(&o) + } + o.err = err + var he *echo.HTTPError + if errors.As(err, &he) { + o.code = he.Code + } + return respond(c, &o) +} + +// InternalError 响应一个服务器内部错误 +func InternalError(c echo.Context, message ...string) error { + return Fail(c, ErrInternal.WithText(message...)) +} + +// ServiceUnavailable 响应一个服务暂不可用的错误 +func ServiceUnavailable(c echo.Context, message ...string) error { + return Fail(c, + ErrServiceUnavailable.WithText(message...), + StatusCode(http.StatusServiceUnavailable)) +} + +// Unauthorized 需要一个身份验证凭据异常的错误 +func Unauthorized(c echo.Context, message ...string) error { + return Fail(c, + ErrUnauthorized.WithText(message...), + StatusCode(http.StatusUnauthorized)) +} + +// Forbidden 响应一个不具有访问资源所需权限的错误(用户通过了身份验证) +func Forbidden(c echo.Context, message ...string) error { + return Fail(c, + ErrForbidden.WithText(message...), + StatusCode(http.StatusForbidden)) +} + +// UnprocessableEntity 响应一个处理客户端上传失败的错误 +func UnprocessableEntity(c echo.Context, message ...string) error { + return Fail(c, + ErrUnprocessableEntity.WithText(message...), + StatusCode(http.StatusUnprocessableEntity)) +} + +// BadRequest 响应一个服务器不理解客户端请求的错误 +func BadRequest(c echo.Context, message ...string) error { + return Fail(c, + ErrBadRequest.WithText(message...), + StatusCode(http.StatusBadRequest)) +} + +// BadParams 响应一个客户端提交的参数不符合要求的错误 +func BadParams(c echo.Context, message ...string) error { + return Fail(c, + ErrBadParams.WithText(message...), + StatusCode(http.StatusBadRequest)) +} + +// RecordNotFound 响应一个数据不存在的错误 +func RecordNotFound(c echo.Context, message ...string) error { + return Fail(c, + ErrRecordNotFound.WithText(message...), + StatusCode(http.StatusNotFound)) +} diff --git a/pkg/rsp/sse.go b/pkg/rsp/sse.go new file mode 100644 index 0000000..16df3f6 --- /dev/null +++ b/pkg/rsp/sse.go @@ -0,0 +1,72 @@ +package rsp + +import ( + "errors" + "fmt" + "github.com/labstack/echo/v4" + "net/http" + "strings" + "time" +) + +type SseOptions struct { + Id string + Event string + Data string + Retry int +} + +func Echo(c echo.Context, e <-chan *SseOptions) error { + w := c.Response().Writer + f, ok := w.(http.Flusher) + if !ok { + return errors.New("unable to get http.Flusher interface; this is probably due " + + "to nginx buffering the response") + } + if want, have := "text/event-stream", c.Request().Header.Get("Accept"); want != have { + return fmt.Errorf("accept header: want %q, have %q; seems like the browser doesn't "+ + "support server-side events", want, have) + } + // Instruct nginx to NOT buffer the response + w.Header().Set("X-Accel-Buffering", "no") + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + // Send heartbeats to ensure the connection stays up + heartbeat := time.NewTicker(30 * time.Second) + defer heartbeat.Stop() + for { + select { + case <-c.Request().Context().Done(): // When the browser closes the connection + f.Flush() + return nil + case <-heartbeat.C: + sendEvent(w, "", "heartbeat", "{}", 2000) + case opts := <-e: + sendEvent(w, cleanNewline(opts.Id), cleanNewline(opts.Event), cleanNewline(opts.Data), opts.Retry) + } + } +} + +func cleanNewline(str string) string { + return strings.ReplaceAll(str, "\n", "") +} + +func sendEvent(w http.ResponseWriter, id, event, data string, retry int) { + f := w.(http.Flusher) + if id != "" { + fmt.Fprintf(w, "id: %s\n", id) + } + if event != "" { + fmt.Fprintf(w, "event: %s\n", event) + } + if data != "" { + fmt.Fprintf(w, "data: %s\n", data) + } + if retry > 0 { + fmt.Fprintf(w, "retry: %d\n", retry) + } + fmt.Fprint(w, "\n") + f.Flush() +} diff --git a/pkg/zinc/zinc.go b/pkg/zinc/zinc.go new file mode 100644 index 0000000..4b3e9df --- /dev/null +++ b/pkg/zinc/zinc.go @@ -0,0 +1,209 @@ +package zinc + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/go-resty/resty/v2" + "net/http" + "time" +) + +type Client struct { + Host string + User string + Password string +} + +type Index struct { + Name string `json:"name"` + StorageType string `json:"storage_type"` + Mappings *IndexMappings `json:"mappings"` +} + +type IndexMappings struct { + Properties *IndexProperty `json:"properties"` +} + +type IndexProperty map[string]*IndexPropertyT + +type IndexPropertyT struct { + Type string `json:"type"` + Index bool `json:"index"` + Store bool `json:"store"` + Sortable bool `json:"sortable"` + Aggregatable bool `json:"aggregatable"` + Highlightable bool `json:"highlightable"` + Analyzer string `json:"analyzer"` + SearchAnalyzer string `json:"search_analyzer"` + Format string `json:"format"` +} + +type QueryResultT struct { + Took int `json:"took"` + TimedOut bool `json:"timed_out"` + Hits *HitsResultT `json:"hits"` +} + +type HitsResultT struct { + Total *HitsResultTotalT `json:"total"` + MaxScore float64 `json:"max_score"` + Hits []*HitItem `json:"hits"` +} + +type HitsResultTotalT struct { + Value int64 `json:"value"` +} + +type HitItem struct { + Index string `json:"_index"` + Type string `json:"_type"` + ID string `json:"_id"` + Score float64 `json:"_score"` + Timestamp time.Time `json:"@timestamp"` + Source any `json:"_source"` +} + +// NewClient 获取ZincClient新实例 +func NewClient(host, user, passwd string) *Client { + return &Client{ + Host: host, + User: user, + Password: passwd, + } +} + +// CreateIndex 创建索引 +func (c *Client) CreateIndex(name string, p *IndexProperty) bool { + data := &Index{ + Name: name, + StorageType: "disk", + Mappings: &IndexMappings{ + Properties: p, + }, + } + resp, err := c.request().SetBody(data).Put("/api/index") + + if err != nil || resp.StatusCode() != http.StatusOK { + return false + } + + return true +} + +// ExistIndex 检查索引是否存在 +func (c *Client) ExistIndex(name string) bool { + resp, err := c.request().Get("/api/index") + + if err != nil || resp.StatusCode() != http.StatusOK { + return false + } + + retData := &map[string]any{} + err = json.Unmarshal([]byte(resp.String()), retData) + if err != nil { + return false + } + + if _, ok := (*retData)[name]; ok { + return true + } + + return false +} + +// PutDoc 新增/更新文档 +func (c *Client) PutDoc(name string, id int64, doc any) (bool, error) { + resp, err := c.request().SetBody(doc).Put(fmt.Sprintf("/api/%s/_doc/%d", name, id)) + + if err != nil { + return false, err + } + + if resp.StatusCode() != http.StatusOK { + return false, errors.New(resp.Status()) + } + + return true, nil +} + +// BulkPushDoc 批量新增文档 +func (c *Client) BulkPushDoc(docs []map[string]any) (bool, error) { + dataStr := "" + for _, doc := range docs { + str, err := json.Marshal(doc) + if err == nil { + dataStr = dataStr + string(str) + "\n" + } + } + + resp, err := c.request().SetBody(dataStr).Post("/api/_bulk") + if err != nil { + return false, err + } + + if resp.StatusCode() != http.StatusOK { + return false, errors.New(resp.Status()) + } + + return true, nil +} + +func (c *Client) EsQuery(indexName string, q any) (*QueryResultT, error) { + resp, err := c.request().SetBody(q).Post(fmt.Sprintf("/es/%s/_search", indexName)) + if err != nil { + return nil, err + } + + if resp.StatusCode() != http.StatusOK { + return nil, errors.New(resp.Status()) + } + + result := &QueryResultT{} + err = json.Unmarshal(resp.Body(), result) + if err != nil { + return nil, err + } + + return result, nil +} + +func (c *Client) ApiQuery(indexName string, q any) (*QueryResultT, error) { + resp, err := c.request().SetBody(q).Post(fmt.Sprintf("/api/%s/_search", indexName)) + if err != nil { + return nil, err + } + + if resp.StatusCode() != http.StatusOK { + return nil, errors.New(resp.Status()) + } + + result := &QueryResultT{} + err = json.Unmarshal(resp.Body(), result) + if err != nil { + return nil, err + } + + return result, nil +} + +func (c *Client) DelDoc(indexName, id string) error { + resp, err := c.request().Delete(fmt.Sprintf("/api/%s/_doc/%s", indexName, id)) + if err != nil { + return err + } + + if resp.StatusCode() != http.StatusOK { + return errors.New(resp.Status()) + } + + return nil +} + +func (c *Client) request() *resty.Request { + client := resty.New() + client.DisableWarn = true + client.SetBaseURL(c.Host) + client.SetBasicAuth(c.User, c.Password) + return client.R() +} diff --git a/scripts/di.go b/scripts/di.go new file mode 100644 index 0000000..60af06d --- /dev/null +++ b/scripts/di.go @@ -0,0 +1,22 @@ +package main + +import ( + "fmt" + "gorm.io/gorm" + "reflect" +) + +type Person struct { + name string + db *gorm.DB +} + +func main() { + var person Person + t1 := reflect.TypeOf(&person).Elem() + t2 := reflect.TypeOf(&person).Elem() + b1 := t1 == t2 + fmt.Println(t1) + fmt.Println(t2) + fmt.Println(b1) +} diff --git a/scripts/gen/gen.go b/scripts/gen/gen.go new file mode 100644 index 0000000..c75521d --- /dev/null +++ b/scripts/gen/gen.go @@ -0,0 +1,116 @@ +package main + +import ( + "bytes" + "fmt" + "os" + "strings" +) + +var repoStub = `package repositories + +import ( + "gorm.io/gorm" + "sorbet/internal/entities" + "sorbet/pkg/db" +) + +type {{pascal}}Repository struct { + *db.Repository[entities.{{pascal}}] +} + +// New{{pascal}}Repository 创建{{label}}仓库 +func New{{pascal}}Repository(orm *gorm.DB) *{{pascal}}Repository { + return &{{pascal}}Repository{ + db.NewRepositoryWith[entities.{{pascal}}](orm, "id"), + } +} + +` + +var labels = map[string]string{ + "Company": "公司", + "CompanyDepartment": "公司部门", + "CompanyStaff": "公司员工", + "Config": "配置", + "ConfigGroup": "配置组", + "Feature": "栏目", + "FeatureCategory": "栏目分类", + "FeatureConfig": "栏目配置", + "FeatureContent": "栏目内容", + "FeatureContentChapter": "栏目内容章回", + "FeatureContentDetail": "栏目内容详情", + "Resource": "资源", + "ResourceCategory": "资源分类", + "SystemLog": "系统日志", + "SystemMenu": "系统菜单", + "SystemPermission": "系统权限", + "SystemRole": "系统用户角色", + "SystemRolePower": "角色授权", + "SystemUser": "系统用户", +} + +func main() { + dirs, err := os.ReadDir("../internal/entities") + if err != nil { + panic(err) + } + os.MkdirAll("../internal/repositories", os.ModePerm) + var camels []string + var pascals []string + var iocStr string + var createdCount int + for _, dir := range dirs { + name := dir.Name() + if !strings.HasSuffix(name, ".go") { + continue + } + name = strings.TrimSuffix(name, ".go") + parts := strings.Split(name, "_") + for i, part := range parts { + parts[i] = string(append(bytes.ToUpper([]byte{part[0]}), part[1:]...)) + } + pascal := strings.Join(parts, "") + camel := string(bytes.ToLower([]byte{pascal[0]})) + pascal[1:] + pascals = append(pascals, pascal) + camels = append(camels, camel) + label, ok := labels[pascal] + if !ok { + fmt.Println("跳过 " + dir.Name()) + continue + } + iocStr += "\tioc.MustSingleton(New" + pascal + "Repository)\n" + repository := "../internal/repositories/" + dir.Name() + if PathExists(repository) { + fmt.Println("跳过 " + dir.Name()) + } else { + code := strings.ReplaceAll(repoStub, "{{pascal}}", pascal) + code = strings.ReplaceAll(code, "{{camel}}", camel) + code = strings.ReplaceAll(code, "{{label}}", label) + os.WriteFile("../internal/repositories/"+dir.Name(), []byte(code), os.ModePerm) + fmt.Println("创建 " + dir.Name()) + createdCount++ + } + } + + if createdCount > 0 { + os.WriteFile("../internal/repositories/ioc.go", []byte(`package repositories + +import "sorbet/pkg/ioc" + +func init() { +`+iocStr+`} +`), os.ModePerm) + } +} + +func PathExists(path string) bool { + _, err := os.Stat(path) + if err == nil { + return true + } + if os.IsNotExist(err) { + return false + } + panic(err) +}