diff --git a/go.mod b/go.mod index 4048e98..be12bb4 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,42 @@ module gitlab.batiao8.com/open/gosdk go 1.18 require ( + github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874 + github.com/gin-gonic/gin v1.9.1 + github.com/gomodule/redigo v1.8.9 github.com/influxdata/influxdb1-client v0.0.0-20220302092344-a9ab5670611c + github.com/sirupsen/logrus v1.9.3 + github.com/smbrave/goutil v0.0.0-20240105134047-64fe0dfafba2 github.com/spf13/cast v1.5.0 + github.com/wechatpay-apiv3/wechatpay-go v0.2.18 + golang.org/x/crypto v0.9.0 + gorm.io/gorm v1.25.5 +) + +require ( + github.com/bytedance/sonic v1.9.1 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect + golang.org/x/arch v0.3.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sys v0.8.0 // indirect + golang.org/x/text v0.9.0 // indirect + google.golang.org/protobuf v1.30.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/qyweixin/app.go b/qyweixin/app.go new file mode 100644 index 0000000..40fe37b --- /dev/null +++ b/qyweixin/app.go @@ -0,0 +1,206 @@ +package qyweixin + +import ( + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "github.com/spf13/cast" + "gitlab.batiao8.com/open/gosdk/util" + "gitlab.batiao8.com/open/gosdk/wechat" + "gitlab.batiao8.com/open/gosdk/wechat/cache" + "gitlab.batiao8.com/open/gosdk/wechat/message" + wutil "gitlab.batiao8.com/open/gosdk/wechat/util" + "net/http" + "strings" + "time" +) + +var ( + wechatCache cache.Cache = cache.NewMemory() +) + +type AppConfig struct { + Corpid string + Secret string + Agent string + Token string + AesKey string + Replay func(message.MixMessage) *message.Reply +} + +type App struct { + tokenExpire int64 + token string + config *AppConfig +} + +type UserInfo struct { + UserId string + RealName string +} + +func NewApp(corpId, secret, agent string) *App { + return &App{ + config: &AppConfig{ + Corpid: corpId, + Secret: secret, + Agent: agent, + }, + } +} + +func (q *App) GetToken() string { + if time.Now().Unix() <= q.tokenExpire-600 { + return q.token + } + q.refreshToken() + return q.token +} + +func (q *App) GetResult(rspBody []byte) (map[string]interface{}, error) { + result := make(map[string]interface{}) + if err := json.Unmarshal(rspBody, &result); err != nil { + log.Errorf("result[%s] error :%s", string(rspBody), err.Error()) + return nil, err + } + if cast.ToInt(result["errcode"]) != 0 { + log.Errorf("result[%s] error ", string(rspBody)) + return nil, fmt.Errorf("%d:%s", cast.ToInt(result["errcode"]), cast.ToString(result["errmsg"])) + } + return result, nil +} + +func (q *App) GetOpenid(userid string) (string, error) { + if err := q.refreshToken(); err != nil { + return "", err + } + reqUrl := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/user/convert_to_openid?access_token=%s", q.GetToken()) + rspBody, err := util.HttpPostJson(reqUrl, nil, []byte(fmt.Sprintf(`{"userid" : "%s"}`, userid))) + if err != nil { + log.Errorf("httpPost url[%s] error :%s", reqUrl, err.Error()) + return "", err + } + result, err := q.GetResult(rspBody) + if err != nil { + return "", err + } + + return cast.ToString(result["openid"]), nil +} + +func (q *App) GetUserInfo(userid string) (*UserInfo, error) { + if err := q.refreshToken(); err != nil { + return nil, err + } + reqUrl := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/user/get?access_token=%s&userid=%s", q.GetToken(), userid) + rspBody, err := util.HttpGet(reqUrl, nil) + if err != nil { + log.Errorf("httpPost url[%s] error :%s", reqUrl, err.Error()) + return nil, err + } + result, err := q.GetResult(rspBody) + if err != nil { + return nil, err + } + + userInfo := new(UserInfo) + userInfo.UserId = userid + userInfo.RealName = cast.ToString(result["name"]) + + return userInfo, nil +} + +func (q *App) GetDepartmentUserid(departmentId int) ([]string, error) { + if err := q.refreshToken(); err != nil { + return nil, err + } + reqUrl := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/user/list?access_token=%s&department_id=%d", q.GetToken(), departmentId) + rspBody, err := util.HttpGet(reqUrl, nil) + if err != nil { + log.Errorf("httpPost url[%s] error :%s", reqUrl, err.Error()) + return nil, err + } + result, err := q.GetResult(rspBody) + if err != nil { + return nil, err + } + + userids := make([]string, 0) + userlist := cast.ToSlice(result["userlist"]) + for _, u := range userlist { + user := cast.ToStringMap(u) + userids = append(userids, cast.ToString(user["userid"])) + } + return userids, nil +} + +func (q *App) Callback(ctx *gin.Context) { + + //配置微信参数 + wechatConfig := &wechat.Config{ + AppID: q.config.Corpid, + AppSecret: q.config.Secret, + Token: q.config.Token, + EncodingAESKey: q.config.AesKey, + Cache: wechatCache, + } + + // 首次配置 + if strings.ToUpper(ctx.Request.Method) == http.MethodGet { + sign := wutil.Signature(ctx.Query("timestamp"), ctx.Query("echostr"), + ctx.Query("nonce"), wechatConfig.Token) + if sign != ctx.Query("msg_signature") { + log.Errorf("sign error forcheck config") + return + } + + _, resp, err := wutil.DecryptMsg(wechatConfig.AppID, ctx.Query("echostr"), wechatConfig.EncodingAESKey) + if err != nil { + log.Errorf("DecryptMsg failed! error:%s", err.Error()) + return + } + ctx.Data(http.StatusOK, "Content-type: text/plain", resp) + return + } + + // 2.响应消息 + wc := wechat.NewWechat(wechatConfig) + ctx.Request.URL.RawQuery += "&encrypt_type=aes" + server := wc.GetServer(ctx.Request, ctx.Writer) + + server.SetMessageHandler(q.config.Replay) + + server.SetDebug(true) + err := server.Serve() + if err != nil { + log.Errorf("qiye weixin Service err:%s", err.Error()) + return + } + err = server.Send() + if err != nil { + log.Errorf("qiye weixin Send err:%s", err.Error()) + return + } +} + +func (q *App) refreshToken() error { + if time.Now().Unix() <= q.tokenExpire-600 { + return nil + } + + reqUrl := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s", q.config.Corpid, q.config.Secret) + rspBody, err := util.HttpGet(reqUrl, nil) + if err != nil { + + return err + } + result, err := q.GetResult(rspBody) + if err != nil { + return err + } + + q.token = cast.ToString(result["access_token"]) + q.tokenExpire = time.Now().Unix() + cast.ToInt64(result["expires_in"]) + return nil +} diff --git a/qyweixin/app_approve.go b/qyweixin/app_approve.go new file mode 100644 index 0000000..e835138 --- /dev/null +++ b/qyweixin/app_approve.go @@ -0,0 +1,190 @@ +package qyweixin + +import ( + "encoding/json" + "fmt" + log "github.com/sirupsen/logrus" + "github.com/smbrave/goutil" + "github.com/spf13/cast" + "gitlab.batiao8.com/open/gosdk/util" +) + +type Applyer struct { + Userid string `json:"userid"` + Partyid string `json:"partyid"` +} + +type Option struct { + Key string `json:"key"` + Value []struct { + Text string `json:"text"` + Lang string `json:"lang"` + } `json:"value"` +} + +type Selector struct { + Type string `json:"type"` + Options []*Option `json:"options"` +} + +type Vacation struct { + Selector *Selector `json:"selector"` + Attendance struct { + DateRange struct { + NewBegin int64 `json:"new_begin"` + NewEnd int64 `json:"new_end"` + NewDuration int64 `json:"new_duration"` + Type string `json:"type"` + } `json:"date_range"` + } `json:"attendance"` +} + +type ApplyValue struct { + Text string `json:"text"` + Selector *Selector `json:"selector"` + Children []interface{} `json:"children"` + Date struct { + Type string `json:"type"` + Timestamp string `json:"s_timestamp"` + } `json:"date"` + NewMoney string `json:"new_Money"` + Files []struct { + FileId string `json:"file_id"` + } `json:"files"` + Vacation *Vacation `json:"vacation"` + PunchCorrection interface{} `json:"punch_correction"` +} + +type ApplyContent struct { + Control string `json:"control"` + Title []struct { + Text string `json:"text"` + Lang string `json:"lang"` + } `json:"title"` + Value *ApplyValue `json:"value"` +} + +type ApproveDetail struct { + SpNo string `json:"sp_no"` + SpName string `json:"sp_name"` + SpStatus int `json:"sp_status"` + TemplateID string `json:"template_id"` + ApplyTime int64 `json:"apply_time"` + Applyer *Applyer `json:"applyer"` + ApplyData struct { + Contents []*ApplyContent `json:"contents"` + } `json:"apply_data"` +} + +type ApproveDetailRsp struct { + Errcode int `json:"errcode"` + Errmsg string `json:"errmsg"` + Info *ApproveDetail `json:"info"` +} + +type AppApprove struct { + App +} + +func (d *ApproveDetail) GetValue(title string) string { + + for _, content := range d.ApplyData.Contents { + key := content.Title[0].Text + if key != title { + continue + } + + var value string + if content.Control == "Selector" { + value = content.Value.Selector.Options[0].Value[0].Text + } else if content.Control == "Text" || content.Control == "Textarea" { + value = content.Value.Text + } else if content.Control == "Date" { + value = content.Value.Date.Timestamp + } else if content.Control == "Money" { + value = content.Value.NewMoney + } else if content.Control == "File" { + value = content.Value.Files[0].FileId + } else if content.Control == "Vacation" { //请假 : 请假类型,开始时间,结束时间,请假时长 + tp := content.Value.Vacation.Selector.Options[0].Value[0].Text + value = tp + "," + cast.ToString(content.Value.Vacation.Attendance.DateRange.NewBegin) + + "," + cast.ToString(content.Value.Vacation.Attendance.DateRange.NewEnd) + + "," + cast.ToString(content.Value.Vacation.Attendance.DateRange.NewDuration) + } else if content.Control == "PunchCorrection" { //补卡:日期,时间,状态 + mp := cast.ToStringMap(content.Value.PunchCorrection) + ddate := cast.ToString(mp["daymonthyear"]) + dtime := cast.ToString(mp["time"]) + if ddate == "" { + ddate = dtime + } + value = ddate + "," + dtime + "," + cast.ToString(mp["state"]) + } + return value + } + return "" +} + +func (d *ApproveDetail) String() string { + return goutil.EncodeJSONIndent(d) +} + +func (d *ApproveDetail) GetUserid() string { + return d.Applyer.Userid +} + +func NewAppApprove(corpId, secret, agent string) *AppApprove { + return &AppApprove{ + App: *NewApp(corpId, secret, agent), + } +} + +func (q *AppApprove) GetDetail(spNo string) (*ApproveDetail, error) { + reqUrl := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/oa/getapprovaldetail?access_token=%s", q.GetToken()) + reqParam := fmt.Sprintf(`{"sp_no":"%s"}`, spNo) + rspBody, err := util.HttpPostJson(reqUrl, nil, []byte(reqParam)) + if err != nil { + return nil, err + } + var rsp ApproveDetailRsp + + mp := make(map[string]interface{}) + json.Unmarshal(rspBody, &mp) + //fmt.Println(goutil.EncodeJSONIndent(mp)) + if err := json.Unmarshal(rspBody, &rsp); err != nil { + log.Errorf("get body[%s] json error :%s", string(rspBody), err.Error()) + return nil, err + } + return rsp.Info, nil +} + +func (q *AppApprove) GetList(start, end int64, templateId string) ([]string, error) { + reqUrl := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/oa/getapprovalinfo?access_token=%s", q.GetToken()) + reqParam := make(map[string]interface{}) + reqParam["starttime"] = cast.ToString(start) + reqParam["endtime"] = cast.ToString(end) + reqParam["new_cursor"] = "" + reqParam["size"] = 100 + filters := make([]interface{}, 0) + if templateId != "" { + filters = append(filters, map[string]interface{}{ + "key": "template_id", + "value": templateId, + }) + } + + filters = append(filters, map[string]interface{}{ + "key": "sp_status", + "value": "2", + }) + reqParam["filters"] = filters + rspBody, err := util.HttpPostJson(reqUrl, nil, []byte(goutil.EncodeJSON(reqParam))) + if err != nil { + log.Errorf("httpPost error :%s", err.Error()) + return nil, err + } + result, err := q.GetResult(rspBody) + if err != nil { + return nil, err + } + return cast.ToStringSlice(result["sp_no_list"]), nil +} diff --git a/qyweixin/app_checkin.go b/qyweixin/app_checkin.go new file mode 100644 index 0000000..503f562 --- /dev/null +++ b/qyweixin/app_checkin.go @@ -0,0 +1,150 @@ +package qyweixin + +import ( + "encoding/json" + "errors" + "fmt" + log "github.com/sirupsen/logrus" + "github.com/smbrave/goutil" + "github.com/spf13/cast" + "gitlab.batiao8.com/open/gosdk/util" + "gorm.io/gorm/utils" + "time" +) + +type UserCheckIn struct { + Day string + Month string + UserId string + Exception string + Rawdata string + StartTime int64 + EndTime int64 +} + +func (u *UserCheckIn) String() string { + return fmt.Sprintf("[%s][%s][%s~%s][%s]", u.UserId, u.Day, + goutil.TimeToDateTime(u.StartTime), goutil.TimeToDateTime(u.EndTime), u.Exception) +} + +type AppCheckin struct { + App +} + +func NewAppCheckin(corpId, secret, agent string) *AppCheckin { + return &AppCheckin{ + App: *NewApp(corpId, secret, agent), + } +} + +func (q *AppCheckin) GetCheckinEmployee(groupIds []string) ([]string, error) { + reqUrl := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/checkin/getcorpcheckinoption?access_token=%s", q.GetToken()) + rspBody, err := util.HttpPostJson(reqUrl, nil, []byte("{}")) + if err != nil { + return nil, err + } + result, err := q.GetResult(rspBody) + if err != nil { + log.Errorf("q.GetResult error :%s ", err.Error()) + return nil, errors.New(string(rspBody)) + } + + resultUser := make([]string, 0) + groups := cast.ToSlice(result["group"]) + for _, group := range groups { + g := cast.ToStringMap(group) + if len(groupIds) != 0 && !utils.Contains(groupIds, cast.ToString(g["groupid"])) { + continue + } + ranges := cast.ToStringMap(g["range"]) + userid := cast.ToStringSlice(ranges["userid"]) + + //包含部门获取部门下的员工 + departmentIds := cast.ToIntSlice(ranges["party_id"]) + if len(departmentIds) != 0 { + for _, did := range departmentIds { + uids, err := q.GetDepartmentUserid(did) + if err != nil { + log.Errorf(" q.GetDepartmentUserid did[%d] error :%s", did, err.Error()) + continue + } + resultUser = append(resultUser, uids...) + } + } + resultUser = append(resultUser, userid...) + } + return resultUser, nil +} + +func (q *AppCheckin) GetCheckinData(startDay, endDay string, userIds []string) ([]*UserCheckIn, error) { + + dayTime, _ := time.ParseInLocation("2006-01-02", startDay, time.Local) + endTime, _ := time.ParseInLocation("2006-01-02", endDay, time.Local) + + reqData := make(map[string]interface{}) + reqData["opencheckindatatype"] = 1 + reqData["starttime"] = dayTime.Unix() + reqData["endtime"] = endTime.Unix() + 86400 + reqData["useridlist"] = userIds + reqUrl := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/checkin/getcheckindata?access_token=%s", q.GetToken()) + rspBody, err := util.HttpPostJson(reqUrl, nil, []byte(goutil.EncodeJSON(reqData))) + if err != nil { + return nil, err + } + result := make(map[string]interface{}) + if err := json.Unmarshal(rspBody, &result); err != nil { + log.Errorf("http url[%s] result[%s] error :%s", reqUrl, string(rspBody), err.Error()) + return nil, err + } + if cast.ToInt(result["errcode"]) != 0 { + log.Errorf("http url[%s] result[%s] error ", reqUrl, string(rspBody)) + return nil, errors.New(string(rspBody)) + } + + datas := cast.ToSlice(result["checkindata"]) + checkData := make(map[string]*UserCheckIn) + for _, dat := range datas { + c := cast.ToStringMap(dat) + exceptionType := cast.ToString(c["exception_type"]) + checkinType := cast.ToString(c["checkin_type"]) + userid := cast.ToString(c["userid"]) + checkinTime := cast.ToInt64(c["checkin_time"]) + schCheckinTime := cast.ToInt64(c["sch_checkin_time"]) + if schCheckinTime == 0 { + schCheckinTime = checkinTime + } + checkDay := time.Unix(schCheckinTime, 0).Format("2006-01-02") + checkMonth := time.Unix(schCheckinTime, 0).Format("200601") + key := fmt.Sprintf("%s_%s", userid, checkDay) + var userData *UserCheckIn = nil + var ok bool + if userData, ok = checkData[key]; !ok { + userData = new(UserCheckIn) + userData.UserId = userid + userData.Day = checkDay + userData.Month = checkMonth + checkData[key] = userData + } + if exceptionType != "" { + userData.Exception += goutil.If(userData.Exception != "", ",", "") + userData.Exception += checkinType + ":" + exceptionType + } + userData.Rawdata = goutil.If(userData.Rawdata == "", "", "\n") + goutil.EncodeJSON(dat) + if checkinType == "上班打卡" { + userData.StartTime = goutil.If(userData.StartTime == 0 || checkinTime < userData.StartTime, checkinTime, userData.StartTime) + } else if checkinType == "下班打卡" { + userData.EndTime = goutil.If(checkinTime > userData.EndTime, checkinTime, userData.EndTime) + } else { + log.Errorf("不支持的打卡类型:%s %s", checkinType, goutil.EncodeJSON(dat)) + } + } + + userDatas := make([]*UserCheckIn, 0) + for _, v := range checkData { + if v.StartTime > v.EndTime { + log.Errorf("user[%s][%s] checkin time error[%s~%s]", v.UserId, v.Day, goutil.TimeToDateTime(v.StartTime), goutil.TimeToDateTime(v.EndTime)) + } + userDatas = append(userDatas, v) + } + return userDatas, nil +} diff --git a/qyweixin/app_hr.go b/qyweixin/app_hr.go new file mode 100644 index 0000000..a7c9839 --- /dev/null +++ b/qyweixin/app_hr.go @@ -0,0 +1,115 @@ +package qyweixin + +import ( + "fmt" + log "github.com/sirupsen/logrus" + "github.com/smbrave/goutil" + "github.com/spf13/cast" + "gitlab.batiao8.com/open/gosdk/util" + "time" +) + +var ( + urlQyWeixinHrGetAllField = "https://qyapi.weixin.qq.com/cgi-bin/hr/get_fields" + urlQyWeixinHrGetStaffInfo = "https://qyapi.weixin.qq.com/cgi-bin/hr/get_staff_info" +) + +type AppHr struct { + App + config *AppConfig +} + +type StaffInfo struct { + UserName string + RealName string + Phone string + StaffType string + Idno string + Salary float64 + Stock float64 + EntryDate string + BirthDate string + OfficialDate string + BankName string + BankCard string +} + +func NewAppHr(corpId, secret, agent string) *AppHr { + return &AppHr{ + App: *NewApp(corpId, secret, secret), + } +} + +func (h *AppHr) GetStaffInfo(userId string) (*StaffInfo, error) { + reqUrl := fmt.Sprintf("%s?access_token=%s", urlQyWeixinHrGetStaffInfo, h.GetToken()) + reqBody := make(map[string]interface{}) + reqBody["userid"] = userId + reqBody["get_all"] = true + rspBody, err := util.HttpPostJson(reqUrl, nil, []byte(goutil.EncodeJSON(reqBody))) + if err != nil { + return nil, err + } + staff := new(StaffInfo) + result, err := h.GetResult(rspBody) + if err != nil { + return nil, err + } + + fieldMap := make(map[string]map[string]interface{}) + for _, fieldInfo := range cast.ToSlice(result["field_info"]) { + fi := cast.ToStringMap(fieldInfo) + fieldMap[cast.ToString(fi["fieldid"])] = fi + } + + userInfo, err := h.GetUserInfo(userId) + if err != nil { + log.Errorf("GetUserInfo error:%s", err.Error()) + return nil, err + } + staff.UserName = userId + staff.RealName = userInfo.RealName + staff.Salary = cast.ToFloat64(h.getFieldValue(fieldMap["20001"])) + staff.Stock = cast.ToFloat64(h.getFieldValue(fieldMap["20002"])) + staff.Phone = cast.ToString(h.getFieldValue(fieldMap["17003"])) + staff.StaffType = cast.ToString(h.getFieldValue(fieldMap["12003"])) + staff.Idno = cast.ToString(h.getFieldValue(fieldMap["11015"])) + staff.BankName = cast.ToString(h.getFieldValue(fieldMap["13001"])) + staff.BankCard = cast.ToString(h.getFieldValue(fieldMap["13002"])) + staff.EntryDate = time.Unix(cast.ToInt64(h.getFieldValue(fieldMap["12018"])), 0).Format("2006-01-02") + staff.BirthDate = time.Unix(cast.ToInt64(h.getFieldValue(fieldMap["11005"])), 0).Format("2006-01-02") + staff.OfficialDate = time.Unix(cast.ToInt64(h.getFieldValue(fieldMap["12023"])), 0).Format("2006-01-02") + + //fmt.Println(goutil.EncodeJSON(staff)) + return staff, nil +} + +func (h *AppHr) getFieldValue(fieldInfo map[string]interface{}) string { + valueType := cast.ToInt(fieldInfo["value_type"]) + if valueType == 1 { + return cast.ToString(fieldInfo["value_string"]) + } else if valueType == 2 { + return cast.ToString(fieldInfo["value_uint64"]) + } else if valueType == 3 { + return cast.ToString(fieldInfo["value_uint32"]) + } else if valueType == 4 { + return cast.ToString(fieldInfo["value_int64"]) + } else if valueType == 5 { + moble := cast.ToStringMap(fieldInfo["value_mobile"]) + return cast.ToString(moble["value_mobile"]) + } + return "" +} + +func (h *AppHr) GetAllField() ([]byte, error) { + reqUrl := fmt.Sprintf("%s?access_token=%s", urlQyWeixinHrGetAllField, h.GetToken()) + rspBody, err := util.HttpGet(reqUrl, nil) + if err != nil { + return nil, err + } + result, err := h.GetResult(rspBody) + if err != nil { + return nil, err + } + fmt.Println(goutil.EncodeJSONIndent(result)) + return rspBody, err +} diff --git a/qyweixin/app_pay.go b/qyweixin/app_pay.go new file mode 100644 index 0000000..52157e8 --- /dev/null +++ b/qyweixin/app_pay.go @@ -0,0 +1,173 @@ +package qyweixin + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + log "github.com/sirupsen/logrus" + "github.com/smbrave/goutil" + "github.com/spf13/cast" + "github.com/wechatpay-apiv3/wechatpay-go/core" + "github.com/wechatpay-apiv3/wechatpay-go/core/option" + "github.com/wechatpay-apiv3/wechatpay-go/utils" + butil "gitlab.batiao8.com/open/gosdk/util" + "io/ioutil" + "net/http" + "os" + "time" +) + +type PayReq struct { + TotalAmount int64 + Title string + Userid string + BillNo string +} + +type PayConfig struct { + Corpid string + Secret string + Agent string + CertPem string + KeyPem string + MchId string + ApiKey string + SerialNumber string +} + +type AppPay struct { + App + cli *core.Client + tlsClient *http.Client + stdClient *http.Client + config *PayConfig +} + +func NewAppPay(cfg *PayConfig) *AppPay { + payCertPem, _ := os.ReadFile(cfg.CertPem) + payKeyPem, _ := os.ReadFile(cfg.KeyPem) + + pay := &AppPay{} + client, err := pay.withCert(payCertPem, payKeyPem) + if err != nil { + log.Errorf("with cert error :%s", err.Error()) + return nil + } + pay.tlsClient = client + + pay.App = *NewApp(cfg.Corpid, cfg.Secret, cfg.Agent) + pay.config = cfg + return pay +} + +// 附着商户证书 +func (c *AppPay) withCert(cert, key []byte) (*http.Client, error) { + tlsCert, err := tls.X509KeyPair(cert, key) + if err != nil { + return nil, err + } + + conf := &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + } + trans := &http.Transport{ + TLSClientConfig: conf, + } + + return &http.Client{ + Transport: trans, + }, nil + +} + +// 发送请求 +func (c *AppPay) post(url string, params butil.Params, tls bool) ([]byte, error) { + var httpc *http.Client + if tls { + if c.tlsClient == nil { + return nil, errors.New("tls wepay is not initialized") + } + httpc = c.tlsClient + } else { + httpc = c.stdClient + } + buf := bytes.NewBuffer(params.Encode()) + resp, err := httpc.Post(url, "application/xml; charset=utf-8", buf) + if err != nil { + return nil, err + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + return body, nil +} + +func (w *AppPay) initPay() error { + // 使用 utils 提供的函数从本地文件中加载商户私钥,商户私钥会用来生成请求的签名 + mchPrivateKey, err := utils.LoadPrivateKeyWithPath(w.config.KeyPem) + if err != nil { + return err + } + ctx := context.Background() + // 使用商户私钥等初始化 client,并使它具有自动定时获取微信支付平台证书的能力 + opts := []core.ClientOption{ + option.WithWechatPayAutoAuthCipher(w.config.MchId, w.config.SerialNumber, mchPrivateKey, w.config.ApiKey), + } + client, err := core.NewClient(ctx, opts...) + if err != nil { + return err + } + w.cli = client + return nil +} + +func (p *AppPay) PayMoney(req *PayReq) error { + if p.cli == nil { + if err := p.initPay(); err != nil { + return err + } + } + + param := butil.NewParams() + userOpenid, err := p.GetOpenid(req.Userid) + if err != nil { + log.Errorf("get openid error :%s", err.Error()) + return err + } + if req.BillNo == "" { + req.BillNo = fmt.Sprintf("QY%s%s", time.Now().Format("20060102150405"), butil.CutTail(req.Userid, 12)) + } + param.Set("nonce_str", goutil.RandomStr(32)) + param.Set("mch_billno", req.BillNo) + param.Set("mch_id", p.config.MchId) + param.Set("wxappid", p.config.Corpid) + param.Set("agentid", p.config.Agent) + param.Set("re_openid", userOpenid) + param.Set("total_amount", cast.ToString(req.TotalAmount)) + param.Set("wishing", req.Title) + param.Set("act_name", "企业红包") + param.Set("remark", "企业红包") + param.Set("workwx_sign", param.QySignMd5(p.config.Secret)) + param.Set("sign", param.SignMd5(p.config.ApiKey)) + + reqUrl := "https://api.mch.weixin.qq.com/mmpaymkttransfers/sendworkwxredpack" + rspBody, err := p.post(reqUrl, param, true) + if err != nil { + log.Errorf("Post [%s] [%s] error :%s", reqUrl, string(param.Encode()), err.Error()) + return err + } + + respParam := butil.NewParams() + respParam.Decode(rspBody) + + returnCode := respParam.GetString("return_code") + resultCoce := respParam.GetString("result_code") + if resultCoce == "SUCCESS" && returnCode == "SUCCESS" { + return nil + } + return errors.New(string(respParam.Encode())) +} diff --git a/util/params.go b/util/params.go new file mode 100644 index 0000000..a3c0cee --- /dev/null +++ b/util/params.go @@ -0,0 +1,169 @@ +package util + +import ( + "bytes" + "crypto/hmac" + "crypto/md5" + "crypto/sha256" + "encoding/hex" + "encoding/xml" + "fmt" + "io" + "sort" + "strconv" + "strings" +) + +type Params map[string]string + +func NewParams() Params { + return make(Params) +} + +func (p Params) Set(k string, v interface{}) { + p[k] = fmt.Sprintf("%v", v) +} + +func (p Params) GetString(k string) string { + s, _ := p[k] + return s +} + +func (p Params) GetUint64(k string) uint64 { + s, _ := strconv.ParseUint(p[k], 10, 64) + return s +} + +func (p Params) GetInt64(k string) int64 { + i, _ := strconv.ParseInt(p.GetString(k), 10, 64) + return i +} + +func (p Params) GetInt(k string) int64 { + i, _ := strconv.ParseInt(p.GetString(k), 10, 64) + return i +} + +func (p Params) GetFloat64(k string) float64 { + f, _ := strconv.ParseFloat(p.GetString(k), 64) + return f +} +func (p Params) GetBool(k string) bool { + b, _ := strconv.ParseBool(p.GetString(k)) + return b +} + +// XML解码 +func (p Params) Decode(body []byte) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("%v", e) + } + }() + + var ( + d *xml.Decoder + start *xml.StartElement + ) + buf := bytes.NewBuffer(body) + d = xml.NewDecoder(buf) + for { + tok, err := d.Token() + if err != nil { + break + } + switch t := tok.(type) { + case xml.StartElement: + start = &t + case xml.CharData: + if t = bytes.TrimSpace(t); len(t) > 0 { + p.Set(start.Name.Local, string(t)) + } + } + } + return nil +} + +// XML编码 +func (p Params) Encode() []byte { + var buf bytes.Buffer + buf.WriteString(``) + for k, v := range p { + buf.WriteString(`<`) + buf.WriteString(k) + buf.WriteString(`>`) + } + buf.WriteString(``) + return buf.Bytes() +} + +// 验证签名 +func (p Params) CheckSign(key string) bool { + return p.GetString("sign") == p.SignMd5(key) +} + +// 生成签名 +func (p Params) getSignStr(key string) string { + var keys = make([]string, 0, len(p)) + for k, _ := range p { + if k != "sign" { + keys = append(keys, k) + } + } + sort.Strings(keys) + + var buf bytes.Buffer + for _, k := range keys { + if len(p.GetString(k)) > 0 { + buf.WriteString(k) + buf.WriteString(`=`) + buf.WriteString(p.GetString(k)) + buf.WriteString(`&`) + } + } + buf.WriteString(`key=`) + buf.WriteString(key) + return buf.String() +} + +func (p Params) getQySignStr(secret string) string { + keys := []string{"act_name", "mch_billno", "mch_id", "nonce_str", "re_openid", "total_amount", "wxappid"} + sort.Strings(keys) + + var buf bytes.Buffer + for _, k := range keys { + buf.WriteString(k) + buf.WriteString(`=`) + buf.WriteString(p.GetString(k)) + buf.WriteString(`&`) + } + buf.WriteString(`secret=`) + buf.WriteString(secret) + return buf.String() +} + +// 生成签名 +func (p Params) QySignMd5(secret string) string { + sum := md5.Sum([]byte(p.getQySignStr(secret))) + str := hex.EncodeToString(sum[:]) + return strings.ToUpper(str) +} + +// 生成签名 +func (p Params) SignMd5(key string) string { + sum := md5.Sum([]byte(p.getSignStr(key))) + str := hex.EncodeToString(sum[:]) + return strings.ToUpper(str) +} + +// 生成签名 +func (p Params) SignHmacSha256(key string) string { + h := hmac.New(sha256.New, []byte(key)) + io.WriteString(h, p.getSignStr(key)) + str := hex.EncodeToString(h.Sum(nil)) + return strings.ToUpper(str) +} diff --git a/util/util.go b/util/util.go new file mode 100644 index 0000000..700d0c3 --- /dev/null +++ b/util/util.go @@ -0,0 +1,18 @@ +package util + +import ( + "fmt" + "github.com/spf13/cast" +) + +func CutTail(str string, length int) string { + if len(str) <= length { + return str + } + + return str[0:length] +} + +func FloatCut(f float64) float64 { + return cast.ToFloat64(fmt.Sprintf("%.2f", f)) +} diff --git a/wechat/cache/cache.go b/wechat/cache/cache.go new file mode 100644 index 0000000..f3feb84 --- /dev/null +++ b/wechat/cache/cache.go @@ -0,0 +1,11 @@ +package cache + +import "time" + +// Cache interface +type Cache interface { + Get(key string) interface{} + Set(key string, val interface{}, timeout time.Duration) error + IsExist(key string) bool + Delete(key string) error +} diff --git a/wechat/cache/memcache.go b/wechat/cache/memcache.go new file mode 100644 index 0000000..8271a9f --- /dev/null +++ b/wechat/cache/memcache.go @@ -0,0 +1,57 @@ +package cache + +import ( + "encoding/json" + "time" + + "github.com/bradfitz/gomemcache/memcache" +) + +// Memcache struct contains *memcache.Client +type Memcache struct { + conn *memcache.Client +} + +// NewMemcache create new memcache +func NewMemcache(server ...string) *Memcache { + mc := memcache.New(server...) + return &Memcache{mc} +} + +// Get return cached value +func (mem *Memcache) Get(key string) interface{} { + var err error + var item *memcache.Item + if item, err = mem.conn.Get(key); err != nil { + return nil + } + var result interface{} + if err = json.Unmarshal(item.Value, &result); err != nil { + return nil + } + return result +} + +// IsExist check value exists in memcache. +func (mem *Memcache) IsExist(key string) bool { + if _, err := mem.conn.Get(key); err != nil { + return false + } + return true +} + +// Set cached value with key and expire time. +func (mem *Memcache) Set(key string, val interface{}, timeout time.Duration) (err error) { + var data []byte + if data, err = json.Marshal(val); err != nil { + return err + } + + item := &memcache.Item{Key: key, Value: data, Expiration: int32(timeout / time.Second)} + return mem.conn.Set(item) +} + +// Delete delete value in memcache. +func (mem *Memcache) Delete(key string) error { + return mem.conn.Delete(key) +} diff --git a/wechat/cache/memcache_test.go b/wechat/cache/memcache_test.go new file mode 100644 index 0000000..6b4ea02 --- /dev/null +++ b/wechat/cache/memcache_test.go @@ -0,0 +1,28 @@ +package cache + +import ( + "testing" + "time" +) + +func TestMemcache(t *testing.T) { + mem := NewMemcache("127.0.0.1:11211") + var err error + timeoutDuration := 10 * time.Second + if err = mem.Set("username", "silenceper", timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if !mem.IsExist("username") { + t.Error("IsExist Error") + } + + name := mem.Get("username").(string) + if name != "silenceper" { + t.Error("get Error") + } + + if err = mem.Delete("username"); err != nil { + t.Errorf("delete Error , err=%v", err) + } +} diff --git a/wechat/cache/memory.go b/wechat/cache/memory.go new file mode 100644 index 0000000..dad34eb --- /dev/null +++ b/wechat/cache/memory.go @@ -0,0 +1,74 @@ +package cache + +import ( + "sync" + "time" +) + +// Memory struct contains *memcache.Client +type Memory struct { + sync.Mutex + + data map[string]*data +} + +type data struct { + Data interface{} + Expired time.Time +} + +// NewMemory create new memcache +func NewMemory() *Memory { + return &Memory{ + data: map[string]*data{}, + } +} + +// Get return cached value +func (mem *Memory) Get(key string) interface{} { + if ret, ok := mem.data[key]; ok { + if ret.Expired.Before(time.Now()) { + mem.deleteKey(key) + return nil + } + return ret.Data + } + return nil +} + +// IsExist check value exists in memcache. +func (mem *Memory) IsExist(key string) bool { + if ret, ok := mem.data[key]; ok { + if ret.Expired.Before(time.Now()) { + mem.deleteKey(key) + return false + } + return true + } + return false +} + +// Set cached value with key and expire time. +func (mem *Memory) Set(key string, val interface{}, timeout time.Duration) (err error) { + mem.Lock() + defer mem.Unlock() + + mem.data[key] = &data{ + Data: val, + Expired: time.Now().Add(timeout), + } + return nil +} + +// Delete delete value in memcache. +func (mem *Memory) Delete(key string) error { + return mem.deleteKey(key) +} + +// deleteKey +func (mem *Memory) deleteKey(key string) error { + mem.Lock() + defer mem.Unlock() + delete(mem.data, key) + return nil +} diff --git a/wechat/cache/redis.go b/wechat/cache/redis.go new file mode 100644 index 0000000..dc4f778 --- /dev/null +++ b/wechat/cache/redis.go @@ -0,0 +1,109 @@ +package cache + +import ( + "encoding/json" + "time" + + "github.com/gomodule/redigo/redis" +) + +// Redis redis cache +type Redis struct { + conn *redis.Pool +} + +// RedisOpts redis 连接属性 +type RedisOpts struct { + Host string `yml:"host" json:"host"` + Password string `yml:"password" json:"password"` + Database int `yml:"database" json:"database"` + MaxIdle int `yml:"max_idle" json:"max_idle"` + MaxActive int `yml:"max_active" json:"max_active"` + IdleTimeout int32 `yml:"idle_timeout" json:"idle_timeout"` //second +} + +// NewRedis 实例化 +func NewRedis(opts *RedisOpts) *Redis { + pool := &redis.Pool{ + MaxActive: opts.MaxActive, + MaxIdle: opts.MaxIdle, + IdleTimeout: time.Second * time.Duration(opts.IdleTimeout), + Dial: func() (redis.Conn, error) { + return redis.Dial("tcp", opts.Host, + redis.DialDatabase(opts.Database), + redis.DialPassword(opts.Password), + ) + }, + TestOnBorrow: func(conn redis.Conn, t time.Time) error { + if time.Since(t) < time.Minute { + return nil + } + _, err := conn.Do("PING") + return err + }, + } + return &Redis{pool} +} + +// SetConn 设置conn +func (r *Redis) SetConn(conn *redis.Pool) { + r.conn = conn +} + +// Get 获取一个值 +func (r *Redis) Get(key string) interface{} { + conn := r.conn.Get() + defer conn.Close() + + var data []byte + var err error + if data, err = redis.Bytes(conn.Do("GET", key)); err != nil { + return nil + } + var reply interface{} + if err = json.Unmarshal(data, &reply); err != nil { + return nil + } + + return reply +} + +// Set 设置一个值 +func (r *Redis) Set(key string, val interface{}, timeout time.Duration) (err error) { + conn := r.conn.Get() + defer conn.Close() + + var data []byte + if data, err = json.Marshal(val); err != nil { + return + } + + _, err = conn.Do("SETEX", key, int64(timeout/time.Second), data) + + return +} + +// IsExist 判断key是否存在 +func (r *Redis) IsExist(key string) bool { + conn := r.conn.Get() + defer conn.Close() + + a, _ := conn.Do("EXISTS", key) + i := a.(int64) + if i > 0 { + return true + } + return false +} + +// Delete 删除 +func (r *Redis) Delete(key string) error { + conn := r.conn.Get() + defer conn.Close() + + if _, err := conn.Do("DEL", key); err != nil { + return err + } + + return nil +} diff --git a/wechat/cache/redis_test.go b/wechat/cache/redis_test.go new file mode 100644 index 0000000..3ced1d3 --- /dev/null +++ b/wechat/cache/redis_test.go @@ -0,0 +1,32 @@ +package cache + +import ( + "testing" + "time" +) + +func TestRedis(t *testing.T) { + opts := &RedisOpts{ + Host: "127.0.0.1:6379", + } + redis := NewRedis(opts) + var err error + timeoutDuration := 1 * time.Second + + if err = redis.Set("username", "silenceper", timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if !redis.IsExist("username") { + t.Error("IsExist Error") + } + + name := redis.Get("username").(string) + if name != "silenceper" { + t.Error("get Error") + } + + if err = redis.Delete("username"); err != nil { + t.Errorf("delete Error , err=%v", err) + } +} diff --git a/wechat/context/access_token.go b/wechat/context/access_token.go new file mode 100644 index 0000000..308edc4 --- /dev/null +++ b/wechat/context/access_token.go @@ -0,0 +1,85 @@ +package context + +import ( + "encoding/json" + "fmt" + "sync" + "time" + + "gitlab.batiao8.com/open/gosdk/wechat/util" +) + +const ( + //AccessTokenURL 获取access_token的接口 + AccessTokenURL = "https://api.weixin.qq.com/cgi-bin/token" +) + +// ResAccessToken struct +type ResAccessToken struct { + util.CommonError + + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` +} + +// GetAccessTokenFunc 获取 access token 的函数签名 +type GetAccessTokenFunc func(ctx *Context) (accessToken string, err error) + +// SetAccessTokenLock 设置读写锁(一个appID一个读写锁) +func (ctx *Context) SetAccessTokenLock(l *sync.RWMutex) { + ctx.accessTokenLock = l +} + +// SetGetAccessTokenFunc 设置自定义获取accessToken的方式, 需要自己实现缓存 +func (ctx *Context) SetGetAccessTokenFunc(f GetAccessTokenFunc) { + ctx.accessTokenFunc = f +} + +// GetAccessToken 获取access_token +func (ctx *Context) GetAccessToken() (accessToken string, err error) { + ctx.accessTokenLock.Lock() + defer ctx.accessTokenLock.Unlock() + + if ctx.accessTokenFunc != nil { + return ctx.accessTokenFunc(ctx) + } + accessTokenCacheKey := fmt.Sprintf("access_token_%s", ctx.AppID) + val := ctx.Cache.Get(accessTokenCacheKey) + if val != nil { + accessToken = val.(string) + return + } + + //从微信服务器获取 + var resAccessToken ResAccessToken + resAccessToken, err = ctx.GetAccessTokenFromServer() + if err != nil { + return + } + + accessToken = resAccessToken.AccessToken + return +} + +// GetAccessTokenFromServer 强制从微信服务器获取token +func (ctx *Context) GetAccessTokenFromServer() (resAccessToken ResAccessToken, err error) { + url := fmt.Sprintf("%s?grant_type=client_credential&appid=%s&secret=%s", AccessTokenURL, ctx.AppID, ctx.AppSecret) + var body []byte + body, err = util.HTTPGet(url) + if err != nil { + return + } + err = json.Unmarshal(body, &resAccessToken) + if err != nil { + return + } + if resAccessToken.ErrMsg != "" { + err = fmt.Errorf("get access_token error : errcode=%v , errormsg=%v", resAccessToken.ErrCode, resAccessToken.ErrMsg) + return + } + + accessTokenCacheKey := fmt.Sprintf("access_token_%s", ctx.AppID) + expires := resAccessToken.ExpiresIn - 1500 + err = ctx.Cache.Set(accessTokenCacheKey, resAccessToken.AccessToken, time.Duration(expires)*time.Second) + return +} diff --git a/wechat/context/access_token_test.go b/wechat/context/access_token_test.go new file mode 100644 index 0000000..fdae218 --- /dev/null +++ b/wechat/context/access_token_test.go @@ -0,0 +1,30 @@ +package context + +import ( + "sync" + "testing" +) + +func TestContext_SetCustomAccessTokenFunc(t *testing.T) { + ctx := Context{ + accessTokenLock: new(sync.RWMutex), + } + f := func(ctx *Context) (accessToken string, err error) { + return "fake token", nil + } + ctx.SetGetAccessTokenFunc(f) + res, err := ctx.GetAccessToken() + if res != "fake token" || err != nil { + t.Error("expect fake token but error") + } +} + +func TestContext_NoSetCustomAccessTokenFunc(t *testing.T) { + ctx := Context{ + accessTokenLock: new(sync.RWMutex), + } + + if ctx.accessTokenFunc != nil { + t.Error("error accessTokenFunc") + } +} diff --git a/wechat/context/component_access_token.go b/wechat/context/component_access_token.go new file mode 100644 index 0000000..1ac131d --- /dev/null +++ b/wechat/context/component_access_token.go @@ -0,0 +1,221 @@ +package context + +import ( + "encoding/json" + "fmt" + "time" + + "gitlab.batiao8.com/open/gosdk/wechat/util" +) + +const ( + componentAccessTokenURL = "https://api.weixin.qq.com/cgi-bin/component/api_component_token" + getPreCodeURL = "https://api.weixin.qq.com/cgi-bin/component/api_create_preauthcode?component_access_token=%s" + queryAuthURL = "https://api.weixin.qq.com/cgi-bin/component/api_query_auth?component_access_token=%s" + refreshTokenURL = "https://api.weixin.qq.com/cgi-bin/component/api_authorizer_token?component_access_token=%s" + getComponentInfoURL = "https://api.weixin.qq.com/cgi-bin/component/api_get_authorizer_info?component_access_token=%s" + getComponentConfigURL = "https://api.weixin.qq.com/cgi-bin/component/api_get_authorizer_option?component_access_token=%s" +) + +// ComponentAccessToken 第三方平台 +type ComponentAccessToken struct { + AccessToken string `json:"component_access_token"` + ExpiresIn int64 `json:"expires_in"` +} + +// GetComponentAccessToken 获取 ComponentAccessToken +func (ctx *Context) GetComponentAccessToken() (string, error) { + accessTokenCacheKey := fmt.Sprintf("component_access_token_%s", ctx.AppID) + val := ctx.Cache.Get(accessTokenCacheKey) + if val == nil { + return "", fmt.Errorf("cann't get component access token") + } + return val.(string), nil +} + +// SetComponentAccessToken 通过component_verify_ticket 获取 ComponentAccessToken +func (ctx *Context) SetComponentAccessToken(verifyTicket string) (*ComponentAccessToken, error) { + body := map[string]string{ + "component_appid": ctx.AppID, + "component_appsecret": ctx.AppSecret, + "component_verify_ticket": verifyTicket, + } + respBody, err := util.PostJSON(componentAccessTokenURL, body) + if err != nil { + return nil, err + } + + at := &ComponentAccessToken{} + if err := json.Unmarshal(respBody, at); err != nil { + return nil, err + } + + accessTokenCacheKey := fmt.Sprintf("component_access_token_%s", ctx.AppID) + expires := at.ExpiresIn - 1500 + ctx.Cache.Set(accessTokenCacheKey, at.AccessToken, time.Duration(expires)*time.Second) + return at, nil +} + +// GetPreCode 获取预授权码 +func (ctx *Context) GetPreCode() (string, error) { + cat, err := ctx.GetComponentAccessToken() + if err != nil { + return "", err + } + req := map[string]string{ + "component_appid": ctx.AppID, + } + uri := fmt.Sprintf(getPreCodeURL, cat) + body, err := util.PostJSON(uri, req) + if err != nil { + return "", err + } + + var ret struct { + PreCode string `json:"pre_auth_code"` + } + if err := json.Unmarshal(body, &ret); err != nil { + return "", err + } + + return ret.PreCode, nil +} + +// ID 微信返回接口中各种类型字段 +type ID struct { + ID int `json:"id"` +} + +// AuthBaseInfo 授权的基本信息 +type AuthBaseInfo struct { + AuthrAccessToken + FuncInfo []AuthFuncInfo `json:"func_info"` +} + +// AuthFuncInfo 授权的接口内容 +type AuthFuncInfo struct { + FuncscopeCategory ID `json:"funcscope_category"` +} + +// AuthrAccessToken 授权方AccessToken +type AuthrAccessToken struct { + Appid string `json:"authorizer_appid"` + AccessToken string `json:"authorizer_access_token"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"authorizer_refresh_token"` +} + +// QueryAuthCode 使用授权码换取公众号或小程序的接口调用凭据和授权信息 +func (ctx *Context) QueryAuthCode(authCode string) (*AuthBaseInfo, error) { + cat, err := ctx.GetComponentAccessToken() + if err != nil { + return nil, err + } + + req := map[string]string{ + "component_appid": ctx.AppID, + "authorization_code": authCode, + } + uri := fmt.Sprintf(queryAuthURL, cat) + body, err := util.PostJSON(uri, req) + if err != nil { + return nil, err + } + + var ret struct { + Info *AuthBaseInfo `json:"authorization_info"` + } + + if err := json.Unmarshal(body, &ret); err != nil { + return nil, err + } + + return ret.Info, nil +} + +// RefreshAuthrToken 获取(刷新)授权公众号或小程序的接口调用凭据(令牌) +func (ctx *Context) RefreshAuthrToken(appid, refreshToken string) (*AuthrAccessToken, error) { + cat, err := ctx.GetComponentAccessToken() + if err != nil { + return nil, err + } + + req := map[string]string{ + "component_appid": ctx.AppID, + "authorizer_appid": appid, + "authorizer_refresh_token": refreshToken, + } + uri := fmt.Sprintf(refreshTokenURL, cat) + body, err := util.PostJSON(uri, req) + if err != nil { + return nil, err + } + + ret := &AuthrAccessToken{} + if err := json.Unmarshal(body, ret); err != nil { + return nil, err + } + + authrTokenKey := "authorizer_access_token_" + appid + ctx.Cache.Set(authrTokenKey, ret.AccessToken, time.Minute*80) + + return ret, nil +} + +// GetAuthrAccessToken 获取授权方AccessToken +func (ctx *Context) GetAuthrAccessToken(appid string) (string, error) { + authrTokenKey := "authorizer_access_token_" + appid + val := ctx.Cache.Get(authrTokenKey) + if val == nil { + return "", fmt.Errorf("cannot get authorizer %s access token", appid) + } + return val.(string), nil +} + +// AuthorizerInfo 授权方详细信息 +type AuthorizerInfo struct { + NickName string `json:"nick_name"` + HeadImg string `json:"head_img"` + ServiceTypeInfo ID `json:"service_type_info"` + VerifyTypeInfo ID `json:"verify_type_info"` + UserName string `json:"user_name"` + PrincipalName string `json:"principal_name"` + BusinessInfo struct { + OpenStore string `json:"open_store"` + OpenScan string `json:"open_scan"` + OpenPay string `json:"open_pay"` + OpenCard string `json:"open_card"` + OpenShake string `json:"open_shake"` + } + Alias string `json:"alias"` + QrcodeURL string `json:"qrcode_url"` +} + +// GetAuthrInfo 获取授权方的帐号基本信息 +func (ctx *Context) GetAuthrInfo(appid string) (*AuthorizerInfo, *AuthBaseInfo, error) { + cat, err := ctx.GetComponentAccessToken() + if err != nil { + return nil, nil, err + } + + req := map[string]string{ + "component_appid": ctx.AppID, + "authorizer_appid": appid, + } + + uri := fmt.Sprintf(getComponentInfoURL, cat) + body, err := util.PostJSON(uri, req) + if err != nil { + return nil, nil, err + } + + var ret struct { + AuthorizerInfo *AuthorizerInfo `json:"authorizer_info"` + AuthorizationInfo *AuthBaseInfo `json:"authorization_info"` + } + if err := json.Unmarshal(body, &ret); err != nil { + return nil, nil, err + } + + return ret.AuthorizerInfo, ret.AuthorizationInfo, nil +} diff --git a/wechat/context/component_test.go b/wechat/context/component_test.go new file mode 100644 index 0000000..5bf087d --- /dev/null +++ b/wechat/context/component_test.go @@ -0,0 +1,19 @@ +package context + +import ( + "encoding/json" + "testing" +) + +var testdata = `{"authorizer_info":{"nick_name":"就爱浪","head_img":"http:\/\/wx.qlogo.cn\/mmopen\/xPKCxELaaj6hiaTZGv19oQPBJibb7hBoKmNOjQibCNOUycE8iaBhiaHOA6eC8hadQSAUZTuHUJl4qCIbCQGjSWialicfzWh4mdxuejY\/0","service_type_info":{"id":1},"verify_type_info":{"id":-1},"user_name":"gh_dcdbaa6f1687","alias":"ckeyer","qrcode_url":"http:\/\/mmbiz.qpic.cn\/mmbiz_jpg\/FribWCoIzQbAX7R1PQ8iaxGonqKp0doYD2ibhC0uhx11LrRcblASiazsbQJTJ4icQnMzfH7G0SUPuKbibTA8Cs4uk5WQ\/0","business_info":{"open_pay":0,"open_shake":0,"open_scan":0,"open_card":0,"open_store":0},"idc":1,"principal_name":"个人","signature":"不折腾会死。"},"authorization_info":{"authorizer_appid":"yyyyy","authorizer_refresh_token":"xxxx","func_info":[{"funcscope_category":{"id":1}},{"funcscope_category":{"id":15}},{"funcscope_category":{"id":4}},{"funcscope_category":{"id":7}},{"funcscope_category":{"id":2}},{"funcscope_category":{"id":3}},{"funcscope_category":{"id":11}},{"funcscope_category":{"id":6}},{"funcscope_category":{"id":5}},{"funcscope_category":{"id":8}},{"funcscope_category":{"id":13}},{"funcscope_category":{"id":9}},{"funcscope_category":{"id":12}},{"funcscope_category":{"id":22}},{"funcscope_category":{"id":23}},{"funcscope_category":{"id":24},"confirm_info":{"need_confirm":0,"already_confirm":0,"can_confirm":0}},{"funcscope_category":{"id":26}},{"funcscope_category":{"id":27},"confirm_info":{"need_confirm":0,"already_confirm":0,"can_confirm":0}},{"funcscope_category":{"id":33},"confirm_info":{"need_confirm":0,"already_confirm":0,"can_confirm":0}},{"funcscope_category":{"id":35}}]}}` + +// TestDecode +func TestDecode(t *testing.T) { + var ret struct { + AuthorizerInfo *AuthorizerInfo `json:"authorizer_info"` + AuthorizationInfo *AuthBaseInfo `json:"authorization_info"` + } + json.Unmarshal([]byte(testdata), &ret) + t.Logf("%+v", ret.AuthorizerInfo) + t.Logf("%+v", ret.AuthorizationInfo) +} diff --git a/wechat/context/context.go b/wechat/context/context.go new file mode 100644 index 0000000..5571153 --- /dev/null +++ b/wechat/context/context.go @@ -0,0 +1,58 @@ +package context + +import ( + "net/http" + "sync" + + "gitlab.batiao8.com/open/gosdk/wechat/cache" +) + +// Context struct +type Context struct { + AppID string + AppSecret string + Token string + EncodingAESKey string + PayMchID string + PayNotifyURL string + PayKey string + + Cache cache.Cache + + Writer http.ResponseWriter + Request *http.Request + + //accessTokenLock 读写锁 同一个AppID一个 + accessTokenLock *sync.RWMutex + + //jsAPITicket 读写锁 同一个AppID一个 + jsAPITicketLock *sync.RWMutex + + //accessTokenFunc 自定义获取 access token 的方法 + accessTokenFunc GetAccessTokenFunc +} + +// Query returns the keyed url query value if it exists +func (ctx *Context) Query(key string) string { + value, _ := ctx.GetQuery(key) + return value +} + +// GetQuery is like Query(), it returns the keyed url query value +func (ctx *Context) GetQuery(key string) (string, bool) { + req := ctx.Request + if values, ok := req.URL.Query()[key]; ok && len(values) > 0 { + return values[0], true + } + return "", false +} + +// SetJsAPITicketLock 设置jsAPITicket的lock +func (ctx *Context) SetJsAPITicketLock(lock *sync.RWMutex) { + ctx.jsAPITicketLock = lock +} + +// GetJsAPITicketLock 获取jsAPITicket 的lock +func (ctx *Context) GetJsAPITicketLock() *sync.RWMutex { + return ctx.jsAPITicketLock +} diff --git a/wechat/context/qy_access_token.go b/wechat/context/qy_access_token.go new file mode 100644 index 0000000..cfe730a --- /dev/null +++ b/wechat/context/qy_access_token.go @@ -0,0 +1,76 @@ +package context + +import ( + "encoding/json" + "fmt" + "log" + "sync" + "time" + + "gitlab.batiao8.com/open/gosdk/wechat/util" +) + +const ( + //qyAccessTokenURL 获取access_token的接口 + qyAccessTokenURL = "https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s" +) + +// ResQyAccessToken struct +type ResQyAccessToken struct { + util.CommonError + + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` +} + +// SetQyAccessTokenLock 设置读写锁(一个appID一个读写锁) +func (ctx *Context) SetQyAccessTokenLock(l *sync.RWMutex) { + ctx.accessTokenLock = l +} + +// GetQyAccessToken 获取access_token +func (ctx *Context) GetQyAccessToken() (accessToken string, err error) { + ctx.accessTokenLock.Lock() + defer ctx.accessTokenLock.Unlock() + + accessTokenCacheKey := fmt.Sprintf("qy_access_token_%s", ctx.AppID) + val := ctx.Cache.Get(accessTokenCacheKey) + if val != nil { + accessToken = val.(string) + return + } + + //从微信服务器获取 + var resQyAccessToken ResQyAccessToken + resQyAccessToken, err = ctx.GetQyAccessTokenFromServer() + if err != nil { + return + } + + accessToken = resQyAccessToken.AccessToken + return +} + +// GetQyAccessTokenFromServer 强制从微信服务器获取token +func (ctx *Context) GetQyAccessTokenFromServer() (resQyAccessToken ResQyAccessToken, err error) { + log.Printf("GetQyAccessTokenFromServer") + url := fmt.Sprintf(qyAccessTokenURL, ctx.AppID, ctx.AppSecret) + var body []byte + body, err = util.HTTPGet(url) + if err != nil { + return + } + err = json.Unmarshal(body, &resQyAccessToken) + if err != nil { + return + } + if resQyAccessToken.ErrCode != 0 { + err = fmt.Errorf("get qy_access_token error : errcode=%v , errormsg=%v", resQyAccessToken.ErrCode, resQyAccessToken.ErrMsg) + return + } + + qyAccessTokenCacheKey := fmt.Sprintf("qy_access_token_%s", ctx.AppID) + expires := resQyAccessToken.ExpiresIn - 1500 + err = ctx.Cache.Set(qyAccessTokenCacheKey, resQyAccessToken.AccessToken, time.Duration(expires)*time.Second) + return +} diff --git a/wechat/context/render.go b/wechat/context/render.go new file mode 100644 index 0000000..85d1d0d --- /dev/null +++ b/wechat/context/render.go @@ -0,0 +1,43 @@ +package context + +import ( + "encoding/xml" + "net/http" +) + +var xmlContentType = []string{"application/xml; charset=utf-8"} +var plainContentType = []string{"text/plain; charset=utf-8"} + +// Render render from bytes +func (ctx *Context) Render(bytes []byte) { + //debug + //fmt.Println("response msg = ", string(bytes)) + ctx.Writer.WriteHeader(200) + _, err := ctx.Writer.Write(bytes) + if err != nil { + panic(err) + } +} + +// String render from string +func (ctx *Context) String(str string) { + writeContextType(ctx.Writer, plainContentType) + ctx.Render([]byte(str)) +} + +// XML render to xml +func (ctx *Context) XML(obj interface{}) { + writeContextType(ctx.Writer, xmlContentType) + bytes, err := xml.Marshal(obj) + if err != nil { + panic(err) + } + ctx.Render(bytes) +} + +func writeContextType(w http.ResponseWriter, value []string) { + header := w.Header() + if val := header["Content-Type"]; len(val) == 0 { + header["Content-Type"] = value + } +} diff --git a/wechat/message/customer_message.go b/wechat/message/customer_message.go new file mode 100644 index 0000000..68f80c4 --- /dev/null +++ b/wechat/message/customer_message.go @@ -0,0 +1,160 @@ +package message + +import ( + "encoding/json" + "fmt" + "gitlab.batiao8.com/open/gosdk/wechat/context" + "gitlab.batiao8.com/open/gosdk/wechat/util" +) + +const ( + customerSendMessage = "https://api.weixin.qq.com/cgi-bin/message/custom/send" +) + +// Manager 消息管理者,可以发送消息 +type Manager struct { + *context.Context +} + +// NewMessageManager 实例化消息管理者 +func NewMessageManager(context *context.Context) *Manager { + return &Manager{ + context, + } +} + +// CustomerMessage 客服消息 +type CustomerMessage struct { + ToUser string `json:"touser"` //接受者OpenID + Msgtype MsgType `json:"msgtype"` //客服消息类型 + Text *MediaText `json:"text,omitempty"` //可选 + Image *MediaResource `json:"image,omitempty"` //可选 + Voice *MediaResource `json:"voice,omitempty"` //可选 + Video *MediaVideo `json:"video,omitempty"` //可选 + Music *MediaMusic `json:"music,omitempty"` //可选 + News *MediaNews `json:"news,omitempty"` //可选 + Mpnews *MediaResource `json:"mpnews,omitempty"` //可选 + Wxcard *MediaWxcard `json:"wxcard,omitempty"` //可选 + Msgmenu *MediaMsgmenu `json:"msgmenu,omitempty"` //可选 + Miniprogrampage *MediaMiniprogrampage `json:"miniprogrampage,omitempty"` //可选 +} + +// NewCustomerTextMessage 文本消息结构体构造方法 +func NewCustomerTextMessage(toUser, text string) *CustomerMessage { + return &CustomerMessage{ + ToUser: toUser, + Msgtype: MsgTypeText, + Text: &MediaText{ + text, + }, + } +} + +// NewCustomerImgMessage 图片消息的构造方法 +func NewCustomerImgMessage(toUser, mediaID string) *CustomerMessage { + return &CustomerMessage{ + ToUser: toUser, + Msgtype: MsgTypeImage, + Image: &MediaResource{ + mediaID, + }, + } +} + +// NewCustomerVoiceMessage 语音消息的构造方法 +func NewCustomerVoiceMessage(toUser, mediaID string) *CustomerMessage { + return &CustomerMessage{ + ToUser: toUser, + Msgtype: MsgTypeVoice, + Voice: &MediaResource{ + mediaID, + }, + } +} + +// MediaText 文本消息的文字 +type MediaText struct { + Content string `json:"content"` +} + +// MediaResource 消息使用的永久素材id +type MediaResource struct { + MediaID string `json:"media_id"` +} + +// MediaVideo 视频消息包含的内容 +type MediaVideo struct { + MediaID string `json:"media_id"` + ThumbMediaID string `json:"thumb_media_id"` + Title string `json:"title"` + Description string `json:"description"` +} + +// MediaMusic 音乐消息包括的内容 +type MediaMusic struct { + Title string `json:"title"` + Description string `json:"description"` + Musicurl string `json:"musicurl"` + Hqmusicurl string `json:"hqmusicurl"` + ThumbMediaID string `json:"thumb_media_id"` +} + +// MediaNews 图文消息的内容 +type MediaNews struct { + Articles []MediaArticles `json:"articles"` +} + +// MediaArticles 图文消息的内容的文章列表中的单独一条 +type MediaArticles struct { + Title string `json:"title"` + Description string `json:"description"` + URL string `json:"url"` + Picurl string `json:"picurl"` +} + +// MediaMsgmenu 菜单消息的内容 +type MediaMsgmenu struct { + HeadContent string `json:"head_content"` + List []MsgmenuItem `json:"list"` + TailContent string `json:"tail_content"` +} + +// MsgmenuItem 菜单消息的菜单按钮 +type MsgmenuItem struct { + ID string `json:"id"` + Content string `json:"content"` +} + +// MediaWxcard 卡券的id +type MediaWxcard struct { + CardID string `json:"card_id"` +} + +// MediaMiniprogrampage 小程序消息 +type MediaMiniprogrampage struct { + Title string `json:"title"` + Appid string `json:"appid"` + Pagepath string `json:"pagepath"` + ThumbMediaID string `json:"thumb_media_id"` +} + +// Send 发送客服消息 +func (manager *Manager) Send(msg *CustomerMessage) error { + accessToken, err := manager.Context.GetAccessToken() + if err != nil { + return err + } + uri := fmt.Sprintf("%s?access_token=%s", customerSendMessage, accessToken) + response, err := util.PostJSON(uri, msg) + var result util.CommonError + err = json.Unmarshal(response, &result) + if err != nil { + return err + } + if result.ErrCode != 0 { + err = fmt.Errorf("customer msg send error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg) + return err + } + + return nil +} diff --git a/wechat/message/image.go b/wechat/message/image.go new file mode 100644 index 0000000..b79f9a4 --- /dev/null +++ b/wechat/message/image.go @@ -0,0 +1,17 @@ +package message + +// Image 图片消息 +type Image struct { + CommonToken + + Image struct { + MediaID string `xml:"MediaId"` + } `xml:"Image"` +} + +// NewImage 回复图片消息 +func NewImage(mediaID string) *Image { + image := new(Image) + image.Image.MediaID = mediaID + return image +} diff --git a/wechat/message/message.go b/wechat/message/message.go new file mode 100644 index 0000000..6097f24 --- /dev/null +++ b/wechat/message/message.go @@ -0,0 +1,228 @@ +package message + +import ( + "encoding/xml" +) + +// MsgType 基本消息类型 +type MsgType string + +// EventType 事件类型 +type EventType string + +// InfoType 第三方平台授权事件类型 +type InfoType string + +const ( + //MsgTypeText 表示文本消息 + MsgTypeText MsgType = "text" + //MsgTypeImage 表示图片消息 + MsgTypeImage = "image" + //MsgTypeVoice 表示语音消息 + MsgTypeVoice = "voice" + //MsgTypeVideo 表示视频消息 + MsgTypeVideo = "video" + //MsgTypeShortVideo 表示短视频消息[限接收] + MsgTypeShortVideo = "shortvideo" + //MsgTypeLocation 表示坐标消息[限接收] + MsgTypeLocation = "location" + //MsgTypeLink 表示链接消息[限接收] + MsgTypeLink = "link" + //MsgTypeMusic 表示音乐消息[限回复] + MsgTypeMusic = "music" + //MsgTypeNews 表示图文消息[限回复] + MsgTypeNews = "news" + //MsgTypeTransfer 表示消息消息转发到客服 + MsgTypeTransfer = "transfer_customer_service" + //MsgTypeEvent 表示事件推送消息 + MsgTypeEvent = "event" +) + +const ( + //EventSubscribe 订阅 + EventSubscribe EventType = "subscribe" + //EventUnsubscribe 取消订阅 + EventUnsubscribe = "unsubscribe" + //EventScan 用户已经关注公众号,则微信会将带场景值扫描事件推送给开发者 + EventScan = "SCAN" + //EventLocation 上报地理位置事件 + EventLocation = "LOCATION" + //EventClick 点击菜单拉取消息时的事件推送 + EventClick = "CLICK" + //EventView 点击菜单跳转链接时的事件推送 + EventView = "VIEW" + //EventScancodePush 扫码推事件的事件推送 + EventScancodePush = "scancode_push" + //EventScancodeWaitmsg 扫码推事件且弹出“消息接收中”提示框的事件推送 + EventScancodeWaitmsg = "scancode_waitmsg" + //EventPicSysphoto 弹出系统拍照发图的事件推送 + EventPicSysphoto = "pic_sysphoto" + //EventPicPhotoOrAlbum 弹出拍照或者相册发图的事件推送 + EventPicPhotoOrAlbum = "pic_photo_or_album" + //EventPicWeixin 弹出微信相册发图器的事件推送 + EventPicWeixin = "pic_weixin" + //EventLocationSelect 弹出地理位置选择器的事件推送 + EventLocationSelect = "location_select" + //EventTemplateSendJobFinish 发送模板消息推送通知 + EventTemplateSendJobFinish = "TEMPLATESENDJOBFINISH" + //EventWxaMediaCheck 异步校验图片/音频是否含有违法违规内容推送事件 + EventWxaMediaCheck = "wxa_media_check" +) + +const ( + // InfoTypeVerifyTicket 返回ticket + InfoTypeVerifyTicket InfoType = "component_verify_ticket" + // InfoTypeAuthorized 授权 + InfoTypeAuthorized = "authorized" + // InfoTypeUnauthorized 取消授权 + InfoTypeUnauthorized = "unauthorized" + // InfoTypeUpdateAuthorized 更新授权 + InfoTypeUpdateAuthorized = "updateauthorized" +) + +// MixMessage 存放所有微信发送过来的消息和事件 +type MixMessage struct { + CommonToken + + //基本消息 + MsgID int64 `xml:"MsgId"` + Content string `xml:"Content"` + Recognition string `xml:"Recognition"` + PicURL string `xml:"PicUrl"` + MediaID string `xml:"MediaId"` + Format string `xml:"Format"` + ThumbMediaID string `xml:"ThumbMediaId"` + LocationX float64 `xml:"Location_X"` + LocationY float64 `xml:"Location_Y"` + Scale float64 `xml:"Scale"` + Label string `xml:"Label"` + Title string `xml:"Title"` + Description string `xml:"Description"` + URL string `xml:"Url"` + + //事件相关 + Event EventType `xml:"Event"` + EventKey string `xml:"EventKey"` + Ticket string `xml:"Ticket"` + Latitude string `xml:"Latitude"` + Longitude string `xml:"Longitude"` + Precision string `xml:"Precision"` + MenuID string `xml:"MenuId"` + Status string `xml:"Status"` + SessionFrom string `xml:"SessionFrom"` + + ScanCodeInfo struct { + ScanType string `xml:"ScanType"` + ScanResult string `xml:"ScanResult"` + } `xml:"ScanCodeInfo"` + + ApprovalInfo struct { + SpNo string `xml:"SpNo"` + SpName string `xml:"SpName"` + SpStatus int `xml:"SpStatus"` + TemplateId string `xml:"TemplateId"` + ApplyTime int64 `xml:"ApplyTime"` + Applyer struct { + UserId string `xml:"UserId"` + } `xml:"Applyer"` + } `xml:"ApprovalInfo"` + + SendPicsInfo struct { + Count int32 `xml:"Count"` + PicList []EventPic `xml:"PicList>item"` + } `xml:"SendPicsInfo"` + + SendLocationInfo struct { + LocationX float64 `xml:"Location_X"` + LocationY float64 `xml:"Location_Y"` + Scale float64 `xml:"Scale"` + Label string `xml:"Label"` + Poiname string `xml:"Poiname"` + } + + // 第三方平台相关 + InfoType InfoType `xml:"InfoType"` + AppID string `xml:"AppId"` + ComponentVerifyTicket string `xml:"ComponentVerifyTicket"` + AuthorizerAppid string `xml:"AuthorizerAppid"` + AuthorizationCode string `xml:"AuthorizationCode"` + AuthorizationCodeExpiredTime int64 `xml:"AuthorizationCodeExpiredTime"` + PreAuthCode string `xml:"PreAuthCode"` + + // 卡券相关 + CardID string `xml:"CardId"` + RefuseReason string `xml:"RefuseReason"` + IsGiveByFriend int32 `xml:"IsGiveByFriend"` + FriendUserName string `xml:"FriendUserName"` + UserCardCode string `xml:"UserCardCode"` + OldUserCardCode string `xml:"OldUserCardCode"` + OuterStr string `xml:"OuterStr"` + IsRestoreMemberCard int32 `xml:"IsRestoreMemberCard"` + UnionID string `xml:"UnionId"` + + // 内容审核相关 + IsRisky bool `xml:"isrisky"` + ExtraInfoJSON string `xml:"extra_info_json"` + TraceID string `xml:"trace_id"` + StatusCode int `xml:"status_code"` +} + +// EventPic 发图事件推送 +type EventPic struct { + PicMd5Sum string `xml:"PicMd5Sum"` +} + +// EncryptedXMLMsg 安全模式下的消息体 +type EncryptedXMLMsg struct { + XMLName struct{} `xml:"xml" json:"-"` + ToUserName string `xml:"ToUserName" json:"ToUserName"` + EncryptedMsg string `xml:"Encrypt" json:"Encrypt"` +} + +// ResponseEncryptedXMLMsg 需要返回的消息体 +type ResponseEncryptedXMLMsg struct { + XMLName struct{} `xml:"xml" json:"-"` + EncryptedMsg string `xml:"Encrypt" json:"Encrypt"` + MsgSignature string `xml:"MsgSignature" json:"MsgSignature"` + Timestamp int64 `xml:"TimeStamp" json:"TimeStamp"` + Nonce string `xml:"Nonce" json:"Nonce"` +} + +// CDATA 使用该类型,在序列化为 xml 文本时文本会被解析器忽略 +type CDATA string + +// MarshalXML 实现自己的序列化方法 +func (c CDATA) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + return e.EncodeElement(struct { + string `xml:",cdata"` + }{string(c)}, start) +} + +// CommonToken 消息中通用的结构 +type CommonToken struct { + XMLName xml.Name `xml:"xml"` + ToUserName CDATA `xml:"ToUserName"` + FromUserName CDATA `xml:"FromUserName"` + CreateTime int64 `xml:"CreateTime"` + MsgType MsgType `xml:"MsgType"` +} + +// SetToUserName set ToUserName +func (msg *CommonToken) SetToUserName(toUserName CDATA) { + msg.ToUserName = toUserName +} + +// SetFromUserName set FromUserName +func (msg *CommonToken) SetFromUserName(fromUserName CDATA) { + msg.FromUserName = fromUserName +} + +// SetCreateTime set createTime +func (msg *CommonToken) SetCreateTime(createTime int64) { + msg.CreateTime = createTime +} + +// SetMsgType set MsgType +func (msg *CommonToken) SetMsgType(msgType MsgType) { + msg.MsgType = msgType +} diff --git a/wechat/message/music.go b/wechat/message/music.go new file mode 100644 index 0000000..5d71f5a --- /dev/null +++ b/wechat/message/music.go @@ -0,0 +1,24 @@ +package message + +// Music 音乐消息 +type Music struct { + CommonToken + + Music struct { + Title string `xml:"Title" ` + Description string `xml:"Description" ` + MusicURL string `xml:"MusicUrl" ` + HQMusicURL string `xml:"HQMusicUrl" ` + ThumbMediaID string `xml:"ThumbMediaId"` + } `xml:"Music"` +} + +// NewMusic 回复音乐消息 +func NewMusic(title, description, musicURL, hQMusicURL, thumbMediaID string) *Music { + music := new(Music) + music.Music.Title = title + music.Music.Description = description + music.Music.MusicURL = musicURL + music.Music.ThumbMediaID = thumbMediaID + return music +} diff --git a/wechat/message/news.go b/wechat/message/news.go new file mode 100644 index 0000000..fa00249 --- /dev/null +++ b/wechat/message/news.go @@ -0,0 +1,35 @@ +package message + +// News 图文消息 +type News struct { + CommonToken + + ArticleCount int `xml:"ArticleCount"` + Articles []*Article `xml:"Articles>item,omitempty"` +} + +// NewNews 初始化图文消息 +func NewNews(articles []*Article) *News { + news := new(News) + news.ArticleCount = len(articles) + news.Articles = articles + return news +} + +// Article 单篇文章 +type Article struct { + Title string `xml:"Title,omitempty"` + Description string `xml:"Description,omitempty"` + PicURL string `xml:"PicUrl,omitempty"` + URL string `xml:"Url,omitempty"` +} + +// NewArticle 初始化文章 +func NewArticle(title, description, picURL, url string) *Article { + article := new(Article) + article.Title = title + article.Description = description + article.PicURL = picURL + article.URL = url + return article +} diff --git a/wechat/message/ransfer_customer.go b/wechat/message/ransfer_customer.go new file mode 100644 index 0000000..01b94ee --- /dev/null +++ b/wechat/message/ransfer_customer.go @@ -0,0 +1,24 @@ +package message + +// TransferCustomer 转发客服消息 +type TransferCustomer struct { + CommonToken + + TransInfo *TransInfo `xml:"TransInfo,omitempty"` +} + +// TransInfo 转发到指定客服 +type TransInfo struct { + KfAccount string `xml:"KfAccount"` +} + +// NewTransferCustomer 实例化 +func NewTransferCustomer(KfAccount string) *TransferCustomer { + tc := new(TransferCustomer) + if KfAccount != "" { + transInfo := new(TransInfo) + transInfo.KfAccount = KfAccount + tc.TransInfo = transInfo + } + return tc +} diff --git a/wechat/message/reply.go b/wechat/message/reply.go new file mode 100644 index 0000000..5488fb8 --- /dev/null +++ b/wechat/message/reply.go @@ -0,0 +1,15 @@ +package message + +import "errors" + +// ErrInvalidReply 无效的回复 +var ErrInvalidReply = errors.New("无效的回复消息") + +// ErrUnsupportReply 不支持的回复类型 +var ErrUnsupportReply = errors.New("不支持的回复消息") + +// Reply 消息回复 +type Reply struct { + MsgType MsgType + MsgData interface{} +} diff --git a/wechat/message/template.go b/wechat/message/template.go new file mode 100644 index 0000000..824bff8 --- /dev/null +++ b/wechat/message/template.go @@ -0,0 +1,74 @@ +package message + +import ( + "encoding/json" + "fmt" + + "gitlab.batiao8.com/open/gosdk/wechat/context" + "gitlab.batiao8.com/open/gosdk/wechat/util" +) + +const ( + templateSendURL = "https://api.weixin.qq.com/cgi-bin/message/template/send" +) + +// Template 模板消息 +type Template struct { + *context.Context +} + +// NewTemplate 实例化 +func NewTemplate(context *context.Context) *Template { + tpl := new(Template) + tpl.Context = context + return tpl +} + +// Message 发送的模板消息内容 +type Message struct { + ToUser string `json:"touser"` // 必须, 接受者OpenID + TemplateID string `json:"template_id"` // 必须, 模版ID + URL string `json:"url,omitempty"` // 可选, 用户点击后跳转的URL, 该URL必须处于开发者在公众平台网站中设置的域中 + Color string `json:"color,omitempty"` // 可选, 整个消息的颜色, 可以不设置 + Data map[string]*DataItem `json:"data"` // 必须, 模板数据 + + MiniProgram struct { + AppID string `json:"appid"` //所需跳转到的小程序appid(该小程序appid必须与发模板消息的公众号是绑定关联关系) + PagePath string `json:"pagepath"` //所需跳转到小程序的具体页面路径,支持带参数,(示例index?foo=bar) + } `json:"miniprogram"` //可选,跳转至小程序地址 +} + +// DataItem 模版内某个 .DATA 的值 +type DataItem struct { + Value string `json:"value"` + Color string `json:"color,omitempty"` +} + +type resTemplateSend struct { + util.CommonError + + MsgID int64 `json:"msgid"` +} + +// Send 发送模板消息 +func (tpl *Template) Send(msg *Message) (msgID int64, err error) { + var accessToken string + accessToken, err = tpl.GetAccessToken() + if err != nil { + return + } + uri := fmt.Sprintf("%s?access_token=%s", templateSendURL, accessToken) + response, err := util.PostJSON(uri, msg) + + var result resTemplateSend + err = json.Unmarshal(response, &result) + if err != nil { + return + } + if result.ErrCode != 0 { + err = fmt.Errorf("template msg send error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg) + return + } + msgID = result.MsgID + return +} diff --git a/wechat/message/text.go b/wechat/message/text.go new file mode 100644 index 0000000..a6819d9 --- /dev/null +++ b/wechat/message/text.go @@ -0,0 +1,14 @@ +package message + +// Text 文本消息 +type Text struct { + CommonToken + Content CDATA `xml:"Content"` +} + +// NewText 初始化文本消息 +func NewText(content string) *Text { + text := new(Text) + text.Content = CDATA(content) + return text +} diff --git a/wechat/message/video.go b/wechat/message/video.go new file mode 100644 index 0000000..6f64875 --- /dev/null +++ b/wechat/message/video.go @@ -0,0 +1,21 @@ +package message + +// Video 视频消息 +type Video struct { + CommonToken + + Video struct { + MediaID string `xml:"MediaId"` + Title string `xml:"Title,omitempty"` + Description string `xml:"Description,omitempty"` + } `xml:"Video"` +} + +// NewVideo 回复图片消息 +func NewVideo(mediaID, title, description string) *Video { + video := new(Video) + video.Video.MediaID = mediaID + video.Video.Title = title + video.Video.Description = description + return video +} diff --git a/wechat/message/voice.go b/wechat/message/voice.go new file mode 100644 index 0000000..a9cb662 --- /dev/null +++ b/wechat/message/voice.go @@ -0,0 +1,17 @@ +package message + +// Voice 语音消息 +type Voice struct { + CommonToken + + Voice struct { + MediaID string `xml:"MediaId"` + } `xml:"Voice"` +} + +// NewVoice 回复语音消息 +func NewVoice(mediaID string) *Voice { + voice := new(Voice) + voice.Voice.MediaID = mediaID + return voice +} diff --git a/wechat/readme.txt b/wechat/readme.txt new file mode 100644 index 0000000..66c1f2c --- /dev/null +++ b/wechat/readme.txt @@ -0,0 +1 @@ +主要用到企业微信审批回调功能,开源版本没有ApprovalInfo相关字段,所以复制过来手动添加 \ No newline at end of file diff --git a/wechat/server/server.go b/wechat/server/server.go new file mode 100644 index 0000000..6e140dd --- /dev/null +++ b/wechat/server/server.go @@ -0,0 +1,244 @@ +package server + +import ( + "encoding/xml" + "errors" + "fmt" + "io/ioutil" + "reflect" + "runtime/debug" + "strconv" + + "gitlab.batiao8.com/open/gosdk/wechat/context" + "gitlab.batiao8.com/open/gosdk/wechat/message" + "gitlab.batiao8.com/open/gosdk/wechat/util" +) + +// Server struct +type Server struct { + *context.Context + + debug bool + + openID string + + messageHandler func(message.MixMessage) *message.Reply + + requestRawXMLMsg []byte + requestMsg message.MixMessage + responseRawXMLMsg []byte + responseMsg interface{} + + isSafeMode bool + random []byte + nonce string + timestamp int64 +} + +// NewServer init +func NewServer(context *context.Context) *Server { + srv := new(Server) + srv.Context = context + return srv +} + +// SetDebug set debug field +func (srv *Server) SetDebug(debug bool) { + srv.debug = debug +} + +// Serve 处理微信的请求消息 +func (srv *Server) Serve() error { + if !srv.Validate() { + return fmt.Errorf("请求校验失败") + } + + echostr, exists := srv.GetQuery("echostr") + if exists { + srv.String(echostr) + return nil + } + + response, err := srv.handleRequest() + if err != nil { + return err + } + + //debug + if srv.debug { + fmt.Println("request msg = ", string(srv.requestRawXMLMsg)) + } + + return srv.buildResponse(response) +} + +// Validate 校验请求是否合法 +func (srv *Server) Validate() bool { + if srv.debug { + return true + } + timestamp := srv.Query("timestamp") + nonce := srv.Query("nonce") + signature := srv.Query("signature") + return signature == util.Signature(srv.Token, timestamp, nonce) +} + +// HandleRequest 处理微信的请求 +func (srv *Server) handleRequest() (reply *message.Reply, err error) { + //set isSafeMode + srv.isSafeMode = false + encryptType := srv.Query("encrypt_type") + if encryptType == "aes" { + srv.isSafeMode = true + } + + //set openID + srv.openID = srv.Query("openid") + + var msg interface{} + msg, err = srv.getMessage() + if err != nil { + return + } + mixMessage, success := msg.(message.MixMessage) + if !success { + err = errors.New("消息类型转换失败") + } + srv.requestMsg = mixMessage + reply = srv.messageHandler(mixMessage) + return +} + +// GetOpenID return openID +func (srv *Server) GetOpenID() string { + return srv.openID +} + +// getMessage 解析微信返回的消息 +func (srv *Server) getMessage() (interface{}, error) { + var rawXMLMsgBytes []byte + var err error + if srv.isSafeMode { + var encryptedXMLMsg message.EncryptedXMLMsg + if err := xml.NewDecoder(srv.Request.Body).Decode(&encryptedXMLMsg); err != nil { + return nil, fmt.Errorf("从body中解析xml失败,err=%v", err) + } + + //验证消息签名 + timestamp := srv.Query("timestamp") + srv.timestamp, err = strconv.ParseInt(timestamp, 10, 32) + if err != nil { + return nil, err + } + nonce := srv.Query("nonce") + srv.nonce = nonce + msgSignature := srv.Query("msg_signature") + msgSignatureGen := util.Signature(srv.Token, timestamp, nonce, encryptedXMLMsg.EncryptedMsg) + if msgSignature != msgSignatureGen { + return nil, fmt.Errorf("消息不合法,验证签名失败") + } + + //解密 + srv.random, rawXMLMsgBytes, err = util.DecryptMsg(srv.AppID, encryptedXMLMsg.EncryptedMsg, srv.EncodingAESKey) + if err != nil { + return nil, fmt.Errorf("消息解密失败, err=%v", err) + } + } else { + rawXMLMsgBytes, err = ioutil.ReadAll(srv.Request.Body) + if err != nil { + return nil, fmt.Errorf("从body中解析xml失败, err=%v", err) + } + } + + srv.requestRawXMLMsg = rawXMLMsgBytes + + return srv.parseRequestMessage(rawXMLMsgBytes) +} + +func (srv *Server) parseRequestMessage(rawXMLMsgBytes []byte) (msg message.MixMessage, err error) { + msg = message.MixMessage{} + err = xml.Unmarshal(rawXMLMsgBytes, &msg) + return +} + +// SetMessageHandler 设置用户自定义的回调方法 +func (srv *Server) SetMessageHandler(handler func(message.MixMessage) *message.Reply) { + srv.messageHandler = handler +} + +func (srv *Server) buildResponse(reply *message.Reply) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("panic error: %v\n%s", e, debug.Stack()) + } + }() + if reply == nil { + //do nothing + return nil + } + msgType := reply.MsgType + switch msgType { + case message.MsgTypeText: + case message.MsgTypeImage: + case message.MsgTypeVoice: + case message.MsgTypeVideo: + case message.MsgTypeMusic: + case message.MsgTypeNews: + case message.MsgTypeTransfer: + default: + err = message.ErrUnsupportReply + return + } + + msgData := reply.MsgData + value := reflect.ValueOf(msgData) + //msgData must be a ptr + kind := value.Kind().String() + if "ptr" != kind { + return message.ErrUnsupportReply + } + + params := make([]reflect.Value, 1) + params[0] = reflect.ValueOf(srv.requestMsg.FromUserName) + value.MethodByName("SetToUserName").Call(params) + + params[0] = reflect.ValueOf(srv.requestMsg.ToUserName) + value.MethodByName("SetFromUserName").Call(params) + + params[0] = reflect.ValueOf(msgType) + value.MethodByName("SetMsgType").Call(params) + + params[0] = reflect.ValueOf(util.GetCurrTs()) + value.MethodByName("SetCreateTime").Call(params) + + srv.responseMsg = msgData + srv.responseRawXMLMsg, err = xml.Marshal(msgData) + return +} + +// Send 将自定义的消息发送 +func (srv *Server) Send() (err error) { + replyMsg := srv.responseMsg + if srv.isSafeMode { + //安全模式下对消息进行加密 + var encryptedMsg []byte + encryptedMsg, err = util.EncryptMsg(srv.random, srv.responseRawXMLMsg, srv.AppID, srv.EncodingAESKey) + if err != nil { + return + } + //TODO 如果获取不到timestamp nonce 则自己生成 + timestamp := srv.timestamp + timestampStr := strconv.FormatInt(timestamp, 10) + msgSignature := util.Signature(srv.Token, timestampStr, srv.nonce, string(encryptedMsg)) + replyMsg = message.ResponseEncryptedXMLMsg{ + EncryptedMsg: string(encryptedMsg), + MsgSignature: msgSignature, + Timestamp: timestamp, + Nonce: srv.nonce, + } + } + if replyMsg != nil { + srv.XML(replyMsg) + } + return +} diff --git a/wechat/util/crypto.go b/wechat/util/crypto.go new file mode 100644 index 0000000..0f07ead --- /dev/null +++ b/wechat/util/crypto.go @@ -0,0 +1,199 @@ +package util + +import ( + "bufio" + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/md5" + "encoding/base64" + "encoding/hex" + "fmt" +) + +// EncryptMsg 加密消息 +func EncryptMsg(random, rawXMLMsg []byte, appID, aesKey string) (encrtptMsg []byte, err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("panic error: err=%v", e) + return + } + }() + var key []byte + key, err = aesKeyDecode(aesKey) + if err != nil { + panic(err) + } + ciphertext := AESEncryptMsg(random, rawXMLMsg, appID, key) + encrtptMsg = []byte(base64.StdEncoding.EncodeToString(ciphertext)) + return +} + +// AESEncryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId] +// 参考:github.com/chanxuehong/wechat.v2 +func AESEncryptMsg(random, rawXMLMsg []byte, appID string, aesKey []byte) (ciphertext []byte) { + const ( + BlockSize = 32 // PKCS#7 + BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数 + ) + + appIDOffset := 20 + len(rawXMLMsg) + contentLen := appIDOffset + len(appID) + amountToPad := BlockSize - contentLen&BlockMask + plaintextLen := contentLen + amountToPad + + plaintext := make([]byte, plaintextLen) + + // 拼接 + copy(plaintext[:16], random) + encodeNetworkByteOrder(plaintext[16:20], uint32(len(rawXMLMsg))) + copy(plaintext[20:], rawXMLMsg) + copy(plaintext[appIDOffset:], appID) + + // PKCS#7 补位 + for i := contentLen; i < plaintextLen; i++ { + plaintext[i] = byte(amountToPad) + } + + // 加密 + block, err := aes.NewCipher(aesKey[:]) + if err != nil { + panic(err) + } + mode := cipher.NewCBCEncrypter(block, aesKey[:16]) + mode.CryptBlocks(plaintext, plaintext) + + ciphertext = plaintext + return +} + +// DecryptMsg 消息解密 +func DecryptMsg(appID, encryptedMsg, aesKey string) (random, rawMsgXMLBytes []byte, err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("panic error: err=%v", e) + return + } + }() + var encryptedMsgBytes, key, getAppIDBytes []byte + encryptedMsgBytes, err = base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return + } + key, err = aesKeyDecode(aesKey) + if err != nil { + panic(err) + } + random, rawMsgXMLBytes, getAppIDBytes, err = AESDecryptMsg(encryptedMsgBytes, key) + if err != nil { + err = fmt.Errorf("消息解密失败,%v", err) + return + } + if appID != string(getAppIDBytes) { + err = fmt.Errorf("消息解密校验APPID失败") + return + } + return +} + +func aesKeyDecode(encodedAESKey string) (key []byte, err error) { + if len(encodedAESKey) != 43 { + err = fmt.Errorf("the length of encodedAESKey must be equal to 43") + return + } + key, err = base64.StdEncoding.DecodeString(encodedAESKey + "=") + if err != nil { + return + } + if len(key) != 32 { + err = fmt.Errorf("encodingAESKey invalid") + return + } + return +} + +// AESDecryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId] +// 参考:github.com/chanxuehong/wechat.v2 +func AESDecryptMsg(ciphertext []byte, aesKey []byte) (random, rawXMLMsg, appID []byte, err error) { + const ( + BlockSize = 32 // PKCS#7 + BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数 + ) + + if len(ciphertext) < BlockSize { + err = fmt.Errorf("the length of ciphertext too short: %d", len(ciphertext)) + return + } + if len(ciphertext)&BlockMask != 0 { + err = fmt.Errorf("ciphertext is not a multiple of the block size, the length is %d", len(ciphertext)) + return + } + + plaintext := make([]byte, len(ciphertext)) // len(plaintext) >= BLOCK_SIZE + + // 解密 + block, err := aes.NewCipher(aesKey) + if err != nil { + panic(err) + } + mode := cipher.NewCBCDecrypter(block, aesKey[:16]) + mode.CryptBlocks(plaintext, ciphertext) + + // PKCS#7 去除补位 + amountToPad := int(plaintext[len(plaintext)-1]) + if amountToPad < 1 || amountToPad > BlockSize { + err = fmt.Errorf("the amount to pad is incorrect: %d", amountToPad) + return + } + plaintext = plaintext[:len(plaintext)-amountToPad] + + // 反拼接 + // len(plaintext) == 16+4+len(rawXMLMsg)+len(appId) + if len(plaintext) <= 20 { + err = fmt.Errorf("plaintext too short, the length is %d", len(plaintext)) + return + } + rawXMLMsgLen := int(decodeNetworkByteOrder(plaintext[16:20])) + if rawXMLMsgLen < 0 { + err = fmt.Errorf("incorrect msg length: %d", rawXMLMsgLen) + return + } + appIDOffset := 20 + rawXMLMsgLen + if len(plaintext) <= appIDOffset { + err = fmt.Errorf("msg length too large: %d", rawXMLMsgLen) + return + } + + random = plaintext[:16:20] + rawXMLMsg = plaintext[20:appIDOffset:appIDOffset] + appID = plaintext[appIDOffset:] + return +} + +// 把整数 n 格式化成 4 字节的网络字节序 +func encodeNetworkByteOrder(orderBytes []byte, n uint32) { + orderBytes[0] = byte(n >> 24) + orderBytes[1] = byte(n >> 16) + orderBytes[2] = byte(n >> 8) + orderBytes[3] = byte(n) +} + +// 从 4 字节的网络字节序里解析出整数 +func decodeNetworkByteOrder(orderBytes []byte) (n uint32) { + return uint32(orderBytes[0])<<24 | + uint32(orderBytes[1])<<16 | + uint32(orderBytes[2])<<8 | + uint32(orderBytes[3]) +} + +// MD5Sum 计算 32 位长度的 MD5 sum +func MD5Sum(txt string) (sum string) { + h := md5.New() + buf := bufio.NewWriterSize(h, 128) + buf.WriteString(txt) + buf.Flush() + sign := make([]byte, hex.EncodedLen(h.Size())) + hex.Encode(sign, h.Sum(nil)) + sum = string(bytes.ToUpper(sign)) + return +} diff --git a/wechat/util/error.go b/wechat/util/error.go new file mode 100644 index 0000000..b971c47 --- /dev/null +++ b/wechat/util/error.go @@ -0,0 +1,51 @@ +package util + +import ( + "encoding/json" + "fmt" + "reflect" +) + +// CommonError 微信返回的通用错误json +type CommonError struct { + ErrCode int64 `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +// DecodeWithCommonError 将返回值按照CommonError解析 +func DecodeWithCommonError(response []byte, apiName string) (err error) { + var commError CommonError + err = json.Unmarshal(response, &commError) + if err != nil { + return + } + if commError.ErrCode != 0 { + return fmt.Errorf("%s Error , errcode=%d , errmsg=%s", apiName, commError.ErrCode, commError.ErrMsg) + } + return nil +} + +// DecodeWithError 将返回值按照解析 +func DecodeWithError(response []byte, obj interface{}, apiName string) error { + err := json.Unmarshal(response, obj) + if err != nil { + return fmt.Errorf("json Unmarshal Error, err=%v", err) + } + responseObj := reflect.ValueOf(obj) + if !responseObj.IsValid() { + return fmt.Errorf("obj is invalid") + } + commonError := responseObj.Elem().FieldByName("CommonError") + if !commonError.IsValid() || commonError.Kind() != reflect.Struct { + return fmt.Errorf("commonError is invalid or not struct") + } + errCode := commonError.FieldByName("ErrCode") + errMsg := commonError.FieldByName("ErrMsg") + if !errCode.IsValid() || !errMsg.IsValid() { + return fmt.Errorf("errcode or errmsg is invalid") + } + if errCode.Int() != 0 { + return fmt.Errorf("%s Error , errcode=%d , errmsg=%s", apiName, errCode.Int(), errMsg.String()) + } + return nil +} diff --git a/wechat/util/http.go b/wechat/util/http.go new file mode 100644 index 0000000..3ce872c --- /dev/null +++ b/wechat/util/http.go @@ -0,0 +1,252 @@ +package util + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "encoding/pem" + "encoding/xml" + "fmt" + "io" + "io/ioutil" + "log" + "mime/multipart" + "net/http" + "os" + + "golang.org/x/crypto/pkcs12" +) + +// HTTPGet get 请求 +func HTTPGet(uri string) ([]byte, error) { + response, err := http.Get(uri) + if err != nil { + return nil, err + } + + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("http get error : uri=%v , statusCode=%v", uri, response.StatusCode) + } + return ioutil.ReadAll(response.Body) +} + +// HTTPPost post 请求 +func HTTPPost(uri string, data string) ([]byte, error) { + body := bytes.NewBuffer([]byte(data)) + response, err := http.Post(uri, "", body) + if err != nil { + return nil, err + } + + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("http get error : uri=%v , statusCode=%v", uri, response.StatusCode) + } + return ioutil.ReadAll(response.Body) +} + +// PostJSON post json 数据请求 +func PostJSON(uri string, obj interface{}) ([]byte, error) { + jsonData, err := json.Marshal(obj) + if err != nil { + return nil, err + } + jsonData = bytes.Replace(jsonData, []byte("\\u003c"), []byte("<"), -1) + jsonData = bytes.Replace(jsonData, []byte("\\u003e"), []byte(">"), -1) + jsonData = bytes.Replace(jsonData, []byte("\\u0026"), []byte("&"), -1) + body := bytes.NewBuffer(jsonData) + response, err := http.Post(uri, "application/json;charset=utf-8", body) + if err != nil { + return nil, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("http get error : uri=%v , statusCode=%v", uri, response.StatusCode) + } + return ioutil.ReadAll(response.Body) +} + +// PostJSONWithRespContentType post json数据请求,且返回数据类型 +func PostJSONWithRespContentType(uri string, obj interface{}) ([]byte, string, error) { + jsonData, err := json.Marshal(obj) + if err != nil { + return nil, "", err + } + + jsonData = bytes.Replace(jsonData, []byte("\\u003c"), []byte("<"), -1) + jsonData = bytes.Replace(jsonData, []byte("\\u003e"), []byte(">"), -1) + jsonData = bytes.Replace(jsonData, []byte("\\u0026"), []byte("&"), -1) + + body := bytes.NewBuffer(jsonData) + response, err := http.Post(uri, "application/json;charset=utf-8", body) + if err != nil { + return nil, "", err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return nil, "", fmt.Errorf("http get error : uri=%v , statusCode=%v", uri, response.StatusCode) + } + responseData, err := ioutil.ReadAll(response.Body) + contentType := response.Header.Get("Content-Type") + return responseData, contentType, err +} + +// PostFile 上传文件 +func PostFile(fieldname, filename, uri string) ([]byte, error) { + fields := []MultipartFormField{ + { + IsFile: true, + Fieldname: fieldname, + Filename: filename, + }, + } + return PostMultipartForm(fields, uri) +} + +// MultipartFormField 保存文件或其他字段信息 +type MultipartFormField struct { + IsFile bool + Fieldname string + Value []byte + Filename string +} + +// PostMultipartForm 上传文件或其他多个字段 +func PostMultipartForm(fields []MultipartFormField, uri string) (respBody []byte, err error) { + bodyBuf := &bytes.Buffer{} + bodyWriter := multipart.NewWriter(bodyBuf) + + for _, field := range fields { + if field.IsFile { + fileWriter, e := bodyWriter.CreateFormFile(field.Fieldname, field.Filename) + if e != nil { + err = fmt.Errorf("error writing to buffer , err=%v", e) + return + } + + fh, e := os.Open(field.Filename) + if e != nil { + err = fmt.Errorf("error opening file , err=%v", e) + return + } + defer fh.Close() + + if _, err = io.Copy(fileWriter, fh); err != nil { + return + } + } else { + partWriter, e := bodyWriter.CreateFormField(field.Fieldname) + if e != nil { + err = e + return + } + valueReader := bytes.NewReader(field.Value) + if _, err = io.Copy(partWriter, valueReader); err != nil { + return + } + } + } + + contentType := bodyWriter.FormDataContentType() + bodyWriter.Close() + + resp, e := http.Post(uri, contentType, bodyBuf) + if e != nil { + err = e + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, err + } + respBody, err = ioutil.ReadAll(resp.Body) + return +} + +// PostXML perform a HTTP/POST request with XML body +func PostXML(uri string, obj interface{}) ([]byte, error) { + xmlData, err := xml.Marshal(obj) + if err != nil { + return nil, err + } + + body := bytes.NewBuffer(xmlData) + response, err := http.Post(uri, "application/xml;charset=utf-8", body) + if err != nil { + return nil, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("http code error : uri=%v , statusCode=%v", uri, response.StatusCode) + } + return ioutil.ReadAll(response.Body) +} + +// httpWithTLS CA证书 +func httpWithTLS(rootCa, key string) (*http.Client, error) { + var client *http.Client + certData, err := ioutil.ReadFile(rootCa) + if err != nil { + return nil, fmt.Errorf("unable to find cert path=%s, error=%v", rootCa, err) + } + cert := pkcs12ToPem(certData, key) + config := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + tr := &http.Transport{ + TLSClientConfig: config, + DisableCompression: true, + } + client = &http.Client{Transport: tr} + return client, nil +} + +// pkcs12ToPem 将Pkcs12转成Pem +func pkcs12ToPem(p12 []byte, password string) tls.Certificate { + blocks, err := pkcs12.ToPEM(p12, password) + defer func() { + if x := recover(); x != nil { + log.Print(x) + } + }() + if err != nil { + panic(err) + } + var pemData []byte + for _, b := range blocks { + pemData = append(pemData, pem.EncodeToMemory(b)...) + } + cert, err := tls.X509KeyPair(pemData, pemData) + if err != nil { + panic(err) + } + return cert +} + +// PostXMLWithTLS perform a HTTP/POST request with XML body and TLS +func PostXMLWithTLS(uri string, obj interface{}, ca, key string) ([]byte, error) { + xmlData, err := xml.Marshal(obj) + if err != nil { + return nil, err + } + + body := bytes.NewBuffer(xmlData) + client, err := httpWithTLS(ca, key) + if err != nil { + return nil, err + } + response, err := client.Post(uri, "application/xml;charset=utf-8", body) + if err != nil { + return nil, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("http code error : uri=%v , statusCode=%v", uri, response.StatusCode) + } + return ioutil.ReadAll(response.Body) +} diff --git a/wechat/util/signature.go b/wechat/util/signature.go new file mode 100644 index 0000000..2d9b3fa --- /dev/null +++ b/wechat/util/signature.go @@ -0,0 +1,18 @@ +package util + +import ( + "crypto/sha1" + "fmt" + "io" + "sort" +) + +// Signature sha1签名 +func Signature(params ...string) string { + sort.Strings(params) + h := sha1.New() + for _, s := range params { + io.WriteString(h, s) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} diff --git a/wechat/util/signature_test.go b/wechat/util/signature_test.go new file mode 100644 index 0000000..9aa2f7f --- /dev/null +++ b/wechat/util/signature_test.go @@ -0,0 +1,11 @@ +package util + +import "testing" + +func TestSignature(t *testing.T) { + //abc sig + abc := "a9993e364706816aba3e25717850c26c9cd0d89d" + if abc != Signature("a", "b", "c") { + t.Error("test Signature Error") + } +} diff --git a/wechat/util/string.go b/wechat/util/string.go new file mode 100644 index 0000000..8179b70 --- /dev/null +++ b/wechat/util/string.go @@ -0,0 +1,18 @@ +package util + +import ( + "math/rand" + "time" +) + +// RandomStr 随机生成字符串 +func RandomStr(length int) string { + str := "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + bytes := []byte(str) + result := []byte{} + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := 0; i < length; i++ { + result = append(result, bytes[r.Intn(len(bytes))]) + } + return string(result) +} diff --git a/wechat/util/time.go b/wechat/util/time.go new file mode 100644 index 0000000..4deeaf9 --- /dev/null +++ b/wechat/util/time.go @@ -0,0 +1,8 @@ +package util + +import "time" + +// GetCurrTs return current timestamps +func GetCurrTs() int64 { + return time.Now().Unix() +} diff --git a/wechat/wechat.go b/wechat/wechat.go new file mode 100644 index 0000000..a3a6f72 --- /dev/null +++ b/wechat/wechat.go @@ -0,0 +1,59 @@ +package wechat + +import ( + "net/http" + "sync" + + "gitlab.batiao8.com/open/gosdk/wechat/cache" + "gitlab.batiao8.com/open/gosdk/wechat/context" + "gitlab.batiao8.com/open/gosdk/wechat/server" +) + +// Wechat struct +type Wechat struct { + Context *context.Context +} + +// Config for user +type Config struct { + AppID string + AppSecret string + Token string + EncodingAESKey string + PayMchID string //支付 - 商户 ID + PayNotifyURL string //支付 - 接受微信支付结果通知的接口地址 + PayKey string //支付 - 商户后台设置的支付 key + Cache cache.Cache +} + +// NewWechat init +func NewWechat(cfg *Config) *Wechat { + context := new(context.Context) + copyConfigToContext(cfg, context) + return &Wechat{context} +} + +func copyConfigToContext(cfg *Config, context *context.Context) { + context.AppID = cfg.AppID + context.AppSecret = cfg.AppSecret + context.Token = cfg.Token + context.EncodingAESKey = cfg.EncodingAESKey + context.PayMchID = cfg.PayMchID + context.PayKey = cfg.PayKey + context.PayNotifyURL = cfg.PayNotifyURL + context.Cache = cfg.Cache + context.SetAccessTokenLock(new(sync.RWMutex)) + context.SetJsAPITicketLock(new(sync.RWMutex)) +} + +// GetServer 消息管理 +func (wc *Wechat) GetServer(req *http.Request, writer http.ResponseWriter) *server.Server { + wc.Context.Request = req + wc.Context.Writer = writer + return server.NewServer(wc.Context) +} + +// GetAccessToken 获取access_token +func (wc *Wechat) GetAccessToken() (string, error) { + return wc.Context.GetAccessToken() +}