jwt validation, getting the correct database

This commit is contained in:
Lukas Bachschwell 2018-11-07 20:13:26 +01:00
parent 549d91fbb4
commit 3347161ae9
13 changed files with 125 additions and 62 deletions

View File

@ -14,3 +14,4 @@ To regenerate docs simply run `bee generate docs`
- company controller, create databases and so on - company controller, create databases and so on
- Update not found to json - Update not found to json
- load db connections from config

View File

@ -1,7 +1,7 @@
package controllers package controllers
import ( import (
auth "multitenantStack/services/authentication" auth "multitenantStack/services"
"time" "time"
jwt "github.com/dgrijalva/jwt-go" jwt "github.com/dgrijalva/jwt-go"
@ -37,7 +37,7 @@ func (c *AuthController) Login() {
} }
if c.Ctx.Input.Method() != "POST" { if c.Ctx.Input.Method() != "POST" {
c.ServeJsonError("Method not allowed") c.ServeJSONError("Method not allowed")
return return
} }
@ -46,13 +46,13 @@ func (c *AuthController) Login() {
email := c.GetString("email") email := c.GetString("email")
password := c.GetString("password") password := c.GetString("password")
//TODO: check against main database, get company id and veryfy password //TODO: check against main database, get company id and verify password
companyName := "" companyName := "company_1"
companyUserId := 5 companyUserID := 5
//TODO: if found query the company database to get roleid, and name //TODO: if found query the company database to get roleID, and name
name := "Lukas" name := "Lukas"
roleId := 5 roleID := 5
tokenString := "" tokenString := ""
if email == "admin@admin.at" && password == "my password" { if email == "admin@admin.at" && password == "my password" {
@ -60,13 +60,13 @@ func (c *AuthController) Login() {
tokenString = auth.CreateToken(jwt.MapClaims{ tokenString = auth.CreateToken(jwt.MapClaims{
"email": email, "email": email,
"companyName": companyName, "companyName": companyName,
"companyUserId": companyUserId, "companyUserID": companyUserID,
"name": name, "name": name,
"roleId": roleId, "roleID": roleID,
"expires": time.Now().Unix() + 3600, "exp": time.Now().Unix() + 3600,
}) })
} else { } else {
c.ServeJsonError("Invalid user/password") c.ServeJSONError("Invalid user/password")
return return
} }

View File

@ -4,35 +4,44 @@ import (
"github.com/astaxie/beego" "github.com/astaxie/beego"
) )
type JsonBasicResponse struct { // JSONBasicResponse The minimal JSON response
type JSONBasicResponse struct {
Status int Status int
Message string Message string
} }
const JSON_ERROR int = 400 // JSONError code for a input error
const JSON_INT_ERROR int = 500 const JSONError int = 400
const JSON_SUCCESS int = 200
// JSONInternalError code for an internal error
const JSONInternalError int = 500
// JSONSuccess code for a success
const JSONSuccess int = 200
// BaseController operations for BaseController // BaseController operations for BaseController
type BaseController struct { type BaseController struct {
beego.Controller beego.Controller
} }
func (c *BaseController) ServeJsonError(message string) { // ServeJSONError respond with a JSON error
json := JsonBasicResponse{JSON_ERROR, message} func (c *BaseController) ServeJSONError(message string) {
json := JSONBasicResponse{JSONError, message}
c.Data["json"] = &json c.Data["json"] = &json
///c.Ctx.ResponseWriter.WriteHeader(400) ///c.Ctx.ResponseWriter.WriteHeader(400)
c.ServeJSON() c.ServeJSON()
} }
func (c *BaseController) ServeJsonErrorWithCode(errorcode int, message string) { // ServeJSONErrorWithCode respond with a JSON error and specify code
json := JsonBasicResponse{errorcode, message} func (c *BaseController) ServeJSONErrorWithCode(errorcode int, message string) {
json := JSONBasicResponse{errorcode, message}
c.Data["json"] = &json c.Data["json"] = &json
c.ServeJSON() c.ServeJSON()
} }
func (c *BaseController) ServeJsonSuccess(message string) { // ServeJSONSuccess respond with a JSON success message
json := JsonBasicResponse{JSON_SUCCESS, message} func (c *BaseController) ServeJSONSuccess(message string) {
json := JSONBasicResponse{JSONSuccess, message}
c.Data["json"] = &json c.Data["json"] = &json
c.ServeJSON() c.ServeJSON()
} }

View File

@ -1,29 +1,47 @@
package controllers package controllers
// BaseController operations for APIs import (
"database/sql"
"fmt"
companydb "multitenantStack/services"
"github.com/astaxie/beego/orm"
jwt "github.com/dgrijalva/jwt-go"
)
// BaseAPIController operations for APIs
type BaseAPIController struct { type BaseAPIController struct {
BaseController BaseController
} }
func (this *BaseAPIController) Prepare() { var jwtSession jwt.MapClaims
var companyDB *sql.DB
var o orm.Ormer
/* //var database sql.database
//Lo que quieras hacer en todos los controladores
// O puede ser leído de una cabecera HEADER!!
tokenString := this.Ctx.Request.Header.Get("X-JWTtoken")
et := jwtbeego.EasyToken{}
valid, issuer, _ := et.ValidateToken(tokenString)
if !valid {
this.Ctx.Output.SetStatus(401)
this.ServeJsonError("Invalid Token")
}
/*
userSession := this.GetSession("username")
if userSession == nil || userSession != issuer { // Prepare parse all requests that come after this controller for valid auth
this.Ctx.Output.SetStatus(401) func (c *BaseAPIController) Prepare() {
this.ServeJsonError("Invalid Session")
} tokenString := c.Ctx.Request.Header.Get("X-JWTtoken")
*/
//return if tokenString == "" {
c.ServeJSONError("No Token provided")
return
}
token, db, err := companydb.GetDatabase(tokenString)
if err != nil {
c.ServeJSONError("Token invalid")
return
}
jwtSession = token
companyDB = db
o, err = orm.NewOrmWithDB("postgres", "company", companyDB)
if err != nil {
fmt.Println(err.Error())
c.ServeJSONError("internal")
return
}
} }

