mirror of https://git.sr.ht/~sircmpwn/gql.sr.ht
Streamline database support code
This commit is contained in:
parent
646e9d90c4
commit
6a1a8f1031
|
@ -29,47 +29,13 @@ func collectFields(ctx context.Context) []graphql.CollectedField {
|
|||
return fields
|
||||
}
|
||||
|
||||
func ColumnsFor(ctx context.Context, alias string,
|
||||
colMap map[string]string) []string {
|
||||
|
||||
fields := collectFields(ctx)
|
||||
if len(fields) == 0 {
|
||||
// Collect all fields if we are not in an active graphql context
|
||||
for qlCol, _ := range colMap {
|
||||
fields = append(fields, graphql.CollectedField{
|
||||
&ast.Field{Name: qlCol}, nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(fields, func(a, b int) bool {
|
||||
return fields[a].Name < fields[b].Name
|
||||
})
|
||||
|
||||
var columns []string
|
||||
for _, qlCol := range fields {
|
||||
if sqlCol, ok := colMap[qlCol.Name]; ok {
|
||||
if alias != "" {
|
||||
columns = append(columns, pq.QuoteIdentifier(alias)+
|
||||
"."+pq.QuoteIdentifier(sqlCol))
|
||||
} else {
|
||||
columns = append(columns, pq.QuoteIdentifier(sqlCol))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
func FieldsFor(ctx context.Context,
|
||||
colMap map[string]interface{}) []interface{} {
|
||||
|
||||
func Scan(ctx context.Context, m Model) []interface{} {
|
||||
qlFields := collectFields(ctx)
|
||||
if len(qlFields) == 0 {
|
||||
// Collect all fields if we are not in an active graphql context
|
||||
for qlCol, _ := range colMap {
|
||||
for _, field := range m.Fields().All() {
|
||||
qlFields = append(qlFields, graphql.CollectedField{
|
||||
&ast.Field{Name: qlCol}, nil,
|
||||
&ast.Field{Name: field.GQL}, nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -80,18 +46,51 @@ func FieldsFor(ctx context.Context,
|
|||
|
||||
var fields []interface{}
|
||||
for _, qlField := range qlFields {
|
||||
if field, ok := colMap[qlField.Name]; ok {
|
||||
fields = append(fields, field)
|
||||
if field, ok := m.Fields().GQL(qlField.Name); ok {
|
||||
fields = append(fields, field.Ptr)
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range m.Fields().Anonymous() {
|
||||
fields = append(fields, field.Ptr)
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
func Columns(ctx context.Context, m Model) []string {
|
||||
fields := collectFields(ctx)
|
||||
if len(fields) == 0 {
|
||||
// Collect all fields if we are not in an active graphql context
|
||||
for _, field := range m.Fields().All() {
|
||||
fields = append(fields, graphql.CollectedField{
|
||||
&ast.Field{Name: field.GQL}, nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(fields, func(a, b int) bool {
|
||||
return fields[a].Name < fields[b].Name
|
||||
})
|
||||
|
||||
var columns []string
|
||||
for _, gql := range fields {
|
||||
if field, ok := m.Fields().GQL(gql.Name); ok {
|
||||
columns = append(columns, WithAlias(m.Alias(), field.SQL))
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range m.Fields().Anonymous() {
|
||||
columns = append(columns, WithAlias(m.Alias(), field.SQL))
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
func WithAlias(alias, col string) string {
|
||||
if alias != "" {
|
||||
return alias + "." + col
|
||||
return pq.QuoteIdentifier(alias) + "." + pq.QuoteIdentifier(col)
|
||||
} else {
|
||||
return col
|
||||
return pq.QuoteIdentifier(col)
|
||||
}
|
||||
}
|
||||
|
|
122
database/sq.go
122
database/sq.go
|
@ -3,13 +3,75 @@ package database
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
)
|
||||
|
||||
type Selectable interface {
|
||||
Select(ctx context.Context) []string
|
||||
Fields(ctx context.Context) []interface{}
|
||||
// Provides a mapping between PostgreSQL columns, GQL fields, and Go struct
|
||||
// fields for all of the data associated with a model.
|
||||
type FieldMap struct {
|
||||
SQL string
|
||||
GQL string
|
||||
Ptr interface{}
|
||||
}
|
||||
|
||||
type ModelFields struct {
|
||||
Fields []*FieldMap
|
||||
|
||||
byGQL map[string]*FieldMap
|
||||
bySQL map[string]*FieldMap
|
||||
anon []*FieldMap
|
||||
}
|
||||
|
||||
func (mf *ModelFields) buildCache() {
|
||||
if mf.byGQL != nil && mf.bySQL != nil {
|
||||
return
|
||||
}
|
||||
|
||||
mf.byGQL = make(map[string]*FieldMap)
|
||||
mf.bySQL = make(map[string]*FieldMap)
|
||||
for _, f := range mf.Fields {
|
||||
if f.GQL != "" {
|
||||
mf.byGQL[f.GQL] = f
|
||||
} else {
|
||||
mf.anon = append(mf.anon, f)
|
||||
}
|
||||
mf.bySQL[f.SQL] = f
|
||||
}
|
||||
}
|
||||
|
||||
func (mf *ModelFields) GQL(name string) (*FieldMap, bool) {
|
||||
mf.buildCache()
|
||||
if f, ok := mf.byGQL[name]; !ok {
|
||||
return nil, false
|
||||
} else {
|
||||
return f, true
|
||||
}
|
||||
}
|
||||
|
||||
func (mf *ModelFields) SQL(name string) (*FieldMap, bool) {
|
||||
mf.buildCache()
|
||||
if f, ok := mf.bySQL[name]; !ok {
|
||||
return nil, false
|
||||
} else {
|
||||
return f, true
|
||||
}
|
||||
}
|
||||
|
||||
func (mf *ModelFields) All() []*FieldMap {
|
||||
return mf.Fields
|
||||
}
|
||||
|
||||
func (mf *ModelFields) Anonymous() []*FieldMap {
|
||||
mf.buildCache()
|
||||
return mf.anon
|
||||
}
|
||||
|
||||
type Model interface {
|
||||
Alias() string
|
||||
Fields() *ModelFields
|
||||
Table() string
|
||||
}
|
||||
|
||||
func Select(ctx context.Context, cols ...interface{}) sq.SelectBuilder {
|
||||
|
@ -20,11 +82,61 @@ func Select(ctx context.Context, cols ...interface{}) sq.SelectBuilder {
|
|||
q = q.Columns(col)
|
||||
case []string:
|
||||
q = q.Columns(col...)
|
||||
case Selectable:
|
||||
q = q.Columns(col.Select(ctx)...)
|
||||
case Model:
|
||||
q = q.Columns(Columns(ctx, col)...)
|
||||
default:
|
||||
panic(fmt.Errorf("Unknown selectable type %T", col))
|
||||
}
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// Prepares an UPDATE statement which applies the changes in the input map to
|
||||
// the given model.
|
||||
func Apply(m Model, input map[string]interface{}) sq.UpdateBuilder {
|
||||
// XXX: This relies on the GraphQL validator to prevent the user from
|
||||
// updating columns they're not supposed to. Risky?
|
||||
table := m.Table()
|
||||
if m.Alias() != "" {
|
||||
table += " " + m.Alias()
|
||||
}
|
||||
update := sq.Update(table).PlaceholderFormat(sq.Dollar)
|
||||
|
||||
defer func() {
|
||||
// Some weird reflection errors don't get properly logged if they're
|
||||
// caught at a higher level.
|
||||
if err := recover(); err != nil {
|
||||
fmt.Printf("%v\n", err)
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
for field, value := range input {
|
||||
f, ok := m.Fields().GQL(field)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
var (
|
||||
pv reflect.Value = reflect.Indirect(reflect.ValueOf(f.Ptr))
|
||||
rv reflect.Value = reflect.ValueOf(value)
|
||||
)
|
||||
if pv.Type().Kind() == reflect.Ptr {
|
||||
if !rv.IsValid() {
|
||||
pv.Set(reflect.Zero(pv.Type()))
|
||||
update = update.Set(WithAlias(m.Alias(), f.SQL), nil)
|
||||
} else {
|
||||
if !pv.Elem().IsValid() {
|
||||
pv.Set(reflect.New(pv.Type().Elem()))
|
||||
}
|
||||
reflect.Indirect(pv).Set(reflect.Indirect(rv))
|
||||
update = update.Set(WithAlias(m.Alias(), f.SQL),
|
||||
reflect.Indirect(rv).Interface())
|
||||
}
|
||||
} else {
|
||||
panic(fmt.Errorf("TODO"))
|
||||
}
|
||||
}
|
||||
|
||||
return update
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue