|
package downloader |
|
|
|
import ( |
|
"crypto/sha256" |
|
"fmt" |
|
"io" |
|
"net/http" |
|
"net/url" |
|
"os" |
|
"path/filepath" |
|
"strconv" |
|
"strings" |
|
|
|
ocispec "github.com/opencontainers/image-spec/specs-go/v1" |
|
|
|
"github.com/mudler/LocalAI/pkg/oci" |
|
"github.com/mudler/LocalAI/pkg/utils" |
|
"github.com/rs/zerolog/log" |
|
) |
|
|
|
const ( |
|
HuggingFacePrefix = "huggingface://" |
|
OCIPrefix = "oci://" |
|
OllamaPrefix = "ollama://" |
|
HTTPPrefix = "http://" |
|
HTTPSPrefix = "https://" |
|
GithubURI = "github:" |
|
GithubURI2 = "github://" |
|
LocalPrefix = "file://" |
|
) |
|
|
|
type URI string |
|
|
|
func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte) error) error { |
|
url := uri.ResolveURL() |
|
|
|
if strings.HasPrefix(url, LocalPrefix) { |
|
rawURL := strings.TrimPrefix(url, LocalPrefix) |
|
|
|
resolvedFile, err := filepath.EvalSymlinks(rawURL) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
resolvedBasePath, err := filepath.EvalSymlinks(basePath) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
err = utils.InTrustedRoot(resolvedFile, resolvedBasePath) |
|
if err != nil { |
|
log.Debug().Str("resolvedFile", resolvedFile).Str("basePath", basePath).Msg("downloader.GetURI blocked an attempt to ready a file url outside of basePath") |
|
return err |
|
} |
|
|
|
body, err := os.ReadFile(resolvedFile) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
|
|
return f(url, body) |
|
} |
|
|
|
|
|
response, err := http.Get(url) |
|
if err != nil { |
|
return err |
|
} |
|
defer response.Body.Close() |
|
|
|
|
|
body, err := io.ReadAll(response.Body) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
|
|
return f(url, body) |
|
} |
|
|
|
func (u URI) FilenameFromUrl() (string, error) { |
|
f, err := filenameFromUrl(string(u)) |
|
if err != nil || f == "" { |
|
f = utils.MD5(string(u)) |
|
if strings.HasSuffix(string(u), ".yaml") || strings.HasSuffix(string(u), ".yml") { |
|
f = f + ".yaml" |
|
} |
|
err = nil |
|
} |
|
|
|
return f, err |
|
} |
|
|
|
func filenameFromUrl(urlstr string) (string, error) { |
|
|
|
if strings.Contains(urlstr, "@") { |
|
urlstr = strings.Split(urlstr, "@")[0] |
|
} |
|
|
|
u, err := url.Parse(urlstr) |
|
if err != nil { |
|
return "", fmt.Errorf("error due to parsing url: %w", err) |
|
} |
|
x, err := url.QueryUnescape(u.EscapedPath()) |
|
if err != nil { |
|
return "", fmt.Errorf("error due to escaping: %w", err) |
|
} |
|
return filepath.Base(x), nil |
|
} |
|
|
|
func (u URI) LooksLikeURL() bool { |
|
return strings.HasPrefix(string(u), HTTPPrefix) || |
|
strings.HasPrefix(string(u), HTTPSPrefix) || |
|
strings.HasPrefix(string(u), HuggingFacePrefix) || |
|
strings.HasPrefix(string(u), GithubURI) || |
|
strings.HasPrefix(string(u), OllamaPrefix) || |
|
strings.HasPrefix(string(u), OCIPrefix) || |
|
strings.HasPrefix(string(u), GithubURI2) |
|
} |
|
|
|
func (s URI) LooksLikeOCI() bool { |
|
return strings.HasPrefix(string(s), OCIPrefix) || strings.HasPrefix(string(s), OllamaPrefix) |
|
} |
|
|
|
func (s URI) ResolveURL() string { |
|
switch { |
|
case strings.HasPrefix(string(s), GithubURI2): |
|
repository := strings.Replace(string(s), GithubURI2, "", 1) |
|
|
|
repoParts := strings.Split(repository, "@") |
|
branch := "main" |
|
|
|
if len(repoParts) > 1 { |
|
branch = repoParts[1] |
|
} |
|
|
|
repoPath := strings.Split(repoParts[0], "/") |
|
org := repoPath[0] |
|
project := repoPath[1] |
|
projectPath := strings.Join(repoPath[2:], "/") |
|
|
|
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) |
|
case strings.HasPrefix(string(s), GithubURI): |
|
parts := strings.Split(string(s), ":") |
|
repoParts := strings.Split(parts[1], "@") |
|
branch := "main" |
|
|
|
if len(repoParts) > 1 { |
|
branch = repoParts[1] |
|
} |
|
|
|
repoPath := strings.Split(repoParts[0], "/") |
|
org := repoPath[0] |
|
project := repoPath[1] |
|
projectPath := strings.Join(repoPath[2:], "/") |
|
|
|
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) |
|
case strings.HasPrefix(string(s), HuggingFacePrefix): |
|
repository := strings.Replace(string(s), HuggingFacePrefix, "", 1) |
|
|
|
|
|
owner := strings.Split(repository, "/")[0] |
|
repo := strings.Split(repository, "/")[1] |
|
|
|
branch := "main" |
|
if strings.Contains(repo, "@") { |
|
branch = strings.Split(repository, "@")[1] |
|
} |
|
filepath := strings.Split(repository, "/")[2] |
|
if strings.Contains(filepath, "@") { |
|
filepath = strings.Split(filepath, "@")[0] |
|
} |
|
|
|
return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath) |
|
} |
|
|
|
return string(s) |
|
} |
|
|
|
func removePartialFile(tmpFilePath string) error { |
|
_, err := os.Stat(tmpFilePath) |
|
if err == nil { |
|
log.Debug().Msgf("Removing temporary file %s", tmpFilePath) |
|
err = os.Remove(tmpFilePath) |
|
if err != nil { |
|
err1 := fmt.Errorf("failed to remove temporary download file %s: %v", tmpFilePath, err) |
|
log.Warn().Msg(err1.Error()) |
|
return err1 |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error { |
|
url := uri.ResolveURL() |
|
if uri.LooksLikeOCI() { |
|
progressStatus := func(desc ocispec.Descriptor) io.Writer { |
|
return &progressWriter{ |
|
fileName: filePath, |
|
total: desc.Size, |
|
hash: sha256.New(), |
|
fileNo: fileN, |
|
totalFiles: total, |
|
downloadStatus: downloadStatus, |
|
} |
|
} |
|
|
|
if strings.HasPrefix(url, OllamaPrefix) { |
|
url = strings.TrimPrefix(url, OllamaPrefix) |
|
return oci.OllamaFetchModel(url, filePath, progressStatus) |
|
} |
|
|
|
url = strings.TrimPrefix(url, OCIPrefix) |
|
img, err := oci.GetImage(url, "", nil, nil) |
|
if err != nil { |
|
return fmt.Errorf("failed to get image %q: %v", url, err) |
|
} |
|
|
|
return oci.ExtractOCIImage(img, filepath.Dir(filePath)) |
|
} |
|
|
|
|
|
_, err := os.Stat(filePath) |
|
if err == nil { |
|
|
|
if sha != "" { |
|
|
|
calculatedSHA, err := calculateSHA(filePath) |
|
if err != nil { |
|
return fmt.Errorf("failed to calculate SHA for file %q: %v", filePath, err) |
|
} |
|
if calculatedSHA == sha { |
|
|
|
log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", filePath) |
|
return nil |
|
} |
|
|
|
err = os.Remove(filePath) |
|
if err != nil { |
|
return fmt.Errorf("failed to remove existing file %q: %v", filePath, err) |
|
} |
|
log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath) |
|
|
|
} else { |
|
|
|
log.Debug().Msgf("File %q already exists. Skipping download", filePath) |
|
return nil |
|
} |
|
} else if !os.IsNotExist(err) { |
|
|
|
return fmt.Errorf("failed to check file %q existence: %v", filePath, err) |
|
} |
|
|
|
log.Info().Msgf("Downloading %q", url) |
|
|
|
|
|
resp, err := http.Get(url) |
|
if err != nil { |
|
return fmt.Errorf("failed to download file %q: %v", filePath, err) |
|
} |
|
defer resp.Body.Close() |
|
|
|
if resp.StatusCode >= 400 { |
|
return fmt.Errorf("failed to download url %q, invalid status code %d", url, resp.StatusCode) |
|
} |
|
|
|
|
|
err = os.MkdirAll(filepath.Dir(filePath), 0750) |
|
if err != nil { |
|
return fmt.Errorf("failed to create parent directory for file %q: %v", filePath, err) |
|
} |
|
|
|
|
|
tmpFilePath := filePath + ".partial" |
|
|
|
|
|
err = removePartialFile(tmpFilePath) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
|
|
outFile, err := os.Create(tmpFilePath) |
|
if err != nil { |
|
return fmt.Errorf("failed to create file %q: %v", tmpFilePath, err) |
|
} |
|
defer outFile.Close() |
|
|
|
progress := &progressWriter{ |
|
fileName: tmpFilePath, |
|
total: resp.ContentLength, |
|
hash: sha256.New(), |
|
fileNo: fileN, |
|
totalFiles: total, |
|
downloadStatus: downloadStatus, |
|
} |
|
_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body) |
|
if err != nil { |
|
return fmt.Errorf("failed to write file %q: %v", filePath, err) |
|
} |
|
|
|
err = os.Rename(tmpFilePath, filePath) |
|
if err != nil { |
|
return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err) |
|
} |
|
|
|
if sha != "" { |
|
|
|
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil)) |
|
if calculatedSHA != sha { |
|
log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha) |
|
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha) |
|
} |
|
} else { |
|
log.Debug().Msgf("SHA missing for %q. Skipping validation", filePath) |
|
} |
|
|
|
log.Info().Msgf("File %q downloaded and verified", filePath) |
|
if utils.IsArchive(filePath) { |
|
basePath := filepath.Dir(filePath) |
|
log.Info().Msgf("File %q is an archive, uncompressing to %s", filePath, basePath) |
|
if err := utils.ExtractArchive(filePath, basePath); err != nil { |
|
log.Debug().Msgf("Failed decompressing %q: %s", filePath, err.Error()) |
|
return err |
|
} |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func formatBytes(bytes int64) string { |
|
const unit = 1024 |
|
if bytes < unit { |
|
return strconv.FormatInt(bytes, 10) + " B" |
|
} |
|
div, exp := int64(unit), 0 |
|
for n := bytes / unit; n >= unit; n /= unit { |
|
div *= unit |
|
exp++ |
|
} |
|
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) |
|
} |
|
|
|
func calculateSHA(filePath string) (string, error) { |
|
file, err := os.Open(filePath) |
|
if err != nil { |
|
return "", err |
|
} |
|
defer file.Close() |
|
|
|
hash := sha256.New() |
|
if _, err := io.Copy(hash, file); err != nil { |
|
return "", err |
|
} |
|
|
|
return fmt.Sprintf("%x", hash.Sum(nil)), nil |
|
} |
|
|