291 lines
6.3 KiB
Go
291 lines
6.3 KiB
Go
package vars
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"regexp"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/hashicorp/go-multierror"
|
|
"gopkg.in/yaml.v2"
|
|
)
|
|
|
|
type Template struct {
|
|
bytes []byte
|
|
}
|
|
|
|
type EvaluateOpts struct {
|
|
ExpectAllKeys bool
|
|
ExpectAllVarsUsed bool
|
|
}
|
|
|
|
func NewTemplate(bytes []byte) Template {
|
|
return Template{bytes: bytes}
|
|
}
|
|
|
|
func (t Template) ExtraVarNames() []string {
|
|
return interpolator{}.extractVarNames(string(t.bytes))
|
|
}
|
|
|
|
func (t Template) Evaluate(vars Variables, opts EvaluateOpts) ([]byte, error) {
|
|
var obj interface{}
|
|
|
|
err := yaml.Unmarshal(t.bytes, &obj)
|
|
if err != nil {
|
|
return []byte{}, err
|
|
}
|
|
|
|
obj, err = t.interpolateRoot(obj, newVarsTracker(vars, opts.ExpectAllKeys, opts.ExpectAllVarsUsed))
|
|
if err != nil {
|
|
return []byte{}, err
|
|
}
|
|
|
|
bytes, err := yaml.Marshal(obj)
|
|
if err != nil {
|
|
return []byte{}, err
|
|
}
|
|
|
|
return bytes, nil
|
|
}
|
|
|
|
func (t Template) interpolateRoot(obj interface{}, tracker varsTracker) (interface{}, error) {
|
|
var err error
|
|
obj, err = interpolator{}.Interpolate(obj, varsLookup{tracker})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return obj, tracker.Error()
|
|
}
|
|
|
|
type interpolator struct{}
|
|
|
|
var (
|
|
interpolationRegex = regexp.MustCompile(`\(\((!?([-/\.\w\pL]+\:)?[-/\.\w\pL]+)\)\)`)
|
|
interpolationAnchoredRegex = regexp.MustCompile("\\A" + interpolationRegex.String() + "\\z")
|
|
)
|
|
|
|
func (i interpolator) Interpolate(node interface{}, varsLookup varsLookup) (interface{}, error) {
|
|
switch typedNode := node.(type) {
|
|
case map[interface{}]interface{}:
|
|
for k, v := range typedNode {
|
|
evaluatedValue, err := i.Interpolate(v, varsLookup)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
evaluatedKey, err := i.Interpolate(k, varsLookup)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
delete(typedNode, k) // delete in case key has changed
|
|
typedNode[evaluatedKey] = evaluatedValue
|
|
}
|
|
|
|
case []interface{}:
|
|
for idx, x := range typedNode {
|
|
var err error
|
|
typedNode[idx], err = i.Interpolate(x, varsLookup)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
case string:
|
|
for _, name := range i.extractVarNames(typedNode) {
|
|
foundVal, found, err := varsLookup.Get(name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("var lookup '%s': %w", name, err)
|
|
}
|
|
|
|
if found {
|
|
// ensure that value type is preserved when replacing the entire field
|
|
if interpolationAnchoredRegex.MatchString(typedNode) {
|
|
return foundVal, nil
|
|
}
|
|
|
|
switch foundVal.(type) {
|
|
case string, int, int16, int32, int64, uint, uint16, uint32, uint64:
|
|
foundValStr := fmt.Sprintf("%v", foundVal)
|
|
typedNode = strings.Replace(typedNode, fmt.Sprintf("((%s))", name), foundValStr, -1)
|
|
typedNode = strings.Replace(typedNode, fmt.Sprintf("((!%s))", name), foundValStr, -1)
|
|
default:
|
|
return nil, InvalidInterpolationError{
|
|
Path: name,
|
|
Value: foundVal,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return typedNode, nil
|
|
}
|
|
|
|
return node, nil
|
|
}
|
|
|
|
func (i interpolator) extractVarNames(value string) []string {
|
|
var names []string
|
|
|
|
for _, match := range interpolationRegex.FindAllSubmatch([]byte(value), -1) {
|
|
names = append(names, strings.TrimPrefix(string(match[1]), "!"))
|
|
}
|
|
|
|
return names
|
|
}
|
|
|
|
type varsLookup struct {
|
|
varsTracker
|
|
}
|
|
|
|
var ErrEmptyVar = errors.New("empty var")
|
|
|
|
// Get value of a var. Name can be the following formats: 1) 'foo', where foo
|
|
// is var name; 2) 'foo:bar', where foo is var source name, and bar is var name;
|
|
// 3) '.:foo', where . means a local var, foo is var name.
|
|
func (l varsLookup) Get(name string) (interface{}, bool, error) {
|
|
var splitName []string
|
|
if strings.Index(name, ":") > 0 {
|
|
parts := strings.Split(name, ":")
|
|
splitName = strings.Split(parts[1], ".")
|
|
splitName[0] = fmt.Sprintf("%s:%s", parts[0], splitName[0])
|
|
} else {
|
|
splitName = strings.Split(name, ".")
|
|
}
|
|
|
|
// this should be impossible since interpolationRegex only matches non-empty
|
|
// vars, but better to error than to panic
|
|
if len(splitName) == 0 {
|
|
return nil, false, ErrEmptyVar
|
|
}
|
|
|
|
val, found, err := l.varsTracker.Get(splitName[0])
|
|
if !found || err != nil {
|
|
return val, found, err
|
|
}
|
|
|
|
for _, seg := range splitName[1:] {
|
|
switch v := val.(type) {
|
|
case map[interface{}]interface{}:
|
|
var found bool
|
|
val, found = v[seg]
|
|
if !found {
|
|
return nil, false, MissingFieldError{
|
|
Path: name,
|
|
Field: seg,
|
|
}
|
|
}
|
|
case map[string]interface{}:
|
|
var found bool
|
|
val, found = v[seg]
|
|
if !found {
|
|
return nil, false, MissingFieldError{
|
|
Path: name,
|
|
Field: seg,
|
|
}
|
|
}
|
|
default:
|
|
return nil, false, InvalidFieldError{
|
|
Path: name,
|
|
Field: seg,
|
|
Value: val,
|
|
}
|
|
}
|
|
}
|
|
|
|
return val, true, err
|
|
}
|
|
|
|
type varsTracker struct {
|
|
vars Variables
|
|
|
|
expectAllFound bool
|
|
expectAllUsed bool
|
|
|
|
missing map[string]struct{} // track missing var names
|
|
visited map[string]struct{}
|
|
visitedAll map[string]struct{} // track all var names that were accessed
|
|
}
|
|
|
|
func newVarsTracker(vars Variables, expectAllFound, expectAllUsed bool) varsTracker {
|
|
return varsTracker{
|
|
vars: vars,
|
|
expectAllFound: expectAllFound,
|
|
expectAllUsed: expectAllUsed,
|
|
missing: map[string]struct{}{},
|
|
visited: map[string]struct{}{},
|
|
visitedAll: map[string]struct{}{},
|
|
}
|
|
}
|
|
|
|
func (t varsTracker) Get(name string) (interface{}, bool, error) {
|
|
t.visitedAll[name] = struct{}{}
|
|
|
|
val, found, err := t.vars.Get(VariableDefinition{Name: name})
|
|
if !found {
|
|
t.missing[name] = struct{}{}
|
|
}
|
|
|
|
return val, found, err
|
|
}
|
|
|
|
func (t varsTracker) Error() error {
|
|
missingErr := t.MissingError()
|
|
extraErr := t.ExtraError()
|
|
if missingErr != nil && extraErr != nil {
|
|
return multierror.Append(missingErr, extraErr)
|
|
} else if missingErr != nil {
|
|
return missingErr
|
|
} else if extraErr != nil {
|
|
return extraErr
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t varsTracker) MissingError() error {
|
|
if !t.expectAllFound || len(t.missing) == 0 {
|
|
return nil
|
|
}
|
|
|
|
return UndefinedVarsError{Vars: names(t.missing)}
|
|
}
|
|
|
|
func (t varsTracker) ExtraError() error {
|
|
if !t.expectAllUsed {
|
|
return nil
|
|
}
|
|
|
|
allDefs, err := t.vars.List()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
unusedNames := map[string]struct{}{}
|
|
|
|
for _, def := range allDefs {
|
|
if _, found := t.visitedAll[def.Name]; !found {
|
|
unusedNames[def.Name] = struct{}{}
|
|
}
|
|
}
|
|
|
|
if len(unusedNames) == 0 {
|
|
return nil
|
|
}
|
|
|
|
return UnusedVarsError{Vars: names(unusedNames)}
|
|
}
|
|
|
|
func names(mapWithNames map[string]struct{}) []string {
|
|
var names []string
|
|
for name, _ := range mapWithNames {
|
|
names = append(names, name)
|
|
}
|
|
|
|
sort.Strings(names)
|
|
|
|
return names
|
|
}
|