View File

@ -7,12 +7,12 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/astaxie/beego" "github.com/astaxie/beego/orm"
) )
// ContactController operations for Contact // ContactController operations for Contact
type ContactController struct { type ContactController struct {
beego.Controller BaseAPIController
} }
// URLMapping ... // URLMapping ...
@ -56,7 +56,7 @@ func (c *ContactController) Post() {
func (c *ContactController) GetOne() { func (c *ContactController) GetOne() {
idStr := c.Ctx.Input.Param(":id") idStr := c.Ctx.Input.Param(":id")
id, _ := strconv.Atoi(idStr) id, _ := strconv.Atoi(idStr)
v, err := models.GetContactById(id) v, err := models.GetContactById(orm.NewOrm(), id)
if err != nil { if err != nil {
c.Data["json"] = err.Error() c.Data["json"] = err.Error()
} else { } else {
@ -119,12 +119,14 @@ func (c *ContactController) GetAll() {
} }
} }
l, err := models.GetAllContact(query, fields, sortby, order, offset, limit) ob, _ := orm.NewOrmWithDB("postgres", "default", companyDB)
l, err := models.GetAllContact(ob, query, fields, sortby, order, offset, limit)
if err != nil { if err != nil {
c.Data["json"] = err.Error() c.Data["json"] = err.Error()
} else { } else {
c.Data["json"] = l c.Data["json"] = l
} }
c.ServeJSON() c.ServeJSON()
} }

View File

