package companydb import ( "database/sql" "errors" "fmt" "io/ioutil" tokenTools "multitenantStack/services/tokenTools" "os" "github.com/BurntSushi/toml" jwt "github.com/dgrijalva/jwt-go" ) type DBConfig struct { User string Password string Host string Port int Db string Ssl string } var Conf DBConfig var dbs map[string]*sql.DB // InitCompanyDBService Init companydb service and open system db connection func InitCompanyDBService() { tomlData, err := ioutil.ReadFile("conf/dbconfig.toml") if err != nil { // handle Read error panic(err.Error()) os.Exit(1) } if _, err := toml.Decode(string(tomlData), &Conf); err != nil { // handle Parse error panic(err.Error()) os.Exit(1) } dbs = make(map[string]*sql.DB) conStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", Conf.Host, Conf.Port, Conf.User, Conf.Password, Conf.Db, Conf.Ssl) systemDB, err := sql.Open("postgres", conStr) if err != nil { fmt.Println("Fatal: could not connect to db, exiting... Error:", err) os.Exit(1) } dbs["system"] = systemDB } // GetSystemDatabase returns system db func GetSystemDatabase() *sql.DB { return dbs["system"] } // GetDatabaseWithName Get orm and user information func GetDatabaseWithName(companyName string) (*sql.DB, error) { if dbs[companyName] != nil { fmt.Println("DB Already open") return dbs[companyName], nil } conStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", Conf.Host, Conf.Port, Conf.User, Conf.Password, companyName, Conf.Ssl) db, err := sql.Open("postgres", conStr) dbs[companyName] = db if err != nil { return nil, err } return db, nil } // GetDatabase Get orm and user information func GetDatabase(tokenString string) (jwt.MapClaims, *sql.DB, error) { // validate token valid, token := tokenTools.Validate(tokenString) if !valid { return nil, nil, errors.New("Token is invalid") } tokenMap := token.Claims.(jwt.MapClaims) companyName := tokenMap["companyName"].(string) if dbs[companyName] != nil { fmt.Println("DB Already open") return tokenMap, dbs[companyName], nil } db, err := GetDatabaseWithName(companyName) if err != nil { return nil, nil, err } // return db with orm or error return tokenMap, db, nil } // HasDatabase Check if DB exists func HasDatabase(dbname string) bool { systemDB := GetSystemDatabase() result, err := systemDB.Query("SELECT datname FROM pg_database WHERE datistemplate = false;") if err != nil { return false } for result.Next() { var aDbName string result.Scan(&aDbName) if aDbName == dbname { return true } } return false } // CreateDatabase Create a database by copying the template func CreateDatabase(companyName string) (*sql.DB, error) { if HasDatabase(companyName) { return nil, errors.New("DB already exists") } systemDB := GetSystemDatabase() // Takes about 1.2 seconds and we trust companyName to be sanitized in register queryString := fmt.Sprintf("CREATE DATABASE %s TEMPLATE company_template;", companyName) _, err := systemDB.Exec(queryString) if err != nil { return nil, err } db, err := GetDatabaseWithName(companyName) if err != nil { return nil, err } return db, nil } // DeleteDatabase Delete an entire database, this is very very dangerous :-) func DeleteDatabase(companyName string) error { if !HasDatabase(companyName) { return errors.New("DB does not exist") } systemDB := GetSystemDatabase() db, err := GetDatabaseWithName(companyName) db.Close() delete(dbs, companyName) fmt.Println("Closed %s", companyName) queryString := fmt.Sprintf("DROP DATABASE %s;", companyName) _, err = systemDB.Exec(queryString) if err != nil { return err } return nil }