425 lines
11 KiB
Go
425 lines
11 KiB
Go
package gitea
|
||
|
||
import (
|
||
"crypto/tls"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"model-repository/domain"
|
||
"net/http"
|
||
"regexp"
|
||
"strconv"
|
||
"time"
|
||
|
||
"github.com/go-resty/resty/v2"
|
||
"gopkg.in/yaml.v2"
|
||
)
|
||
|
||
const (
|
||
DEFALUT_REVISION = "main"
|
||
TIME_FORMAT = "2006-01-02T15:04:05.000Z"
|
||
REF_VERSION_LIMIT = 20
|
||
)
|
||
|
||
var (
|
||
TagList = []string{
|
||
"audio-classification",
|
||
"automatic-speech-recognition",
|
||
"conversational",
|
||
"depth-estimation",
|
||
"document-question-answering",
|
||
"feature-extraction",
|
||
"fill-mask",
|
||
"image-classification",
|
||
"image-feature-extraction",
|
||
"image-segmentation",
|
||
"image-to-image",
|
||
"image-to-text",
|
||
"mask-generation",
|
||
"ner",
|
||
"object-detection",
|
||
"question-answering",
|
||
"sentiment-analysis",
|
||
"summarization",
|
||
"table-question-answering",
|
||
"text-classification",
|
||
"text-generation",
|
||
"text-to-audio",
|
||
"text-to-speech",
|
||
"text2text-generation",
|
||
"token-classification",
|
||
"translation",
|
||
"video-classification",
|
||
"visual-question-answering",
|
||
"vqa",
|
||
"zero-shot-audio-classification",
|
||
"zero-shot-classification",
|
||
"zero-shot-image-classification",
|
||
"zero-shot-object-detection",
|
||
"translation_XX_to_YY",
|
||
}
|
||
)
|
||
|
||
type GiteaStorage struct {
|
||
client *resty.Client
|
||
}
|
||
|
||
func NewGiteaStorage(baseUrl string) *GiteaStorage {
|
||
client := resty.New().SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}).SetBaseURL(baseUrl).SetTimeout(5 * time.Minute)
|
||
return &GiteaStorage{
|
||
client: client,
|
||
}
|
||
}
|
||
|
||
func (gs *GiteaStorage) GetPipelineTag(hfModel *domain.HuggingFaceModelRequest) (*Metadata, error) {
|
||
url := fmt.Sprintf("/api/v1/repos/%s/%s/raw/README.md", hfModel.Project, hfModel.Name)
|
||
resp, err := gs.client.R().
|
||
SetHeader("Accept", "application/json").
|
||
SetHeader("access_token", hfModel.AccessToken).
|
||
SetQueryParams(map[string]string{
|
||
"ref": safeSubstring(hfModel.Version, REF_VERSION_LIMIT),
|
||
}).
|
||
Get(url)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
var metadata Metadata
|
||
err = yaml.Unmarshal(resp.Body(), &metadata)
|
||
if err != nil {
|
||
fmt.Println("Error parsing YAML:", err)
|
||
return nil, err
|
||
}
|
||
return &metadata, nil
|
||
}
|
||
|
||
func (gs *GiteaStorage) GetModelInfo(hfModel *domain.HuggingFaceModelRequest) (*domain.HuggingFaceModelInfo, error) {
|
||
url := ""
|
||
mir := &ModelInfoResp{}
|
||
hfmi := &domain.HuggingFaceModelInfo{}
|
||
|
||
if hfModel.Version == "" {
|
||
url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, DEFALUT_REVISION)
|
||
} else {
|
||
url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, hfModel.Version)
|
||
}
|
||
resp, err := gs.client.R().
|
||
SetHeader("Accept", "application/json").
|
||
SetHeader("access_token", hfModel.AccessToken).
|
||
SetQueryParams(map[string]string{
|
||
"recursive": "true",
|
||
}).
|
||
SetResult(mir).
|
||
Get(url)
|
||
defer resp.RawBody().Close()
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if resp.StatusCode() == 404 {
|
||
return nil, fmt.Errorf(domain.HUGGINGFACE_HEADER_X_Error_Code_REPO_NotFound)
|
||
}
|
||
for _, file := range mir.Tree {
|
||
hfmi.Siblings = append(hfmi.Siblings, domain.Sibling{Rfilename: file.Path})
|
||
}
|
||
hfmi.LastModified = time.Now().Format(TIME_FORMAT)
|
||
hfmi.CreatedAt = time.Now().Format(TIME_FORMAT)
|
||
hfmi.ModelInfoID = hfModel.Project + "/" + hfModel.Name
|
||
hfmi.ModelID = hfModel.Project + "/" + hfModel.Name
|
||
hfmi.SHA = mir.SHA
|
||
hfmi.LibraryName = "transformers"
|
||
readmeMeta, err := gs.GetPipelineTag(hfModel)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
hfmi.Tags = readmeMeta.Tags
|
||
if readmeMeta.PipLinetag == "" {
|
||
for _, rtag := range readmeMeta.Tags {
|
||
for _, tag := range TagList {
|
||
if rtag == tag {
|
||
hfmi.PipelineTag = tag
|
||
break
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
hfmi.PipelineTag = readmeMeta.PipLinetag
|
||
}
|
||
return hfmi, nil
|
||
}
|
||
|
||
func (gs *GiteaStorage) GetModelFileTree(hfModel *domain.HuggingFaceModelRequest) ([]domain.HuggingFaceModelFileInfo, error) {
|
||
url := ""
|
||
mir := &ModelInfoResp{}
|
||
var tree []domain.HuggingFaceModelFileInfo
|
||
|
||
if hfModel.Version == "" {
|
||
url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, DEFALUT_REVISION)
|
||
} else {
|
||
url = fmt.Sprintf("/api/v1/repos/%s/%s/git/trees/%s", hfModel.Project, hfModel.Name, hfModel.Version)
|
||
}
|
||
resp, err := gs.client.R().
|
||
SetHeader("Accept", "application/json").
|
||
SetHeader("access_token", hfModel.AccessToken).
|
||
SetQueryParams(map[string]string{
|
||
"recursive": "true",
|
||
}).
|
||
SetResult(mir).
|
||
Get(url)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer resp.RawBody().Close()
|
||
for _, file := range mir.Tree {
|
||
if domain.IsLfsFile(file.Path) {
|
||
content, err := gs.GetFileContent(&domain.HuggingFaceFileRequest{
|
||
Project: hfModel.Project,
|
||
Name: hfModel.Name,
|
||
File: file.Path,
|
||
Version: hfModel.Version,
|
||
AccessToken: hfModel.AccessToken,
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
size, hash := gs.ExtractLfsMetadata(content)
|
||
tree = append(tree, domain.HuggingFaceModelFileInfo{
|
||
Type: string(file.Type),
|
||
Oid: file.SHA,
|
||
Size: size,
|
||
Path: file.Path,
|
||
Lfs: &domain.HuggingFaceModelLfsFileInfo{
|
||
Oid: hash,
|
||
Size: size,
|
||
PointerSize: file.Size,
|
||
},
|
||
})
|
||
} else {
|
||
tree = append(tree, domain.HuggingFaceModelFileInfo{
|
||
Type: string(file.Type),
|
||
Oid: file.SHA,
|
||
Size: file.Size,
|
||
Path: file.Path,
|
||
})
|
||
}
|
||
}
|
||
return tree, nil
|
||
}
|
||
|
||
func (gs *GiteaStorage) GetFileContent(hfi *domain.HuggingFaceFileRequest) (string, error) {
|
||
url := fmt.Sprintf("/api/v1/repos/%s/%s/raw/%s", hfi.Project, hfi.Name, hfi.File)
|
||
resp, err := gs.client.R().
|
||
SetHeader("access_token", hfi.AccessToken).
|
||
SetQueryParams(map[string]string{
|
||
"ref": safeSubstring(hfi.Version, REF_VERSION_LIMIT),
|
||
}).
|
||
Get(url)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return string(resp.Body()), nil
|
||
}
|
||
|
||
func (gs *GiteaStorage) GetFileMetadata(hfFile *domain.HuggingFaceFileRequest) (*domain.GitFileMetadata, error) {
|
||
// TODO Implement me
|
||
gfd := &domain.GitFileMetadata{}
|
||
metaresp := &FileMetaResp{}
|
||
url := fmt.Sprintf("/api/v1/repos/%s/%s/contents/%s", hfFile.Project, hfFile.Name, hfFile.File)
|
||
resp, err := gs.client.R().
|
||
SetHeader("Accept", "application/json").
|
||
SetHeader("access_token", hfFile.AccessToken).
|
||
SetQueryParams(map[string]string{
|
||
// 支持commitId(前10位),branch,tag
|
||
"ref": safeSubstring(hfFile.Version, REF_VERSION_LIMIT),
|
||
}).
|
||
SetResult(metaresp).
|
||
Get(url)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer resp.RawBody().Close()
|
||
|
||
if domain.IsLfsFile(hfFile.File) {
|
||
decodeContent, err := base64.StdEncoding.DecodeString(metaresp.Content)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
match := Regexp(`size (\d+)`, string(decodeContent))
|
||
if len(match) < 2 {
|
||
return nil, errors.New("size not found")
|
||
}
|
||
size, err := strconv.Atoi(match[1])
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
sha256 := Regexp("sha256:([a-f0-9]+)", string(decodeContent))
|
||
gfd.Size = int64(size)
|
||
gfd.SHA256 = sha256[1]
|
||
|
||
} else {
|
||
gfd.Size = int64(metaresp.Size)
|
||
}
|
||
gfd.LastCommitID = metaresp.LastCommitSHA
|
||
gfd.BlobID = metaresp.SHA
|
||
gfd.FileName = metaresp.Name
|
||
return gfd, nil
|
||
|
||
}
|
||
|
||
func safeSubstring(s string, n int) string {
|
||
runes := []rune(s)
|
||
if len(runes) > n {
|
||
return string(runes[:n])
|
||
}
|
||
return s
|
||
}
|
||
|
||
func (gs *GiteaStorage) ExtractLfsMetadata(content string) (int64, string) {
|
||
match := Regexp(`size (\d+)`, content)
|
||
if len(match) < 2 {
|
||
return 0, ""
|
||
}
|
||
size, err := strconv.Atoi(match[1])
|
||
if err != nil {
|
||
return 0, ""
|
||
}
|
||
sha256 := Regexp("sha256:([a-f0-9]+)", content)
|
||
return int64(size), sha256[1]
|
||
}
|
||
|
||
func (gs *GiteaStorage) GetLfsRawFile(hfFile *domain.HuggingFaceFileRequest, writer io.Writer) error {
|
||
url := fmt.Sprintf("/api/v1/repos/%s/%s/media/%s", hfFile.Project, hfFile.Name, hfFile.File)
|
||
resp, _ := gs.client.R().SetDoNotParseResponse(true).
|
||
SetHeader("Accept", "application/json").
|
||
SetHeader("access_token", hfFile.AccessToken).
|
||
SetQueryParams(map[string]string{
|
||
"ref": safeSubstring(hfFile.Version, REF_VERSION_LIMIT),
|
||
}).
|
||
Get(url)
|
||
defer resp.RawBody().Close() // 确保关闭流
|
||
|
||
// 使用 buffer 进行流式读写,提高性能
|
||
buffer := make([]byte, 4096*1024) // 4M buffer
|
||
_, err := io.CopyBuffer(writer, resp.RawBody(), buffer)
|
||
return err
|
||
}
|
||
|
||
func (gs *GiteaStorage) CreateRepo(repo *domain.CreateRepoInput, token string) (*domain.CreateRepoOutput, error) {
|
||
url := fmt.Sprintf("/api/v1/user/repos")
|
||
|
||
// 构建请求体
|
||
requestBody := map[string]interface{}{
|
||
"auto_init": false,
|
||
"default_branch": DEFALUT_REVISION,
|
||
"name": repo.Name,
|
||
"private": true,
|
||
"template": false,
|
||
"trust_model": "default",
|
||
}
|
||
requestBodyBytes, _ := json.Marshal(requestBody)
|
||
|
||
resp, err := gs.client.R().
|
||
SetHeader("Authorization", fmt.Sprintf("token %s", token)).
|
||
SetHeader("Content-Type", "application/json").
|
||
SetBody(requestBodyBytes).
|
||
Post(url)
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer resp.RawBody().Close() // 确保关闭流
|
||
|
||
if resp.StatusCode() == http.StatusConflict {
|
||
output := &domain.CreateRepoOutput{
|
||
URL: domain.HUGGINGFACE_ENDPOINT + "/" + repo.Organization + "/" + repo.Name,
|
||
}
|
||
return output, fmt.Errorf(domain.HUGGINGFACE_HEADER_X_Error_Code_REPO_EXIST)
|
||
}
|
||
|
||
if resp.StatusCode() != http.StatusCreated {
|
||
return nil, fmt.Errorf("failed to create repository, status code: %d", resp.StatusCode())
|
||
}
|
||
// 解析响应体
|
||
var responseMap map[string]interface{}
|
||
if err := json.Unmarshal(resp.Body(), &responseMap); err != nil {
|
||
return nil, fmt.Errorf("failed to parse response body: %w", err)
|
||
}
|
||
|
||
// 提取 id 字段
|
||
id, ok := responseMap["id"].(float64)
|
||
if !ok {
|
||
return nil, fmt.Errorf("id field not found or not a float64 in response")
|
||
}
|
||
output := &domain.CreateRepoOutput{
|
||
URL: domain.HUGGINGFACE_ENDPOINT + "/" + repo.Organization + "/" + repo.Name,
|
||
Name: repo.Organization + "/" + repo.Name,
|
||
//从 resp id中获取
|
||
ID: fmt.Sprintf("%f", id),
|
||
}
|
||
return output, nil
|
||
}
|
||
|
||
func Regexp(reg string, input string) []string {
|
||
re := regexp.MustCompile(reg)
|
||
return re.FindStringSubmatch(input)
|
||
}
|
||
|
||
type FileMetaResp struct {
|
||
Name string `json:"name"`
|
||
Path string `json:"path"`
|
||
SHA string `json:"sha"`
|
||
LastCommitSHA string `json:"last_commit_sha"`
|
||
Type string `json:"type"`
|
||
Size int64 `json:"size"`
|
||
Encoding string `json:"encoding"`
|
||
Content string `json:"content"`
|
||
Target string `json:"target"`
|
||
URL string `json:"url"`
|
||
HTMLURL string `json:"html_url"`
|
||
GitURL string `json:"git_url"`
|
||
DownloadURL string `json:"download_url"`
|
||
SubmoduleGitURL string `json:"submodule_git_url"`
|
||
Links Links `json:"_links"`
|
||
}
|
||
|
||
type Links struct {
|
||
Self string `json:"self"`
|
||
Git string `json:"git"`
|
||
HTML string `json:"html"`
|
||
}
|
||
|
||
type ModelInfoResp struct {
|
||
SHA string `json:"sha"`
|
||
URL string `json:"url"`
|
||
Tree []Tree `json:"tree"`
|
||
Truncated bool `json:"truncated"`
|
||
Page int64 `json:"page"`
|
||
TotalCount int64 `json:"total_count"`
|
||
}
|
||
|
||
type Tree struct {
|
||
Path string `json:"path"`
|
||
Mode string `json:"mode"`
|
||
Type Type `json:"type"`
|
||
Size int64 `json:"size"`
|
||
SHA string `json:"sha"`
|
||
URL string `json:"url"`
|
||
}
|
||
|
||
type Type string
|
||
|
||
const (
|
||
Blob Type = "blob"
|
||
)
|
||
|
||
type Metadata struct {
|
||
License string `yaml:"license"`
|
||
LicenseName string `yaml:"license_name"`
|
||
LicenseLink string `yaml:"license_link"`
|
||
Tags []string `yaml:"tags"`
|
||
Datasets []string `yaml:"datasets"`
|
||
PipLinetag string `yaml:"pipeline_tag"`
|
||
Language []string `yaml:"language"`
|
||
}
|