@ -1,13 +1,16 @@
package controllers package controllers
// ErrorController Handle all errors
type ErrorController struct { type ErrorController struct {
BaseController BaseController
} }
// Error404 handle a 404
func (c *ErrorController) Error404() { func (c *ErrorController) Error404() {
c.ServeJsonErrorWithCode(404, "Not Found") c.ServeJSONErrorWithCode(404, "Not Found")
} }
// Error500 handle a 500
func (c *ErrorController) Error500() { func (c *ErrorController) Error500() {
c.ServeJsonErrorWithCode(500, "Internal Server Error") c.ServeJSONErrorWithCode(500, "Internal Server Error")
} }

View File

@ -5,10 +5,12 @@ type IndexController struct {
BaseController BaseController
} }
// Get Index response for get
func (c *IndexController) Get() { func (c *IndexController) Get() {
c.ServeJsonSuccess("multitenant API") c.ServeJSONSuccess("multitenant API")
} }
// Post Index response for post
func (c *IndexController) Post() { func (c *IndexController) Post() {
c.ServeJsonSuccess("multitenant API") c.ServeJSONSuccess("multitenant API")
} }

View File

@ -1 +1 @@
{"/Users/LB/go/src/multitenantStack/controllers":1541598684943144901} {"/Users/LB/go/src/multitenantStack/controllers":1541617005449208486}

View File

