diff --git a/gitea.go b/gitea.go deleted file mode 100644 index 40fea5a..0000000 --- a/gitea.go +++ /dev/null @@ -1,424 +0,0 @@ -package gitea - -import ( - "crypto/tls" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "model-repository/domain" - "net/http" - "regexp" - "strconv" - "time" - - "github.com/go-resty/resty/v2" - "gopkg.in/yaml.v2" -) - -const ( - DEFALUT_REVISION = "main" - TIME_FORMAT = "2006-01-02T15:04:05.000Z" - REF_VERSION_LIMIT = 20 -) - -var ( - TagList = []string{ - "audio-classification", - "automatic-speech-recognition", - "conversational", - "depth-estimation", - "document-question-answering", - "feature-extraction", - "fill-mask", - "image-classification", - "image-feature-extraction", - "image-segmentation", - "image-to-image", - "image-to-text", - "mask-generation", - "ner", - "object-detection", - "question-answering", - "sentiment-analysis", - "summarization", - "table-question-answering", - "text-classification", - "text-generation", - "text-to-audio", - "text-to-speech", - "text2text-generation", - "token-classification", - "translation", - "video-classification", - "visual-question-answering", - "vqa", - "zero-shot-audio-classification", - "zero-shot-classification", - "zero-shot-image-classification", - "zero-shot-object-detection", - "translation_XX_to_YY", - } -) - -type GiteaStorage struct { - client *resty.Client -} - -func NewGiteaStorage(baseUrl string) *GiteaStorage { - client := resty.New().SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}).SetBaseURL(baseUrl).SetTimeout(5 * time.Minute) - return &GiteaStorage{ - client: client, - } -} - -func (gs *GiteaStorage) GetPipelineTag(hfModel *domain.HuggingFaceModelRequest) (*Metadata, error) { - url := fmt.Sprintf("/api/v1/repos/%s/%s/raw/README.md", hfModel.Project, hfModel.Name) - resp, err := gs.client.R(). - SetHeader("Accept", "application/json"). - SetHeader("access_token", hfModel.AccessToken). - SetQueryParams(map[string]string{ - "ref": safeSubstring(hfModel.Version, REF_VERSION_LIMIT), - }). - Get(url) - if err != nil { - return nil, err - } - var metadata Metadata - err = yaml.Unmarshal(resp.Body(), &metadata) - if err != nil { - fmt.Println("Error parsing YAML:", err) - return nil, err - } - return &metadata, nil -} - -func (gs *GiteaStorage) GetModelInfo(hfModel *domain.HuggingFaceModelRequest) (*domain.HuggingFaceModelInfo, error) { - url := "" - mir := &ModelInfoResp{} - hfmi := &domain.HuggingFaceModelInfo{} - - if hfModel.Version == "" { - url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, DEFALUT_REVISION) - } else { - url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, hfModel.Version) - } - resp, err := gs.client.R(). - SetHeader("Accept", "application/json"). - SetHeader("access_token", hfModel.AccessToken). - SetQueryParams(map[string]string{ - "recursive": "true", - }). - SetResult(mir). - Get(url) - defer resp.RawBody().Close() - - if err != nil { - return nil, err - } - if resp.StatusCode() == 404 { - return nil, fmt.Errorf(domain.HUGGINGFACE_HEADER_X_Error_Code_REPO_NotFound) - } - for _, file := range mir.Tree { - hfmi.Siblings = append(hfmi.Siblings, domain.Sibling{Rfilename: file.Path}) - } - hfmi.LastModified = time.Now().Format(TIME_FORMAT) - hfmi.CreatedAt = time.Now().Format(TIME_FORMAT) - hfmi.ModelInfoID = hfModel.Project + "/" + hfModel.Name - hfmi.ModelID = hfModel.Project + "/" + hfModel.Name - hfmi.SHA = mir.SHA - hfmi.LibraryName = "transformers" - readmeMeta, err := gs.GetPipelineTag(hfModel) - if err != nil { - return nil, err - } - hfmi.Tags = readmeMeta.Tags - if readmeMeta.PipLinetag == "" { - for _, rtag := range readmeMeta.Tags { - for _, tag := range TagList { - if rtag == tag { - hfmi.PipelineTag = tag - break - } - } - } - } else { - hfmi.PipelineTag = readmeMeta.PipLinetag - } - return hfmi, nil -} - -func (gs *GiteaStorage) GetModelFileTree(hfModel *domain.HuggingFaceModelRequest) ([]domain.HuggingFaceModelFileInfo, error) { - url := "" - mir := &ModelInfoResp{} - var tree []domain.HuggingFaceModelFileInfo - - if hfModel.Version == "" { - url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, DEFALUT_REVISION) - } else { - url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, hfModel.Version) - } - resp, err := gs.client.R(). - SetHeader("Accept", "application/json"). - SetHeader("access_token", hfModel.AccessToken). - SetQueryParams(map[string]string{ - "recursive": "true", - }). - SetResult(mir). - Get(url) - if err != nil { - return nil, err - } - defer resp.RawBody().Close() - for _, file := range mir.Tree { - if domain.IsLfsFile(file.Path) { - content, err := gs.GetFileContent(&domain.HuggingFaceFileRequest{ - Project: hfModel.Project, - Name: hfModel.Name, - File: file.Path, - Version: hfModel.Version, - AccessToken: hfModel.AccessToken, - }) - if err != nil { - return nil, err - } - size, hash := gs.ExtractLfsMetadata(content) - tree = append(tree, domain.HuggingFaceModelFileInfo{ - Type: string(file.Type), - Oid: file.SHA, - Size: size, - Path: file.Path, - Lfs: &domain.HuggingFaceModelLfsFileInfo{ - Oid: hash, - Size: size, - PointerSize: file.Size, - }, - }) - } else { - tree = append(tree, domain.HuggingFaceModelFileInfo{ - Type: string(file.Type), - Oid: file.SHA, - Size: file.Size, - Path: file.Path, - }) - } - } - return tree, nil -} - -func (gs *GiteaStorage) GetFileContent(hfi *domain.HuggingFaceFileRequest) (string, error) { - url := fmt.Sprintf("/api/v1/repos/%s/%s/raw/%s", hfi.Project, hfi.Name, hfi.File) - resp, err := gs.client.R(). - SetHeader("access_token", hfi.AccessToken). - SetQueryParams(map[string]string{ - "ref": safeSubstring(hfi.Version, REF_VERSION_LIMIT), - }). - Get(url) - if err != nil { - return "", err - } - return string(resp.Body()), nil -} - -func (gs *GiteaStorage) GetFileMetadata(hfFile *domain.HuggingFaceFileRequest) (*domain.GitFileMetadata, error) { - // TODO Implement me - gfd := &domain.GitFileMetadata{} - metaresp := &FileMetaResp{} - url := fmt.Sprintf("/api/v1/repos/%s/%s/contents/%s", hfFile.Project, hfFile.Name, hfFile.File) - resp, err := gs.client.R(). - SetHeader("Accept", "application/json"). - SetHeader("access_token", hfFile.AccessToken). - SetQueryParams(map[string]string{ - // 支持commitId(前10位),branch,tag - "ref": safeSubstring(hfFile.Version, REF_VERSION_LIMIT), - }). - SetResult(metaresp). - Get(url) - if err != nil { - return nil, err - } - defer resp.RawBody().Close() - - if domain.IsLfsFile(hfFile.File) { - decodeContent, err := base64.StdEncoding.DecodeString(metaresp.Content) - if err != nil { - return nil, err - } - match := Regexp(`size (\d+)`, string(decodeContent)) - if len(match) < 2 { - return nil, errors.New("size not found") - } - size, err := strconv.Atoi(match[1]) - if err != nil { - return nil, err - } - sha256 := Regexp("sha256:([a-f0-9]+)", string(decodeContent)) - gfd.Size = int64(size) - gfd.SHA256 = sha256[1] - - } else { - gfd.Size = int64(metaresp.Size) - } - gfd.LastCommitID = metaresp.LastCommitSHA - gfd.BlobID = metaresp.SHA - gfd.FileName = metaresp.Name - return gfd, nil - -} - -func safeSubstring(s string, n int) string { - runes := []rune(s) - if len(runes) > n { - return string(runes[:n]) - } - return s -} - -func (gs *GiteaStorage) ExtractLfsMetadata(content string) (int64, string) { - match := Regexp(`size (\d+)`, content) - if len(match) < 2 { - return 0, "" - } - size, err := strconv.Atoi(match[1]) - if err != nil { - return 0, "" - } - sha256 := Regexp("sha256:([a-f0-9]+)", content) - return int64(size), sha256[1] -} - -func (gs *GiteaStorage) GetLfsRawFile(hfFile *domain.HuggingFaceFileRequest, writer io.Writer) error { - url := fmt.Sprintf("/api/v1/repos/%s/%s/media/%s", hfFile.Project, hfFile.Name, hfFile.File) - resp, _ := gs.client.R().SetDoNotParseResponse(true). - SetHeader("Accept", "application/json"). - SetHeader("access_token", hfFile.AccessToken). - SetQueryParams(map[string]string{ - "ref": safeSubstring(hfFile.Version, REF_VERSION_LIMIT), - }). - Get(url) - defer resp.RawBody().Close() // 确保关闭流 - - // 使用 buffer 进行流式读写,提高性能 - buffer := make([]byte, 4096*1024) // 4M buffer - _, err := io.CopyBuffer(writer, resp.RawBody(), buffer) - return err -} - -func (gs *GiteaStorage) CreateRepo(repo *domain.CreateRepoInput, token string) (*domain.CreateRepoOutput, error) { - url := fmt.Sprintf("/api/v1/user/repos") - - // 构建请求体 - requestBody := map[string]interface{}{ - "auto_init": false, - "default_branch": DEFALUT_REVISION, - "name": repo.Name, - "private": true, - "template": false, - "trust_model": "default", - } - requestBodyBytes, _ := json.Marshal(requestBody) - - resp, err := gs.client.R(). - SetHeader("Authorization", fmt.Sprintf("token %s", token)). - SetHeader("Content-Type", "application/json"). - SetBody(requestBodyBytes). - Post(url) - - if err != nil { - return nil, err - } - defer resp.RawBody().Close() // 确保关闭流 - - if resp.StatusCode() == http.StatusConflict { - output := &domain.CreateRepoOutput{ - URL: domain.HUGGINGFACE_ENDPOINT + "/" + repo.Organization + "/" + repo.Name, - } - return output, fmt.Errorf(domain.HUGGINGFACE_HEADER_X_Error_Code_REPO_EXIST) - } - - if resp.StatusCode() != http.StatusCreated { - return nil, fmt.Errorf("failed to create repository, status code: %d", resp.StatusCode()) - } - // 解析响应体 - var responseMap map[string]interface{} - if err := json.Unmarshal(resp.Body(), &responseMap); err != nil { - return nil, fmt.Errorf("failed to parse response body: %w", err) - } - - // 提取 id 字段 - id, ok := responseMap["id"].(float64) - if !ok { - return nil, fmt.Errorf("id field not found or not a float64 in response") - } - output := &domain.CreateRepoOutput{ - URL: domain.HUGGINGFACE_ENDPOINT + "/" + repo.Organization + "/" + repo.Name, - Name: repo.Organization + "/" + repo.Name, - //从 resp id中获取 - ID: fmt.Sprintf("%f", id), - } - return output, nil -} - -func Regexp(reg string, input string) []string { - re := regexp.MustCompile(reg) - return re.FindStringSubmatch(input) -} - -type FileMetaResp struct { - Name string `json:"name"` - Path string `json:"path"` - SHA string `json:"sha"` - LastCommitSHA string `json:"last_commit_sha"` - Type string `json:"type"` - Size int64 `json:"size"` - Encoding string `json:"encoding"` - Content string `json:"content"` - Target string `json:"target"` - URL string `json:"url"` - HTMLURL string `json:"html_url"` - GitURL string `json:"git_url"` - DownloadURL string `json:"download_url"` - SubmoduleGitURL string `json:"submodule_git_url"` - Links Links `json:"_links"` -} - -type Links struct { - Self string `json:"self"` - Git string `json:"git"` - HTML string `json:"html"` -} - -type ModelInfoResp struct { - SHA string `json:"sha"` - URL string `json:"url"` - Tree []Tree `json:"tree"` - Truncated bool `json:"truncated"` - Page int64 `json:"page"` - TotalCount int64 `json:"total_count"` -} - -type Tree struct { - Path string `json:"path"` - Mode string `json:"mode"` - Type Type `json:"type"` - Size int64 `json:"size"` - SHA string `json:"sha"` - URL string `json:"url"` -} - -type Type string - -const ( - Blob Type = "blob" -) - -type Metadata struct { - License string `yaml:"license"` - LicenseName string `yaml:"license_name"` - LicenseLink string `yaml:"license_link"` - Tags []string `yaml:"tags"` - Datasets []string `yaml:"datasets"` - PipLinetag string `yaml:"pipeline_tag"` - Language []string `yaml:"language"` -}