From a6db607812147f16ef1657515160d2a0bc57ea40 Mon Sep 17 00:00:00 2001 From: system Date: Mon, 19 Aug 2024 16:03:32 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20/?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gitea.go | 424 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 424 insertions(+) create mode 100644 gitea.go diff --git a/gitea.go b/gitea.go new file mode 100644 index 0000000..40fea5a --- /dev/null +++ b/gitea.go @@ -0,0 +1,424 @@ +package gitea + +import ( + "crypto/tls" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "model-repository/domain" + "net/http" + "regexp" + "strconv" + "time" + + "github.com/go-resty/resty/v2" + "gopkg.in/yaml.v2" +) + +const ( + DEFALUT_REVISION = "main" + TIME_FORMAT = "2006-01-02T15:04:05.000Z" + REF_VERSION_LIMIT = 20 +) + +var ( + TagList = []string{ + "audio-classification", + "automatic-speech-recognition", + "conversational", + "depth-estimation", + "document-question-answering", + "feature-extraction", + "fill-mask", + "image-classification", + "image-feature-extraction", + "image-segmentation", + "image-to-image", + "image-to-text", + "mask-generation", + "ner", + "object-detection", + "question-answering", + "sentiment-analysis", + "summarization", + "table-question-answering", + "text-classification", + "text-generation", + "text-to-audio", + "text-to-speech", + "text2text-generation", + "token-classification", + "translation", + "video-classification", + "visual-question-answering", + "vqa", + "zero-shot-audio-classification", + "zero-shot-classification", + "zero-shot-image-classification", + "zero-shot-object-detection", + "translation_XX_to_YY", + } +) + +type GiteaStorage struct { + client *resty.Client +} + +func NewGiteaStorage(baseUrl string) *GiteaStorage { + client := resty.New().SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}).SetBaseURL(baseUrl).SetTimeout(5 * time.Minute) + return &GiteaStorage{ + client: client, + } +} + +func (gs *GiteaStorage) GetPipelineTag(hfModel *domain.HuggingFaceModelRequest) (*Metadata, error) { + url := fmt.Sprintf("/api/v1/repos/%s/%s/raw/README.md", hfModel.Project, hfModel.Name) + resp, err := gs.client.R(). + SetHeader("Accept", "application/json"). + SetHeader("access_token", hfModel.AccessToken). + SetQueryParams(map[string]string{ + "ref": safeSubstring(hfModel.Version, REF_VERSION_LIMIT), + }). + Get(url) + if err != nil { + return nil, err + } + var metadata Metadata + err = yaml.Unmarshal(resp.Body(), &metadata) + if err != nil { + fmt.Println("Error parsing YAML:", err) + return nil, err + } + return &metadata, nil +} + +func (gs *GiteaStorage) GetModelInfo(hfModel *domain.HuggingFaceModelRequest) (*domain.HuggingFaceModelInfo, error) { + url := "" + mir := &ModelInfoResp{} + hfmi := &domain.HuggingFaceModelInfo{} + + if hfModel.Version == "" { + url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, DEFALUT_REVISION) + } else { + url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, hfModel.Version) + } + resp, err := gs.client.R(). + SetHeader("Accept", "application/json"). + SetHeader("access_token", hfModel.AccessToken). + SetQueryParams(map[string]string{ + "recursive": "true", + }). + SetResult(mir). + Get(url) + defer resp.RawBody().Close() + + if err != nil { + return nil, err + } + if resp.StatusCode() == 404 { + return nil, fmt.Errorf(domain.HUGGINGFACE_HEADER_X_Error_Code_REPO_NotFound) + } + for _, file := range mir.Tree { + hfmi.Siblings = append(hfmi.Siblings, domain.Sibling{Rfilename: file.Path}) + } + hfmi.LastModified = time.Now().Format(TIME_FORMAT) + hfmi.CreatedAt = time.Now().Format(TIME_FORMAT) + hfmi.ModelInfoID = hfModel.Project + "/" + hfModel.Name + hfmi.ModelID = hfModel.Project + "/" + hfModel.Name + hfmi.SHA = mir.SHA + hfmi.LibraryName = "transformers" + readmeMeta, err := gs.GetPipelineTag(hfModel) + if err != nil { + return nil, err + } + hfmi.Tags = readmeMeta.Tags + if readmeMeta.PipLinetag == "" { + for _, rtag := range readmeMeta.Tags { + for _, tag := range TagList { + if rtag == tag { + hfmi.PipelineTag = tag + break + } + } + } + } else { + hfmi.PipelineTag = readmeMeta.PipLinetag + } + return hfmi, nil +} + +func (gs *GiteaStorage) GetModelFileTree(hfModel *domain.HuggingFaceModelRequest) ([]domain.HuggingFaceModelFileInfo, error) { + url := "" + mir := &ModelInfoResp{} + var tree []domain.HuggingFaceModelFileInfo + + if hfModel.Version == "" { + url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, DEFALUT_REVISION) + } else { + url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, hfModel.Version) + } + resp, err := gs.client.R(). + SetHeader("Accept", "application/json"). + SetHeader("access_token", hfModel.AccessToken). + SetQueryParams(map[string]string{ + "recursive": "true", + }). + SetResult(mir). + Get(url) + if err != nil { + return nil, err + } + defer resp.RawBody().Close() + for _, file := range mir.Tree { + if domain.IsLfsFile(file.Path) { + content, err := gs.GetFileContent(&domain.HuggingFaceFileRequest{ + Project: hfModel.Project, + Name: hfModel.Name, + File: file.Path, + Version: hfModel.Version, + AccessToken: hfModel.AccessToken, + }) + if err != nil { + return nil, err + } + size, hash := gs.ExtractLfsMetadata(content) + tree = append(tree, domain.HuggingFaceModelFileInfo{ + Type: string(file.Type), + Oid: file.SHA, + Size: size, + Path: file.Path, + Lfs: &domain.HuggingFaceModelLfsFileInfo{ + Oid: hash, + Size: size, + PointerSize: file.Size, + }, + }) + } else { + tree = append(tree, domain.HuggingFaceModelFileInfo{ + Type: string(file.Type), + Oid: file.SHA, + Size: file.Size, + Path: file.Path, + }) + } + } + return tree, nil +} + +func (gs *GiteaStorage) GetFileContent(hfi *domain.HuggingFaceFileRequest) (string, error) { + url := fmt.Sprintf("/api/v1/repos/%s/%s/raw/%s", hfi.Project, hfi.Name, hfi.File) + resp, err := gs.client.R(). + SetHeader("access_token", hfi.AccessToken). + SetQueryParams(map[string]string{ + "ref": safeSubstring(hfi.Version, REF_VERSION_LIMIT), + }). + Get(url) + if err != nil { + return "", err + } + return string(resp.Body()), nil +} + +func (gs *GiteaStorage) GetFileMetadata(hfFile *domain.HuggingFaceFileRequest) (*domain.GitFileMetadata, error) { + // TODO Implement me + gfd := &domain.GitFileMetadata{} + metaresp := &FileMetaResp{} + url := fmt.Sprintf("/api/v1/repos/%s/%s/contents/%s", hfFile.Project, hfFile.Name, hfFile.File) + resp, err := gs.client.R(). + SetHeader("Accept", "application/json"). + SetHeader("access_token", hfFile.AccessToken). + SetQueryParams(map[string]string{ + // 支持commitId(前10位),branch,tag + "ref": safeSubstring(hfFile.Version, REF_VERSION_LIMIT), + }). + SetResult(metaresp). + Get(url) + if err != nil { + return nil, err + } + defer resp.RawBody().Close() + + if domain.IsLfsFile(hfFile.File) { + decodeContent, err := base64.StdEncoding.DecodeString(metaresp.Content) + if err != nil { + return nil, err + } + match := Regexp(`size (\d+)`, string(decodeContent)) + if len(match) < 2 { + return nil, errors.New("size not found") + } + size, err := strconv.Atoi(match[1]) + if err != nil { + return nil, err + } + sha256 := Regexp("sha256:([a-f0-9]+)", string(decodeContent)) + gfd.Size = int64(size) + gfd.SHA256 = sha256[1] + + } else { + gfd.Size = int64(metaresp.Size) + } + gfd.LastCommitID = metaresp.LastCommitSHA + gfd.BlobID = metaresp.SHA + gfd.FileName = metaresp.Name + return gfd, nil + +} + +func safeSubstring(s string, n int) string { + runes := []rune(s) + if len(runes) > n { + return string(runes[:n]) + } + return s +} + +func (gs *GiteaStorage) ExtractLfsMetadata(content string) (int64, string) { + match := Regexp(`size (\d+)`, content) + if len(match) < 2 { + return 0, "" + } + size, err := strconv.Atoi(match[1]) + if err != nil { + return 0, "" + } + sha256 := Regexp("sha256:([a-f0-9]+)", content) + return int64(size), sha256[1] +} + +func (gs *GiteaStorage) GetLfsRawFile(hfFile *domain.HuggingFaceFileRequest, writer io.Writer) error { + url := fmt.Sprintf("/api/v1/repos/%s/%s/media/%s", hfFile.Project, hfFile.Name, hfFile.File) + resp, _ := gs.client.R().SetDoNotParseResponse(true). + SetHeader("Accept", "application/json"). + SetHeader("access_token", hfFile.AccessToken). + SetQueryParams(map[string]string{ + "ref": safeSubstring(hfFile.Version, REF_VERSION_LIMIT), + }). + Get(url) + defer resp.RawBody().Close() // 确保关闭流 + + // 使用 buffer 进行流式读写,提高性能 + buffer := make([]byte, 4096*1024) // 4M buffer + _, err := io.CopyBuffer(writer, resp.RawBody(), buffer) + return err +} + +func (gs *GiteaStorage) CreateRepo(repo *domain.CreateRepoInput, token string) (*domain.CreateRepoOutput, error) { + url := fmt.Sprintf("/api/v1/user/repos") + + // 构建请求体 + requestBody := map[string]interface{}{ + "auto_init": false, + "default_branch": DEFALUT_REVISION, + "name": repo.Name, + "private": true, + "template": false, + "trust_model": "default", + } + requestBodyBytes, _ := json.Marshal(requestBody) + + resp, err := gs.client.R(). + SetHeader("Authorization", fmt.Sprintf("token %s", token)). + SetHeader("Content-Type", "application/json"). + SetBody(requestBodyBytes). + Post(url) + + if err != nil { + return nil, err + } + defer resp.RawBody().Close() // 确保关闭流 + + if resp.StatusCode() == http.StatusConflict { + output := &domain.CreateRepoOutput{ + URL: domain.HUGGINGFACE_ENDPOINT + "/" + repo.Organization + "/" + repo.Name, + } + return output, fmt.Errorf(domain.HUGGINGFACE_HEADER_X_Error_Code_REPO_EXIST) + } + + if resp.StatusCode() != http.StatusCreated { + return nil, fmt.Errorf("failed to create repository, status code: %d", resp.StatusCode()) + } + // 解析响应体 + var responseMap map[string]interface{} + if err := json.Unmarshal(resp.Body(), &responseMap); err != nil { + return nil, fmt.Errorf("failed to parse response body: %w", err) + } + + // 提取 id 字段 + id, ok := responseMap["id"].(float64) + if !ok { + return nil, fmt.Errorf("id field not found or not a float64 in response") + } + output := &domain.CreateRepoOutput{ + URL: domain.HUGGINGFACE_ENDPOINT + "/" + repo.Organization + "/" + repo.Name, + Name: repo.Organization + "/" + repo.Name, + //从 resp id中获取 + ID: fmt.Sprintf("%f", id), + } + return output, nil +} + +func Regexp(reg string, input string) []string { + re := regexp.MustCompile(reg) + return re.FindStringSubmatch(input) +} + +type FileMetaResp struct { + Name string `json:"name"` + Path string `json:"path"` + SHA string `json:"sha"` + LastCommitSHA string `json:"last_commit_sha"` + Type string `json:"type"` + Size int64 `json:"size"` + Encoding string `json:"encoding"` + Content string `json:"content"` + Target string `json:"target"` + URL string `json:"url"` + HTMLURL string `json:"html_url"` + GitURL string `json:"git_url"` + DownloadURL string `json:"download_url"` + SubmoduleGitURL string `json:"submodule_git_url"` + Links Links `json:"_links"` +} + +type Links struct { + Self string `json:"self"` + Git string `json:"git"` + HTML string `json:"html"` +} + +type ModelInfoResp struct { + SHA string `json:"sha"` + URL string `json:"url"` + Tree []Tree `json:"tree"` + Truncated bool `json:"truncated"` + Page int64 `json:"page"` + TotalCount int64 `json:"total_count"` +} + +type Tree struct { + Path string `json:"path"` + Mode string `json:"mode"` + Type Type `json:"type"` + Size int64 `json:"size"` + SHA string `json:"sha"` + URL string `json:"url"` +} + +type Type string + +const ( + Blob Type = "blob" +) + +type Metadata struct { + License string `yaml:"license"` + LicenseName string `yaml:"license_name"` + LicenseLink string `yaml:"license_link"` + Tags []string `yaml:"tags"` + Datasets []string `yaml:"datasets"` + PipLinetag string `yaml:"pipeline_tag"` + Language []string `yaml:"language"` +}