Streamline database support code

This commit is contained in:
Drew DeVault 2020-09-24 14:42:15 -04:00
parent 646e9d90c4
commit 6a1a8f1031
2 changed files with 157 additions and 46 deletions

View File

@ -29,47 +29,13 @@ func collectFields(ctx context.Context) []graphql.CollectedField {
return fields
}
func ColumnsFor(ctx context.Context, alias string,
colMap map[string]string) []string {
fields := collectFields(ctx)
if len(fields) == 0 {
// Collect all fields if we are not in an active graphql context
for qlCol, _ := range colMap {
fields = append(fields, graphql.CollectedField{
&ast.Field{Name: qlCol}, nil,
})
}
}
sort.Slice(fields, func(a, b int) bool {
return fields[a].Name < fields[b].Name
})
var columns []string
for _, qlCol := range fields {
if sqlCol, ok := colMap[qlCol.Name]; ok {
if alias != "" {
columns = append(columns, pq.QuoteIdentifier(alias)+
"."+pq.QuoteIdentifier(sqlCol))
} else {
columns = append(columns, pq.QuoteIdentifier(sqlCol))
}
}
}
return columns
}
func FieldsFor(ctx context.Context,
colMap map[string]interface{}) []interface{} {
func Scan(ctx context.Context, m Model) []interface{} {
qlFields := collectFields(ctx)
if len(qlFields) == 0 {
// Collect all fields if we are not in an active graphql context
for qlCol, _ := range colMap {
for _, field := range m.Fields().All() {
qlFields = append(qlFields, graphql.CollectedField{
&ast.Field{Name: qlCol}, nil,
&ast.Field{Name: field.GQL}, nil,
})
}
}
@ -80,18 +46,51 @@ func FieldsFor(ctx context.Context,
var fields []interface{}
for _, qlField := range qlFields {
if field, ok := colMap[qlField.Name]; ok {
fields = append(fields, field)
if field, ok := m.Fields().GQL(qlField.Name); ok {
fields = append(fields, field.Ptr)
}
}
for _, field := range m.Fields().Anonymous() {
fields = append(fields, field.Ptr)
}
return fields
}
func Columns(ctx context.Context, m Model) []string {
fields := collectFields(ctx)
if len(fields) == 0 {
// Collect all fields if we are not in an active graphql context
for _, field := range m.Fields().All() {
fields = append(fields, graphql.CollectedField{
&ast.Field{Name: field.GQL}, nil,
})
}
}
sort.Slice(fields, func(a, b int) bool {
return fields[a].Name < fields[b].Name
})
var columns []string
for _, gql := range fields {
if field, ok := m.Fields().GQL(gql.Name); ok {
columns = append(columns, WithAlias(m.Alias(), field.SQL))
}
}
for _, field := range m.Fields().Anonymous() {
columns = append(columns, WithAlias(m.Alias(), field.SQL))
}
return columns
}
func WithAlias(alias, col string) string {
if alias != "" {
return alias + "." + col
return pq.QuoteIdentifier(alias) + "." + pq.QuoteIdentifier(col)
} else {
return col
return pq.QuoteIdentifier(col)
}
}

View File

@ -3,13 +3,75 @@ package database
import (
"context"
"fmt"
"reflect"
sq "github.com/Masterminds/squirrel"
)
type Selectable interface {
Select(ctx context.Context) []string
Fields(ctx context.Context) []interface{}
// Provides a mapping between PostgreSQL columns, GQL fields, and Go struct
// fields for all of the data associated with a model.
type FieldMap struct {
SQL string
GQL string
Ptr interface{}
}
type ModelFields struct {
Fields []*FieldMap
byGQL map[string]*FieldMap
bySQL map[string]*FieldMap
anon []*FieldMap
}
func (mf *ModelFields) buildCache() {
if mf.byGQL != nil && mf.bySQL != nil {
return
}
mf.byGQL = make(map[string]*FieldMap)
mf.bySQL = make(map[string]*FieldMap)
for _, f := range mf.Fields {
if f.GQL != "" {
mf.byGQL[f.GQL] = f
} else {
mf.anon = append(mf.anon, f)
}
mf.bySQL[f.SQL] = f
}
}
func (mf *ModelFields) GQL(name string) (*FieldMap, bool) {
mf.buildCache()
if f, ok := mf.byGQL[name]; !ok {
return nil, false
} else {
return f, true
}
}
func (mf *ModelFields) SQL(name string) (*FieldMap, bool) {
mf.buildCache()
if f, ok := mf.bySQL[name]; !ok {
return nil, false
} else {
return f, true
}
}
func (mf *ModelFields) All() []*FieldMap {
return mf.Fields
}
func (mf *ModelFields) Anonymous() []*FieldMap {
mf.buildCache()
return mf.anon
}
type Model interface {
Alias() string
Fields() *ModelFields
Table() string
}
func Select(ctx context.Context, cols ...interface{}) sq.SelectBuilder {
@ -20,11 +82,61 @@ func Select(ctx context.Context, cols ...interface{}) sq.SelectBuilder {
q = q.Columns(col)
case []string:
q = q.Columns(col...)
case Selectable:
q = q.Columns(col.Select(ctx)...)
case Model:
q = q.Columns(Columns(ctx, col)...)
default:
panic(fmt.Errorf("Unknown selectable type %T", col))
}
}
return q
}
// Prepares an UPDATE statement which applies the changes in the input map to
// the given model.
func Apply(m Model, input map[string]interface{}) sq.UpdateBuilder {
// XXX: This relies on the GraphQL validator to prevent the user from
// updating columns they're not supposed to. Risky?
table := m.Table()
if m.Alias() != "" {
table += " " + m.Alias()
}
update := sq.Update(table).PlaceholderFormat(sq.Dollar)
defer func() {
// Some weird reflection errors don't get properly logged if they're
// caught at a higher level.
if err := recover(); err != nil {
fmt.Printf("%v\n", err)
panic(err)
}
}()
for field, value := range input {
f, ok := m.Fields().GQL(field)
if !ok {
continue
}
var (
pv reflect.Value = reflect.Indirect(reflect.ValueOf(f.Ptr))
rv reflect.Value = reflect.ValueOf(value)
)
if pv.Type().Kind() == reflect.Ptr {
if !rv.IsValid() {
pv.Set(reflect.Zero(pv.Type()))
update = update.Set(WithAlias(m.Alias(), f.SQL), nil)
} else {
if !pv.Elem().IsValid() {
pv.Set(reflect.New(pv.Type().Elem()))
}
reflect.Indirect(pv).Set(reflect.Indirect(rv))
update = update.Set(WithAlias(m.Alias(), f.SQL),
reflect.Indirect(rv).Interface())
}
} else {
panic(fmt.Errorf("TODO"))
}
}
return update
}