schema.go 5.17 KB
package models

import (
	"github.com/golang/protobuf/proto"
	"google.golang.org/protobuf/reflect/protoreflect"
	"pro2d/common/components"
	"pro2d/common/db/mongoproxy"
	"pro2d/common/logger"
	"reflect"
	"strings"
)

type SchemaOption func(schema *Schema)

func WithSchemaDB(idb components.IDB) SchemaOption {
	return func(schema *Schema) {
		schema.db = idb
	}
}

type SchemaMap map[string]components.ISchema

type Schema struct {
	conn         components.IConnection
	db           components.IDB
	reflectValue *reflect.Value
	reflectIndex map[string]int
	protoIndex   map[string]int

	cacheFields map[string]interface{}

	pri    interface{}
	schema interface{}
}

func NewSchema(key string, schema interface{}) *Schema {
	s := reflect.ValueOf(schema)
	if s.Kind() == reflect.Ptr {
		s = reflect.ValueOf(schema).Elem()
	}
	sch := &Schema{
		reflectValue: &s,
		cacheFields:  make(map[string]interface{}),
		schema:       schema,
		reflectIndex: make(map[string]int),
		protoIndex:   make(map[string]int),
	}

	p := proto.MessageReflect(schema.(proto.Message)).Descriptor()
	for i := 0; i < p.Fields().Len(); i++ {
		sch.protoIndex[strings.ToLower(p.Fields().Get(i).JSONName())] = i
	}

	for i := 0; i < s.Type().NumField(); i++ {
		name := s.Type().Field(i).Name
		if strings.Compare(name[0:1], strings.ToLower(name[0:1])) == 0 {
			continue
		}
		sch.reflectIndex[strings.ToLower(name)] = i
	}

	sch.db = mongoproxy.NewMongoColl(sch.GetSchemaName(), sch)
	sch.pri = mongoproxy.GetBsonD(sch.getPriTag(), key)
	return sch
}

func (s *Schema) getPriTag() string {
	var pri string
	for i := 0; i < s.reflectValue.Type().NumField(); i++ {
		if s.reflectValue.Type().Field(i).Tag.Get("pri") == "1" {
			pri = strings.ToLower(s.reflectValue.Type().Field(i).Name)
			break
		}
	}
	return pri
}

func (s *Schema) FindIndex() (string, []string) {
	var index []string
	for i := 0; i < s.reflectValue.Type().NumField(); i++ {
		if s.reflectValue.Type().Field(i).Tag.Get("index") != "" {
			js := strings.Split(s.reflectValue.Type().Field(i).Tag.Get("json"), ",")
			if len(js) == 0 {
				continue
			}
			index = append(index, js[0])
		}
	}
	return strings.ToLower(s.reflectValue.Type().Name()), index
}

func (s *Schema) Init() {
	coll, keys := s.FindIndex()
	for _, index := range keys {
		s.db.CreateTable()

		logger.Debug("InitDoc collect: %v, createIndex: %s", coll, index)
		res, err := s.db.SetUnique(index)
		if err != nil {
			logger.Error("InitDoc unique: %s, err: %v", res, err)
			continue
		}
	}
}

func (s *Schema) GetDB() components.IDB {
	return s.db
}

func (s *Schema) GetPri() interface{} {
	return s.pri
}

func (s *Schema) GetSchema() interface{} {
	return s.schema
}

func (s *Schema) GetSchemaName() string {
	return strings.ToLower(s.reflectValue.Type().Name())
}

func (s *Schema) UpdateSchema(schema interface{}) {
	sch := reflect.ValueOf(schema)
	if sch.Kind() == reflect.Ptr {
		sch = reflect.ValueOf(schema).Elem()
	}
	for i := 0; i < sch.Type().NumField(); i++ {
		name := sch.Type().Field(i).Name
		if _, ok := s.reflectIndex[strings.ToLower(name)]; !ok {
			continue
		}
		s.SetProperty(sch.Type().Field(i).Name, sch.Field(i).Interface())
	}
}

func (s *Schema) SetConn(conn components.IConnection) {
	s.conn = conn
}

func (s *Schema) GetConn() components.IConnection {
	return s.conn
}

func (s *Schema) Load() error {
	return s.db.Load()
}

func (s *Schema) Create() error {
	_, err := s.db.Create()
	return err
}

//更新缓存字段到数据库
func (s *Schema) Update() {
	if len(s.cacheFields) > 0 {
		if err := s.db.UpdateProperties(s.cacheFields); err != nil {
			logger.Error("%s, UpdateErr: %s", s.GetSchemaName(), err.Error())
			return
		}
		s.cacheFields = make(map[string]interface{})
	}
}

//更新内存,并把字段缓存
func (s *Schema) SetProperty(key string, val interface{}) {
	idx, ok := s.reflectIndex[strings.ToLower(key)]
	if !ok {
		return
	}
	s.reflectValue.Field(idx).Set(reflect.ValueOf(val))
	s.cacheFields[strings.ToLower(key)] = val
}

func (s *Schema) GetProperty(key string) interface{} {
	idx, ok := s.reflectIndex[strings.ToLower(key)]
	if !ok {
		return nil
	}
	return s.reflectValue.Field(idx).Interface()
}

//更新内存,并把字段缓存
func (s *Schema) SetProperties(properties map[string]interface{}) {
	for key, val := range properties {
		idx, ok := s.reflectIndex[strings.ToLower(key)]
		if !ok {
			continue
		}

		s.reflectValue.Field(idx).Set(reflect.ValueOf(val))
		s.cacheFields[strings.ToLower(key)] = val
	}
}

func (s *Schema) IncrProperty(key string, val int64) int64 {
	idx, ok := s.reflectIndex[strings.ToLower(key)]
	if !ok {
		return 0
	}
	field := s.reflectValue.Field(idx)
	var v int64
	switch field.Kind() {
	case reflect.Int64:
		v = field.Int() + val
	case reflect.Int:
		v = field.Int() + val
	}
	s.SetProperty(key, v)
	return v
}

func (s *Schema) ParseFields(message protoreflect.Message, properties map[string]interface{}) []int32 {
	ids := make([]int32, 0, len(properties))

	for k, v := range properties {
		idx, ok := s.protoIndex[strings.ToLower(k)]
		if !ok {
			continue
		}
		field := message.Descriptor().Fields().Get(idx)
		if field == nil {
			continue
		}

		ids = append(ids, int32(field.Index()))

		message.Set(field, protoreflect.ValueOf(v))

		s.SetProperty(k, v)
	}

	return ids
}