From 6a1a8f1031f3c2c7ea60a66c62eea8a7c8af7237 Mon Sep 17 00:00:00 2001 From: Drew DeVault Date: Thu, 24 Sep 2020 14:42:15 -0400 Subject: [PATCH] Streamline database support code --- database/ql.go | 81 ++++++++++++++++---------------- database/sq.go | 122 +++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 157 insertions(+), 46 deletions(-) diff --git a/database/ql.go b/database/ql.go index 771633d..bb8fdda 100644 --- a/database/ql.go +++ b/database/ql.go @@ -29,47 +29,13 @@ func collectFields(ctx context.Context) []graphql.CollectedField { return fields } -func ColumnsFor(ctx context.Context, alias string, - colMap map[string]string) []string { - - fields := collectFields(ctx) - if len(fields) == 0 { - // Collect all fields if we are not in an active graphql context - for qlCol, _ := range colMap { - fields = append(fields, graphql.CollectedField{ - &ast.Field{Name: qlCol}, nil, - }) - } - } - - sort.Slice(fields, func(a, b int) bool { - return fields[a].Name < fields[b].Name - }) - - var columns []string - for _, qlCol := range fields { - if sqlCol, ok := colMap[qlCol.Name]; ok { - if alias != "" { - columns = append(columns, pq.QuoteIdentifier(alias)+ - "."+pq.QuoteIdentifier(sqlCol)) - } else { - columns = append(columns, pq.QuoteIdentifier(sqlCol)) - } - } - } - - return columns -} - -func FieldsFor(ctx context.Context, - colMap map[string]interface{}) []interface{} { - +func Scan(ctx context.Context, m Model) []interface{} { qlFields := collectFields(ctx) if len(qlFields) == 0 { // Collect all fields if we are not in an active graphql context - for qlCol, _ := range colMap { + for _, field := range m.Fields().All() { qlFields = append(qlFields, graphql.CollectedField{ - &ast.Field{Name: qlCol}, nil, + &ast.Field{Name: field.GQL}, nil, }) } } @@ -80,18 +46,51 @@ func FieldsFor(ctx context.Context, var fields []interface{} for _, qlField := range qlFields { - if field, ok := colMap[qlField.Name]; ok { - fields = append(fields, field) + if field, ok := m.Fields().GQL(qlField.Name); ok { + fields = append(fields, field.Ptr) } } + for _, field := range m.Fields().Anonymous() { + fields = append(fields, field.Ptr) + } + return fields } +func Columns(ctx context.Context, m Model) []string { + fields := collectFields(ctx) + if len(fields) == 0 { + // Collect all fields if we are not in an active graphql context + for _, field := range m.Fields().All() { + fields = append(fields, graphql.CollectedField{ + &ast.Field{Name: field.GQL}, nil, + }) + } + } + + sort.Slice(fields, func(a, b int) bool { + return fields[a].Name < fields[b].Name + }) + + var columns []string + for _, gql := range fields { + if field, ok := m.Fields().GQL(gql.Name); ok { + columns = append(columns, WithAlias(m.Alias(), field.SQL)) + } + } + + for _, field := range m.Fields().Anonymous() { + columns = append(columns, WithAlias(m.Alias(), field.SQL)) + } + + return columns +} + func WithAlias(alias, col string) string { if alias != "" { - return alias + "." + col + return pq.QuoteIdentifier(alias) + "." + pq.QuoteIdentifier(col) } else { - return col + return pq.QuoteIdentifier(col) } } diff --git a/database/sq.go b/database/sq.go index 22b2558..58b558e 100644 --- a/database/sq.go +++ b/database/sq.go @@ -3,13 +3,75 @@ package database import ( "context" "fmt" + "reflect" sq "github.com/Masterminds/squirrel" ) -type Selectable interface { - Select(ctx context.Context) []string - Fields(ctx context.Context) []interface{} +// Provides a mapping between PostgreSQL columns, GQL fields, and Go struct +// fields for all of the data associated with a model. +type FieldMap struct { + SQL string + GQL string + Ptr interface{} +} + +type ModelFields struct { + Fields []*FieldMap + + byGQL map[string]*FieldMap + bySQL map[string]*FieldMap + anon []*FieldMap +} + +func (mf *ModelFields) buildCache() { + if mf.byGQL != nil && mf.bySQL != nil { + return + } + + mf.byGQL = make(map[string]*FieldMap) + mf.bySQL = make(map[string]*FieldMap) + for _, f := range mf.Fields { + if f.GQL != "" { + mf.byGQL[f.GQL] = f + } else { + mf.anon = append(mf.anon, f) + } + mf.bySQL[f.SQL] = f + } +} + +func (mf *ModelFields) GQL(name string) (*FieldMap, bool) { + mf.buildCache() + if f, ok := mf.byGQL[name]; !ok { + return nil, false + } else { + return f, true + } +} + +func (mf *ModelFields) SQL(name string) (*FieldMap, bool) { + mf.buildCache() + if f, ok := mf.bySQL[name]; !ok { + return nil, false + } else { + return f, true + } +} + +func (mf *ModelFields) All() []*FieldMap { + return mf.Fields +} + +func (mf *ModelFields) Anonymous() []*FieldMap { + mf.buildCache() + return mf.anon +} + +type Model interface { + Alias() string + Fields() *ModelFields + Table() string } func Select(ctx context.Context, cols ...interface{}) sq.SelectBuilder { @@ -20,11 +82,61 @@ func Select(ctx context.Context, cols ...interface{}) sq.SelectBuilder { q = q.Columns(col) case []string: q = q.Columns(col...) - case Selectable: - q = q.Columns(col.Select(ctx)...) + case Model: + q = q.Columns(Columns(ctx, col)...) default: panic(fmt.Errorf("Unknown selectable type %T", col)) } } return q } + +// Prepares an UPDATE statement which applies the changes in the input map to +// the given model. +func Apply(m Model, input map[string]interface{}) sq.UpdateBuilder { + // XXX: This relies on the GraphQL validator to prevent the user from + // updating columns they're not supposed to. Risky? + table := m.Table() + if m.Alias() != "" { + table += " " + m.Alias() + } + update := sq.Update(table).PlaceholderFormat(sq.Dollar) + + defer func() { + // Some weird reflection errors don't get properly logged if they're + // caught at a higher level. + if err := recover(); err != nil { + fmt.Printf("%v\n", err) + panic(err) + } + }() + + for field, value := range input { + f, ok := m.Fields().GQL(field) + if !ok { + continue + } + + var ( + pv reflect.Value = reflect.Indirect(reflect.ValueOf(f.Ptr)) + rv reflect.Value = reflect.ValueOf(value) + ) + if pv.Type().Kind() == reflect.Ptr { + if !rv.IsValid() { + pv.Set(reflect.Zero(pv.Type())) + update = update.Set(WithAlias(m.Alias(), f.SQL), nil) + } else { + if !pv.Elem().IsValid() { + pv.Set(reflect.New(pv.Type().Elem())) + } + reflect.Indirect(pv).Set(reflect.Indirect(rv)) + update = update.Set(WithAlias(m.Alias(), f.SQL), + reflect.Indirect(rv).Interface()) + } + } else { + panic(fmt.Errorf("TODO")) + } + } + + return update +}