File size: 2,781 Bytes
7def60a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package templates

import (
	"bytes"
	"fmt"
	"os"
	"path/filepath"
	"sync"
	"text/template"

	"github.com/mudler/LocalAI/pkg/utils"

	"github.com/Masterminds/sprig/v3"
)

// Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go?
// Technically, order doesn't _really_ matter, but the count must stay in sync, see tests/integration/reflect_test.go
type TemplateType int

type TemplateCache struct {
	mu            sync.Mutex
	templatesPath string
	templates     map[TemplateType]map[string]*template.Template
}

func NewTemplateCache(templatesPath string) *TemplateCache {
	tc := &TemplateCache{
		templatesPath: templatesPath,
		templates:     make(map[TemplateType]map[string]*template.Template),
	}
	return tc
}

func (tc *TemplateCache) initializeTemplateMapKey(tt TemplateType) {
	if _, ok := tc.templates[tt]; !ok {
		tc.templates[tt] = make(map[string]*template.Template)
	}
}

func (tc *TemplateCache) EvaluateTemplate(templateType TemplateType, templateName string, in interface{}) (string, error) {
	tc.mu.Lock()
	defer tc.mu.Unlock()

	tc.initializeTemplateMapKey(templateType)
	m, ok := tc.templates[templateType][templateName]
	if !ok {
		// return "", fmt.Errorf("template not loaded: %s", templateName)
		loadErr := tc.loadTemplateIfExists(templateType, templateName)
		if loadErr != nil {
			return "", loadErr
		}
		m = tc.templates[templateType][templateName] // ok is not important since we check m on the next line, and wealready checked
	}
	if m == nil {
		return "", fmt.Errorf("failed loading a template for %s", templateName)
	}

	var buf bytes.Buffer

	if err := m.Execute(&buf, in); err != nil {
		return "", err
	}
	return buf.String(), nil
}

func (tc *TemplateCache) loadTemplateIfExists(templateType TemplateType, templateName string) error {

	// Check if the template was already loaded
	if _, ok := tc.templates[templateType][templateName]; ok {
		return nil
	}

	// Check if the model path exists
	// skip any error here - we run anyway if a template does not exist
	modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName)

	dat := ""
	file := filepath.Join(tc.templatesPath, modelTemplateFile)

	// Security check
	if err := utils.VerifyPath(modelTemplateFile, tc.templatesPath); err != nil {
		return fmt.Errorf("template file outside path: %s", file)
	}

	// can either be a file in the system or a string with the template
	if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) {
		d, err := os.ReadFile(file)
		if err != nil {
			return err
		}
		dat = string(d)
	} else {
		dat = templateName
	}

	// Parse the template
	tmpl, err := template.New("prompt").Funcs(sprig.FuncMap()).Parse(dat)
	if err != nil {
		return err
	}
	tc.templates[templateType][templateName] = tmpl

	return nil
}