@ -2,6 +2,8 @@ package main
import ( import (
_ "multitenantStack/routers" _ "multitenantStack/routers"
auth "multitenantStack/services"
"time"
"github.com/astaxie/beego" "github.com/astaxie/beego"
"github.com/astaxie/beego/orm" "github.com/astaxie/beego/orm"
@ -10,6 +12,8 @@ import (
func init() { func init() {
orm.RegisterDataBase("default", "postgres", "host=127.0.0.1 port=5435 user=postgres password=postgre dbname=company_template sslmode=disable") orm.RegisterDataBase("default", "postgres", "host=127.0.0.1 port=5435 user=postgres password=postgre dbname=company_template sslmode=disable")
auth.InitJWTService()
orm.DefaultTimeLoc = time.UTC
if beego.BConfig.RunMode == "dev" { if beego.BConfig.RunMode == "dev" {
beego.BConfig.WebConfig.DirectoryIndex = true beego.BConfig.WebConfig.DirectoryIndex = true
beego.BConfig.WebConfig.StaticDir["/swagger"] = "swagger" beego.BConfig.WebConfig.StaticDir["/swagger"] = "swagger"

View File

@ -40,8 +40,8 @@ func AddContact(m *Contact) (id int64, err error) {
// GetContactById retrieves Contact by Id. Returns error if // GetContactById retrieves Contact by Id. Returns error if
// Id doesn't exist // Id doesn't exist
func GetContactById(id int) (v *Contact, err error) { func GetContactById(o orm.Ormer, id int) (v *Contact, err error) {
o := orm.NewOrm() //o := orm.NewOrm()
v = &Contact{Id: id} v = &Contact{Id: id}
if err = o.Read(v); err == nil { if err = o.Read(v); err == nil {
return v, nil return v, nil
@ -51,9 +51,8 @@ func GetContactById(id int) (v *Contact, err error) {
// GetAllContact retrieves all Contact matches certain condition. Returns empty list if // GetAllContact retrieves all Contact matches certain condition. Returns empty list if
// no records exist // no records exist
func GetAllContact(query map[string]string, fields []string, sortby []string, order []string, func GetAllContact(o orm.Ormer, query map[string]string, fields []string, sortby []string, order []string,
offset int64, limit int64) (ml []interface{}, err error) { offset int64, limit int64) (ml []interface{}, err error) {
o := orm.NewOrm()
qs := o.QueryTable(new(Contact)) qs := o.QueryTable(new(Contact))
// query k=v // query k=v
for k, v := range query { for k, v := range query {

View File

@ -2,10 +2,12 @@ package services
import ( import (
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"os" "os"
"github.com/astaxie/beego/orm" "github.com/astaxie/beego/orm"
jwt "github.com/dgrijalva/jwt-go"
) )
var dbs map[string]*sql.DB var dbs map[string]*sql.DB
@ -27,13 +29,30 @@ func InitCompanyService() {
} }
// GetDatabase Get orm and user information // GetDatabase Get orm and user information
func GetDatabase(token string) { func GetDatabase(tokenString string) (jwt.MapClaims, *sql.DB, error) {
// validate token // validate token
// retrieve correct user/database valid, token := Validate(tokenString)
// check if open first if !valid {
// try to open second return nil, nil, errors.New("Token is invalid")
// return error otherwise }
tokenMap := token.Claims.(jwt.MapClaims)
companyName := tokenMap["companyName"].(string)
if dbs[companyName] != nil {
fmt.Println("DB Already open")
return tokenMap, dbs[companyName], nil
}
conStr := fmt.Sprintf("host=127.0.0.1 port=5435 user=postgres password=postgre dbname=%s sslmode=disable", tokenMap["companyName"])
fmt.Println(conStr)
db, err := sql.Open("postgres", conStr)
if err != nil {
return nil, nil, err
}
// return db with orm or error // return db with orm or error
return tokenMap, db, nil
} }
// CreateDatabase Create a database by copying the template // CreateDatabase Create a database by copying the template

View File

@ -9,18 +9,22 @@ import (
var hmacSecret []byte var hmacSecret []byte
// GenerateSecret generate the secret to verify JWTs
func GenerateSecret() []byte { func GenerateSecret() []byte {
b := make([]byte, 32) b := make([]byte, 32)
rand.Read(b) rand.Read(b)
return b return b
} }
func InitAuthService() { // InitJWTService generate the secret to verify JWTs and store it in memory
func InitJWTService() {
hmacSecret = GenerateSecret() hmacSecret = GenerateSecret()
fmt.Println("InitJWTService", hmacSecret)
// TODO: This needs to be replaced with reading rsa keys, there needs to be a automatic generation of these if they do not exist // TODO: This needs to be replaced with reading rsa keys, there needs to be a automatic generation of these if they do not exist
} }
// Validate a jwt tokenstring
func Validate(Token string) (bool, jwt.Token) { func Validate(Token string) (bool, jwt.Token) {
token, err := jwt.Parse(Token, func(token *jwt.Token) (interface{}, error) { token, err := jwt.Parse(Token, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
@ -31,6 +35,7 @@ func Validate(Token string) (bool, jwt.Token) {
}) })
if err == nil && token.Valid { if err == nil && token.Valid {
fmt.Println("Token is valid") fmt.Println("Token is valid")
return true, *token return true, *token
} }
@ -39,6 +44,7 @@ func Validate(Token string) (bool, jwt.Token) {
return false, *token return false, *token
} }
// CreateToken create a new jwt token with the provided claims
func CreateToken(Claims jwt.MapClaims) string { func CreateToken(Claims jwt.MapClaims) string {
// Create a new token object, specifying signing method and the claims // Create a new token object, specifying signing method and the claims

View File

@ -2,6 +2,7 @@ package test
import ( import (
_ "multitenantStack/routers" _ "multitenantStack/routers"
auth "multitenantStack/services/authentication"
"github.com/astaxie/beego" "github.com/astaxie/beego"
"github.com/astaxie/beego/orm" "github.com/astaxie/beego/orm"
@ -10,8 +11,7 @@ import (
func init() { func init() {
orm.RegisterDataBase("default", "postgres", "host=127.0.0.1 port=5435 user=postgres password=postgre sslmode=disable") orm.RegisterDataBase("default", "postgres", "host=127.0.0.1 port=5435 user=postgres password=postgre sslmode=disable")
orm.RegisterDataBase("company1", "postgres", "host=127.0.0.1 port=5435 user=postgres password=postgre dbname=company1 sslmode=disable") auth.InitAuthService()
orm.RegisterDataBase("company2", "postgres", "host=127.0.0.1 port=5435 user=postgres password=postgre dbname=company2 sslmode=disable")
orm.Debug = true orm.Debug = true
beego.BConfig.WebConfig.DirectoryIndex = true beego.BConfig.WebConfig.DirectoryIndex = true