package datastore import ( "crypto/aes" "crypto/cipher" "crypto/rand" "errors" "fmt" "os" "sort" "git.dayanhub.com/sagi/envoid/internal/common" "git.dayanhub.com/sagi/envoid/internal/config" intErrors "git.dayanhub.com/sagi/envoid/internal/errors" "git.dayanhub.com/sagi/envoid/internal/types" "git.dayanhub.com/sagi/envoid/internal/variables" "golang.org/x/crypto/scrypt" "golang.org/x/sync/errgroup" ) type datastore struct { db *db } func NewDataStore() (*datastore, error) { db, err := newDB() if err != nil { return nil, err } return &datastore{ db: db, }, nil } func (d *datastore) CreateEnv(name string) error { table_name := envNameToTableName(name) return d.db.createTableIfNotExists(table_name) } func (d *datastore) DoesFileExists() bool { pwd := config.GetConfig().PWD filePath := fmt.Sprintf("%s/%s", pwd, variables.DBFileName) if _, err := os.Stat(filePath); errors.Is(err, os.ErrNotExist) { return false } return true } func (d *datastore) ListEnvironments() ([]string, error) { return d.db.listTables() } func (d *datastore) CreateEnvOffExsisting(new_env string, base_env string) error { table_name_new := envNameToTableName(new_env) table_name_base := envNameToTableName(base_env) err := d.CreateEnv(new_env) if err != nil { return err } err = d.db.copyContentFromTo(table_name_base, table_name_new) if err != nil { return err } return nil } func (d *datastore) Close() error { return d.db.close() } func (d *datastore) SetValue(key string, value string, encrypted *bool, envs []*types.Environment) error { if encrypted == nil { encrypted = common.BoolP(false) } if *encrypted { v, err := enc(value) if err != nil { return err } value = *v } for _, env := range envs { table_name := envNameToTableName(env.Name) if err := d.db.setVar(table_name, key, value, *encrypted); err != nil { return err } } return nil } func (d *datastore) GetAll(envName string) ([]*types.EnvVar, error) { table_name := envNameToTableName(envName) vars, err := d.db.getAll(table_name) if err != nil { return vars, err } g := new(errgroup.Group) for _, v := range vars { g.Go(func() error { if v.Encrypted { if v.Value, err = dec(v.Value); err != nil { return &intErrors.InvalidPasswordError{} } } return nil }) } if err := g.Wait(); err != nil { return vars, err } sort.SliceStable(vars, func(i, j int) bool { return vars[i].Key < vars[j].Key }) return vars, nil } func (d *datastore) RemoveVar(key string, envs []*types.Environment) { for _, env := range envs { table_name := envNameToTableName(env.Name) d.db.rmVar(table_name, key) } } func (d *datastore) GetVar(envName string, key string) (*types.EnvVar, error) { table_name := envNameToTableName(envName) v, err := d.db.getVar(table_name, key) if err != nil { return v, err } if v.Encrypted { if v.Value, err = dec(v.Value); err != nil { return v, &intErrors.InvalidPasswordError{} } } return v, err } func (d *datastore) RemoveEnv(envName string) error { table_name := envNameToTableName(envName) return d.db.deleteTable(table_name) } func enc(s string) (*string, error) { conf := config.GetConfig() proj, _ := conf.GetProject(conf.PWD) key, salt, err := deriveKey([]byte(proj.Password), nil) data := []byte(s) if err != nil { return nil, err } blockCipher, err := aes.NewCipher(key) if err != nil { return nil, err } gcm, err := cipher.NewGCM(blockCipher) if err != nil { return nil, err } nonce := make([]byte, gcm.NonceSize()) if _, err = rand.Read(nonce); err != nil { return nil, err } ciphertext := gcm.Seal(nonce, nonce, data, nil) ciphertext = append(ciphertext, salt...) str := string(ciphertext) return &str, nil } func dec(s string) (string, error) { data := []byte(s) salt, data := data[len(data)-32:], data[:len(data)-32] conf := config.GetConfig() proj, _ := conf.GetProject(conf.PWD) key, _, err := deriveKey([]byte(proj.Password), salt) if err != nil { return "", err } blockCipher, err := aes.NewCipher(key) if err != nil { return "", err } gcm, err := cipher.NewGCM(blockCipher) if err != nil { return "", err } nonce, ciphertext := data[:gcm.NonceSize()], data[gcm.NonceSize():] plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { return "", err } str := string(plaintext) return str, nil } func deriveKey(password, salt []byte) ([]byte, []byte, error) { if salt == nil { salt = make([]byte, 32) if _, err := rand.Read(salt); err != nil { return nil, nil, err } } key, err := scrypt.Key(password, salt, 1048576, 8, 1, 32) if err != nil { return nil, nil, err } return key, salt, nil }