restNet50-sss20240819140609/gitea.go

425 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package gitea
import (
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"model-repository/domain"
"net/http"
"regexp"
"strconv"
"time"
"github.com/go-resty/resty/v2"
"gopkg.in/yaml.v2"
)
const (
DEFALUT_REVISION = "main"
TIME_FORMAT = "2006-01-02T15:04:05.000Z"
REF_VERSION_LIMIT = 20
)
var (
TagList = []string{
"audio-classification",
"automatic-speech-recognition",
"conversational",
"depth-estimation",
"document-question-answering",
"feature-extraction",
"fill-mask",
"image-classification",
"image-feature-extraction",
"image-segmentation",
"image-to-image",
"image-to-text",
"mask-generation",
"ner",
"object-detection",
"question-answering",
"sentiment-analysis",
"summarization",
"table-question-answering",
"text-classification",
"text-generation",
"text-to-audio",
"text-to-speech",
"text2text-generation",
"token-classification",
"translation",
"video-classification",
"visual-question-answering",
"vqa",
"zero-shot-audio-classification",
"zero-shot-classification",
"zero-shot-image-classification",
"zero-shot-object-detection",
"translation_XX_to_YY",
}
)
type GiteaStorage struct {
client *resty.Client
}
func NewGiteaStorage(baseUrl string) *GiteaStorage {
client := resty.New().SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}).SetBaseURL(baseUrl).SetTimeout(5 * time.Minute)
return &GiteaStorage{
client: client,
}
}
func (gs *GiteaStorage) GetPipelineTag(hfModel *domain.HuggingFaceModelRequest) (*Metadata, error) {
url := fmt.Sprintf("/api/v1/repos/%s/%s/raw/README.md", hfModel.Project, hfModel.Name)
resp, err := gs.client.R().
SetHeader("Accept", "application/json").
SetHeader("access_token", hfModel.AccessToken).
SetQueryParams(map[string]string{
"ref": safeSubstring(hfModel.Version, REF_VERSION_LIMIT),
}).
Get(url)
if err != nil {
return nil, err
}
var metadata Metadata
err = yaml.Unmarshal(resp.Body(), &metadata)
if err != nil {
fmt.Println("Error parsing YAML:", err)
return nil, err
}
return &metadata, nil
}
func (gs *GiteaStorage) GetModelInfo(hfModel *domain.HuggingFaceModelRequest) (*domain.HuggingFaceModelInfo, error) {
url := ""
mir := &ModelInfoResp{}
hfmi := &domain.HuggingFaceModelInfo{}
if hfModel.Version == "" {
url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, DEFALUT_REVISION)
} else {
url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, hfModel.Version)
}
resp, err := gs.client.R().
SetHeader("Accept", "application/json").
SetHeader("access_token", hfModel.AccessToken).
SetQueryParams(map[string]string{
"recursive": "true",
}).
SetResult(mir).
Get(url)
defer resp.RawBody().Close()
if err != nil {
return nil, err
}
if resp.StatusCode() == 404 {
return nil, fmt.Errorf(domain.HUGGINGFACE_HEADER_X_Error_Code_REPO_NotFound)
}
for _, file := range mir.Tree {
hfmi.Siblings = append(hfmi.Siblings, domain.Sibling{Rfilename: file.Path})
}
hfmi.LastModified = time.Now().Format(TIME_FORMAT)
hfmi.CreatedAt = time.Now().Format(TIME_FORMAT)
hfmi.ModelInfoID = hfModel.Project + "/" + hfModel.Name
hfmi.ModelID = hfModel.Project + "/" + hfModel.Name
hfmi.SHA = mir.SHA
hfmi.LibraryName = "transformers"
readmeMeta, err := gs.GetPipelineTag(hfModel)
if err != nil {
return nil, err
}
hfmi.Tags = readmeMeta.Tags
if readmeMeta.PipLinetag == "" {
for _, rtag := range readmeMeta.Tags {
for _, tag := range TagList {
if rtag == tag {
hfmi.PipelineTag = tag
break
}
}
}
} else {
hfmi.PipelineTag = readmeMeta.PipLinetag
}
return hfmi, nil
}
func (gs *GiteaStorage) GetModelFileTree(hfModel *domain.HuggingFaceModelRequest) ([]domain.HuggingFaceModelFileInfo, error) {
url := ""
mir := &ModelInfoResp{}
var tree []domain.HuggingFaceModelFileInfo
if hfModel.Version == "" {
url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, DEFALUT_REVISION)
} else {
url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, hfModel.Version)
}
resp, err := gs.client.R().
SetHeader("Accept", "application/json").
SetHeader("access_token", hfModel.AccessToken).
SetQueryParams(map[string]string{
"recursive": "true",
}).
SetResult(mir).
Get(url)
if err != nil {
return nil, err
}
defer resp.RawBody().Close()
for _, file := range mir.Tree {
if domain.IsLfsFile(file.Path) {
content, err := gs.GetFileContent(&domain.HuggingFaceFileRequest{
Project: hfModel.Project,
Name: hfModel.Name,
File: file.Path,
Version: hfModel.Version,
AccessToken: hfModel.AccessToken,
})
if err != nil {
return nil, err
}
size, hash := gs.ExtractLfsMetadata(content)
tree = append(tree, domain.HuggingFaceModelFileInfo{
Type: string(file.Type),
Oid: file.SHA,
Size: size,
Path: file.Path,
Lfs: &domain.HuggingFaceModelLfsFileInfo{
Oid: hash,
Size: size,
PointerSize: file.Size,
},
})
} else {
tree = append(tree, domain.HuggingFaceModelFileInfo{
Type: string(file.Type),
Oid: file.SHA,
Size: file.Size,
Path: file.Path,
})
}
}
return tree, nil
}
func (gs *GiteaStorage) GetFileContent(hfi *domain.HuggingFaceFileRequest) (string, error) {
url := fmt.Sprintf("/api/v1/repos/%s/%s/raw/%s", hfi.Project, hfi.Name, hfi.File)
resp, err := gs.client.R().
SetHeader("access_token", hfi.AccessToken).
SetQueryParams(map[string]string{
"ref": safeSubstring(hfi.Version, REF_VERSION_LIMIT),
}).
Get(url)
if err != nil {
return "", err
}
return string(resp.Body()), nil
}
func (gs *GiteaStorage) GetFileMetadata(hfFile *domain.HuggingFaceFileRequest) (*domain.GitFileMetadata, error) {
// TODO Implement me
gfd := &domain.GitFileMetadata{}
metaresp := &FileMetaResp{}
url := fmt.Sprintf("/api/v1/repos/%s/%s/contents/%s", hfFile.Project, hfFile.Name, hfFile.File)
resp, err := gs.client.R().
SetHeader("Accept", "application/json").
SetHeader("access_token", hfFile.AccessToken).
SetQueryParams(map[string]string{
// 支持commitId(前10位)branchtag
"ref": safeSubstring(hfFile.Version, REF_VERSION_LIMIT),
}).
SetResult(metaresp).
Get(url)
if err != nil {
return nil, err
}
defer resp.RawBody().Close()
if domain.IsLfsFile(hfFile.File) {
decodeContent, err := base64.StdEncoding.DecodeString(metaresp.Content)
if err != nil {
return nil, err
}
match := Regexp(`size (\d+)`, string(decodeContent))
if len(match) < 2 {
return nil, errors.New("size not found")
}
size, err := strconv.Atoi(match[1])
if err != nil {
return nil, err
}
sha256 := Regexp("sha256:([a-f0-9]+)", string(decodeContent))
gfd.Size = int64(size)
gfd.SHA256 = sha256[1]
} else {
gfd.Size = int64(metaresp.Size)
}
gfd.LastCommitID = metaresp.LastCommitSHA
gfd.BlobID = metaresp.SHA
gfd.FileName = metaresp.Name
return gfd, nil
}
func safeSubstring(s string, n int) string {
runes := []rune(s)
if len(runes) > n {
return string(runes[:n])
}
return s
}
func (gs *GiteaStorage) ExtractLfsMetadata(content string) (int64, string) {
match := Regexp(`size (\d+)`, content)
if len(match) < 2 {
return 0, ""
}
size, err := strconv.Atoi(match[1])
if err != nil {
return 0, ""
}
sha256 := Regexp("sha256:([a-f0-9]+)", content)
return int64(size), sha256[1]
}
func (gs *GiteaStorage) GetLfsRawFile(hfFile *domain.HuggingFaceFileRequest, writer io.Writer) error {
url := fmt.Sprintf("/api/v1/repos/%s/%s/media/%s", hfFile.Project, hfFile.Name, hfFile.File)
resp, _ := gs.client.R().SetDoNotParseResponse(true).
SetHeader("Accept", "application/json").
SetHeader("access_token", hfFile.AccessToken).
SetQueryParams(map[string]string{
"ref": safeSubstring(hfFile.Version, REF_VERSION_LIMIT),
}).
Get(url)
defer resp.RawBody().Close() // 确保关闭流
// 使用 buffer 进行流式读写,提高性能
buffer := make([]byte, 4096*1024) // 4M buffer
_, err := io.CopyBuffer(writer, resp.RawBody(), buffer)
return err
}
func (gs *GiteaStorage) CreateRepo(repo *domain.CreateRepoInput, token string) (*domain.CreateRepoOutput, error) {
url := fmt.Sprintf("/api/v1/user/repos")
// 构建请求体
requestBody := map[string]interface{}{
"auto_init": false,
"default_branch": DEFALUT_REVISION,
"name": repo.Name,
"private": true,
"template": false,
"trust_model": "default",
}
requestBodyBytes, _ := json.Marshal(requestBody)
resp, err := gs.client.R().
SetHeader("Authorization", fmt.Sprintf("token %s", token)).
SetHeader("Content-Type", "application/json").
SetBody(requestBodyBytes).
Post(url)
if err != nil {
return nil, err
}
defer resp.RawBody().Close() // 确保关闭流
if resp.StatusCode() == http.StatusConflict {
output := &domain.CreateRepoOutput{
URL: domain.HUGGINGFACE_ENDPOINT + "/" + repo.Organization + "/" + repo.Name,
}
return output, fmt.Errorf(domain.HUGGINGFACE_HEADER_X_Error_Code_REPO_EXIST)
}
if resp.StatusCode() != http.StatusCreated {
return nil, fmt.Errorf("failed to create repository, status code: %d", resp.StatusCode())
}
// 解析响应体
var responseMap map[string]interface{}
if err := json.Unmarshal(resp.Body(), &responseMap); err != nil {
return nil, fmt.Errorf("failed to parse response body: %w", err)
}
// 提取 id 字段
id, ok := responseMap["id"].(float64)
if !ok {
return nil, fmt.Errorf("id field not found or not a float64 in response")
}
output := &domain.CreateRepoOutput{
URL: domain.HUGGINGFACE_ENDPOINT + "/" + repo.Organization + "/" + repo.Name,
Name: repo.Organization + "/" + repo.Name,
//从 resp id中获取
ID: fmt.Sprintf("%f", id),
}
return output, nil
}
func Regexp(reg string, input string) []string {
re := regexp.MustCompile(reg)
return re.FindStringSubmatch(input)
}
type FileMetaResp struct {
Name string `json:"name"`
Path string `json:"path"`
SHA string `json:"sha"`
LastCommitSHA string `json:"last_commit_sha"`
Type string `json:"type"`
Size int64 `json:"size"`
Encoding string `json:"encoding"`
Content string `json:"content"`
Target string `json:"target"`
URL string `json:"url"`
HTMLURL string `json:"html_url"`
GitURL string `json:"git_url"`
DownloadURL string `json:"download_url"`
SubmoduleGitURL string `json:"submodule_git_url"`
Links Links `json:"_links"`
}
type Links struct {
Self string `json:"self"`
Git string `json:"git"`
HTML string `json:"html"`
}
type ModelInfoResp struct {
SHA string `json:"sha"`
URL string `json:"url"`
Tree []Tree `json:"tree"`
Truncated bool `json:"truncated"`
Page int64 `json:"page"`
TotalCount int64 `json:"total_count"`
}
type Tree struct {
Path string `json:"path"`
Mode string `json:"mode"`
Type Type `json:"type"`
Size int64 `json:"size"`
SHA string `json:"sha"`
URL string `json:"url"`
}
type Type string
const (
Blob Type = "blob"
)
type Metadata struct {
License string `yaml:"license"`
LicenseName string `yaml:"license_name"`
LicenseLink string `yaml:"license_link"`
Tags []string `yaml:"tags"`
Datasets []string `yaml:"datasets"`
PipLinetag string `yaml:"pipeline_tag"`
Language []string `yaml:"language"`
}