广告平台(站长使用)
 
 
 
 
 

137 lines
2.5 KiB

  1. package mw
  2. import (
  3. "crypto/sha1"
  4. "encoding/base64"
  5. "errors"
  6. "io"
  7. "github.com/dchest/uniuri"
  8. "github.com/gin-contrib/sessions"
  9. "github.com/gin-gonic/gin"
  10. )
  11. // csrf,xsrf检查
  12. const (
  13. csrfSecret = "csrfSecret"
  14. csrfSalt = "csrfSalt"
  15. csrfToken = "csrfToken"
  16. )
  17. var defaultIgnoreMethods = []string{"GET", "HEAD", "OPTIONS"}
  18. var defaultErrorFunc = func(c *gin.Context) {
  19. panic(errors.New("CSRF token mismatch"))
  20. }
  21. var defaultTokenGetter = func(c *gin.Context) string {
  22. r := c.Request
  23. if t := r.FormValue("_csrf"); len(t) > 0 {
  24. return t
  25. } else if t := r.URL.Query().Get("_csrf"); len(t) > 0 {
  26. return t
  27. } else if t := r.Header.Get("X-CSRF-TOKEN"); len(t) > 0 {
  28. return t
  29. } else if t := r.Header.Get("X-XSRF-TOKEN"); len(t) > 0 {
  30. return t
  31. }
  32. return ""
  33. }
  34. // Options stores configurations for a CSRF middleware.
  35. type Options struct {
  36. Secret string
  37. IgnoreMethods []string
  38. ErrorFunc gin.HandlerFunc
  39. TokenGetter func(c *gin.Context) string
  40. }
  41. func tokenize(secret, salt string) string {
  42. h := sha1.New()
  43. io.WriteString(h, salt+"-"+secret)
  44. hash := base64.URLEncoding.EncodeToString(h.Sum(nil))
  45. return hash
  46. }
  47. func inArray(arr []string, value string) bool {
  48. inarr := false
  49. for _, v := range arr {
  50. if v == value {
  51. inarr = true
  52. break
  53. }
  54. }
  55. return inarr
  56. }
  57. // Middleware validates CSRF token.
  58. func Middleware(options Options) gin.HandlerFunc {
  59. ignoreMethods := options.IgnoreMethods
  60. errorFunc := options.ErrorFunc
  61. tokenGetter := options.TokenGetter
  62. if ignoreMethods == nil {
  63. ignoreMethods = defaultIgnoreMethods
  64. }
  65. if errorFunc == nil {
  66. errorFunc = defaultErrorFunc
  67. }
  68. if tokenGetter == nil {
  69. tokenGetter = defaultTokenGetter
  70. }
  71. return func(c *gin.Context) {
  72. session := sessions.Default(c)
  73. c.Set(csrfSecret, options.Secret)
  74. if inArray(ignoreMethods, c.Request.Method) {
  75. c.Next()
  76. return
  77. }
  78. salt, ok := session.Get(csrfSalt).(string)
  79. if !ok || len(salt) == 0 {
  80. errorFunc(c)
  81. return
  82. }
  83. token := tokenGetter(c)
  84. if tokenize(options.Secret, salt) != token {
  85. errorFunc(c)
  86. return
  87. }
  88. c.Next()
  89. }
  90. }
  91. // GetToken returns a CSRF token.
  92. func GetToken(c *gin.Context) string {
  93. session := sessions.Default(c)
  94. secret := c.MustGet(csrfSecret).(string)
  95. if t, ok := c.Get(csrfToken); ok {
  96. return t.(string)
  97. }
  98. salt, ok := session.Get(csrfSalt).(string)
  99. if !ok {
  100. salt = uniuri.New()
  101. session.Set(csrfSalt, salt)
  102. session.Save()
  103. }
  104. token := tokenize(secret, salt)
  105. c.Set(csrfToken, token)
  106. return token
  107. }