diff --git a/config.go b/config.go index d9ff624c..85da1c30 100644 --- a/config.go +++ b/config.go @@ -354,10 +354,12 @@ func init() { SetLogFuncCall(true) err = ParseConfig() - if err != nil && os.IsNotExist(err) { - // for init if doesn't have app.conf will not panic - ac := config.NewFakeConfig() - AppConfig = &beegoAppConfig{ac} + if err != nil { + if os.IsNotExist(err) { + // for init if doesn't have app.conf will not panic + ac := config.NewFakeConfig() + AppConfig = &beegoAppConfig{ac} + } Warning(err) } } diff --git a/context/acceptencoder.go b/context/acceptencoder.go new file mode 100644 index 00000000..1bd2cc3d --- /dev/null +++ b/context/acceptencoder.go @@ -0,0 +1,133 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "io" + "net/http" + "os" + "strconv" + "strings" +) + +type acceptEncoder struct { + name string + encode func(io.Writer, int) (io.Writer, error) +} + +var ( + noneCompressEncoder = acceptEncoder{"", func(wr io.Writer, level int) (io.Writer, error) { return wr, nil }} + gzipCompressEncoder = acceptEncoder{"gzip", func(wr io.Writer, level int) (io.Writer, error) { return gzip.NewWriterLevel(wr, level) }} + deflateCompressEncoder = acceptEncoder{"deflate", func(wr io.Writer, level int) (io.Writer, error) { return flate.NewWriter(wr, level) }} +) + +var ( + encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore + "gzip": gzipCompressEncoder, + "deflate": deflateCompressEncoder, + "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip + "identity": noneCompressEncoder, // identity means none-compress + } +) + +// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) +func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { + return writeLevel(encoding, writer, file, flate.BestCompression) +} + +// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) +func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { + return writeLevel(encoding, writer, bytes.NewReader(content), flate.BestSpeed) +} + +// writeLevel reads from reader,writes to writer by specific encoding and compress level +// the compress level is defined by deflate package +func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) { + var outputWriter io.Writer + var err error + var ce = noneCompressEncoder + + if cf, ok := encoderMap[encoding]; ok { + ce = cf + } + encoding = ce.name + outputWriter, err = ce.encode(writer, level) + + if err != nil { + return false, "", err + } + + _, err = io.Copy(outputWriter, reader) + if err != nil { + return false, "", err + } + + switch outputWriter.(type) { + case io.WriteCloser: + outputWriter.(io.WriteCloser).Close() + } + return encoding != "", encoding, nil +} + +// ParseEncoding will extract the right encoding for response +// the Accept-Encoding's sec is here: +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3 +func ParseEncoding(r *http.Request) string { + if r == nil { + return "" + } + return parseEncoding(r) +} + +type q struct { + name string + value float64 +} + +func parseEncoding(r *http.Request) string { + acceptEncoding := r.Header.Get("Accept-Encoding") + if acceptEncoding == "" { + return "" + } + var lastQ q + for _, v := range strings.Split(acceptEncoding, ",") { + v = strings.TrimSpace(v) + if v == "" { + continue + } + vs := strings.Split(v, ";") + if len(vs) == 1 { + lastQ = q{vs[0], 1} + break + } + if len(vs) == 2 { + f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64) + if f == 0 { + continue + } + if f > lastQ.value { + lastQ = q{vs[0], f} + } + } + } + if cf, ok := encoderMap[lastQ.name]; ok { + return cf.name + } else { + return "" + } +} diff --git a/context/acceptencoder_test.go b/context/acceptencoder_test.go new file mode 100644 index 00000000..147313c5 --- /dev/null +++ b/context/acceptencoder_test.go @@ -0,0 +1,45 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "testing" +) + +func Test_ExtractEncoding(t *testing.T) { + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"deflate,gzip"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip;q=.5,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip;q=.5,deflate;q=0.3"}}}) != "gzip" { + t.Fail() + } + + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip;q=0,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"deflate;q=0.5,gzip;q=0.5,identity"}}}) != "" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"*"}}}) != "gzip" { + t.Fail() + } +} diff --git a/context/output.go b/context/output.go index 0a194753..f0d66f36 100644 --- a/context/output.go +++ b/context/output.go @@ -16,8 +16,6 @@ package context import ( "bytes" - "compress/flate" - "compress/gzip" "encoding/json" "encoding/xml" "errors" @@ -54,29 +52,16 @@ func (output *BeegoOutput) Header(key, val string) { // if EnableGzip, compress content string. // it sends out response body directly. func (output *BeegoOutput) Body(content []byte) { - outputWriter := output.Context.ResponseWriter.(io.Writer) - if output.EnableGzip == true && output.Context.Input.Header("Accept-Encoding") != "" { - splitted := strings.SplitN(output.Context.Input.Header("Accept-Encoding"), ",", -1) - encodings := make([]string, len(splitted)) - - for i, val := range splitted { - encodings[i] = strings.TrimSpace(val) - } - for _, val := range encodings { - if val == "gzip" { - output.Header("Content-Encoding", "gzip") - outputWriter, _ = gzip.NewWriterLevel(output.Context.ResponseWriter, gzip.BestSpeed) - break - } else if val == "deflate" { - output.Header("Content-Encoding", "deflate") - outputWriter, _ = flate.NewWriter(output.Context.ResponseWriter, flate.BestSpeed) - break - } - } + var encoding string + var buf = &bytes.Buffer{} + if output.EnableGzip { + encoding = ParseEncoding(output.Context.Input.Request) + } + if b, n, _ := WriteBody(encoding, buf, content); b { + output.Header("Content-Encoding", n) } else { output.Header("Content-Length", strconv.Itoa(len(content))) } - // Write status code if it has been set manually // Set it to 0 afterwards to prevent "multiple response.WriteHeader calls" if output.Status != 0 { @@ -84,10 +69,7 @@ func (output *BeegoOutput) Body(content []byte) { output.Status = 0 } - outputWriter.Write(content) - if c, ok := outputWriter.(io.Closer); ok { - c.Close() - } + io.Copy(output.Context.ResponseWriter, buf) } // Cookie sets cookie value via given key. @@ -98,25 +80,21 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface //fix cookie not work in IE if len(others) > 0 { + var maxAge int64 + switch v := others[0].(type) { case int: - if v > 0 { - fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(v)*time.Second).UTC().Format(time.RFC1123), v) - } else if v <= 0 { - fmt.Fprintf(&b, "; Max-Age=0") - } - case int64: - if v > 0 { - fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(v)*time.Second).UTC().Format(time.RFC1123), v) - } else if v <= 0 { - fmt.Fprintf(&b, "; Max-Age=0") - } + maxAge = int64(v) case int32: - if v > 0 { - fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(v)*time.Second).UTC().Format(time.RFC1123), v) - } else if v <= 0 { - fmt.Fprintf(&b, "; Max-Age=0") - } + maxAge = int64(v) + case int64: + maxAge = v + } + + if maxAge > 0 { + fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge) + } else { + fmt.Fprintf(&b, "; Max-Age=0") } } diff --git a/controller_test.go b/controller_test.go index 15938cdc..4156bd29 100644 --- a/controller_test.go +++ b/controller_test.go @@ -17,42 +17,48 @@ package beego import ( "fmt" "github.com/astaxie/beego/context" + "testing" ) -func ExampleGetInt() { +func TestGetInt(t *testing.T) { i := &context.BeegoInput{Params: map[string]string{"age": "40"}} ctx := &context.Context{Input: i} ctrlr := Controller{Ctx: ctx} val, _ := ctrlr.GetInt("age") - fmt.Printf("%T", val) - //Output: int + + if (val != 40) { + t.Errorf("TestGetInt expect 40,get %T,%v", val, val) + } } -func ExampleGetInt8() { +func TestGetInt8(t *testing.T) { i := &context.BeegoInput{Params: map[string]string{"age": "40"}} ctx := &context.Context{Input: i} ctrlr := Controller{Ctx: ctx} val, _ := ctrlr.GetInt8("age") - fmt.Printf("%T", val) + if val != 40 { + t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val) + } //Output: int8 } -func ExampleGetInt16() { +func TestGetInt16(t *testing.T) { i := &context.BeegoInput{Params: map[string]string{"age": "40"}} ctx := &context.Context{Input: i} ctrlr := Controller{Ctx: ctx} val, _ := ctrlr.GetInt16("age") - fmt.Printf("%T", val) - //Output: int16 + if val != 40 { + t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val) + } } -func ExampleGetInt32() { +func TestGetInt32(t *testing.T) { i := &context.BeegoInput{Params: map[string]string{"age": "40"}} ctx := &context.Context{Input: i} @@ -60,16 +66,19 @@ func ExampleGetInt32() { val, _ := ctrlr.GetInt32("age") fmt.Printf("%T", val) - //Output: int32 + if val != 40 { + t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val) + } } -func ExampleGetInt64() { +func TestGetInt64(t *testing.T) { i := &context.BeegoInput{Params: map[string]string{"age": "40"}} ctx := &context.Context{Input: i} ctrlr := Controller{Ctx: ctx} val, _ := ctrlr.GetInt64("age") - fmt.Printf("%T", val) - //Output: int64 + if val != 40 { + t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val) + } } diff --git a/memzipfile.go b/memzipfile.go deleted file mode 100644 index b61e87f2..00000000 --- a/memzipfile.go +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "bytes" - "compress/flate" - "compress/gzip" - "errors" - "io" - "io/ioutil" - "net/http" - "os" - "strings" - "sync" - "time" -) - -var ( - menFileInfoMap = make(map[string]*memFileInfo) - lock sync.RWMutex -) - -// openMemZipFile returns MemFile object with a compressed static file. -// it's used for serve static file if gzip enable. -func openMemZipFile(path string, zip string) (*memFile, error) { - osFile, e := os.Open(path) - if e != nil { - return nil, e - } - defer osFile.Close() - - osFileInfo, e := osFile.Stat() - if e != nil { - return nil, e - } - - modTime := osFileInfo.ModTime() - fileSize := osFileInfo.Size() - lock.RLock() - cfi, ok := menFileInfoMap[zip+":"+path] - lock.RUnlock() - if !(ok && cfi.ModTime() == modTime && cfi.fileSize == fileSize) { - var content []byte - if zip == "gzip" { - var zipBuf bytes.Buffer - gzipWriter, e := gzip.NewWriterLevel(&zipBuf, gzip.BestCompression) - if e != nil { - return nil, e - } - _, e = io.Copy(gzipWriter, osFile) - gzipWriter.Close() - if e != nil { - return nil, e - } - content, e = ioutil.ReadAll(&zipBuf) - if e != nil { - return nil, e - } - } else if zip == "deflate" { - var zipBuf bytes.Buffer - deflateWriter, e := flate.NewWriter(&zipBuf, flate.BestCompression) - if e != nil { - return nil, e - } - _, e = io.Copy(deflateWriter, osFile) - deflateWriter.Close() - if e != nil { - return nil, e - } - content, e = ioutil.ReadAll(&zipBuf) - if e != nil { - return nil, e - } - } else { - content, e = ioutil.ReadAll(osFile) - if e != nil { - return nil, e - } - } - - cfi = &memFileInfo{osFileInfo, modTime, content, int64(len(content)), fileSize} - lock.Lock() - defer lock.Unlock() - menFileInfoMap[zip+":"+path] = cfi - } - return &memFile{fi: cfi, offset: 0}, nil -} - -// MemFileInfo contains a compressed file bytes and file information. -// it implements os.FileInfo interface. -type memFileInfo struct { - os.FileInfo - modTime time.Time - content []byte - contentSize int64 - fileSize int64 -} - -// Name returns the compressed filename. -func (fi *memFileInfo) Name() string { - return fi.Name() -} - -// Size returns the raw file content size, not compressed size. -func (fi *memFileInfo) Size() int64 { - return fi.contentSize -} - -// Mode returns file mode. -func (fi *memFileInfo) Mode() os.FileMode { - return fi.Mode() -} - -// ModTime returns the last modified time of raw file. -func (fi *memFileInfo) ModTime() time.Time { - return fi.modTime -} - -// IsDir returns the compressing file is a directory or not. -func (fi *memFileInfo) IsDir() bool { - return fi.IsDir() -} - -// return nil. implement the os.FileInfo interface method. -func (fi *memFileInfo) Sys() interface{} { - return nil -} - -// memFile contains MemFileInfo and bytes offset when reading. -// it implements io.Reader,io.ReadCloser and io.Seeker. -type memFile struct { - fi *memFileInfo - offset int64 -} - -// Close memfile. -func (f *memFile) Close() error { - return nil -} - -// Get os.FileInfo of memfile. -func (f *memFile) Stat() (os.FileInfo, error) { - return f.fi, nil -} - -// read os.FileInfo of files in directory of memfile. -// it returns empty slice. -func (f *memFile) Readdir(count int) ([]os.FileInfo, error) { - infos := []os.FileInfo{} - - return infos, nil -} - -// Read bytes from the compressed file bytes. -func (f *memFile) Read(p []byte) (n int, err error) { - if len(f.fi.content)-int(f.offset) >= len(p) { - n = len(p) - } else { - n = len(f.fi.content) - int(f.offset) - err = io.EOF - } - copy(p, f.fi.content[f.offset:f.offset+int64(n)]) - f.offset += int64(n) - return -} - -var errWhence = errors.New("Seek: invalid whence") -var errOffset = errors.New("Seek: invalid offset") - -// Read bytes from the compressed file bytes by seeker. -func (f *memFile) Seek(offset int64, whence int) (ret int64, err error) { - switch whence { - default: - return 0, errWhence - case os.SEEK_SET: - case os.SEEK_CUR: - offset += f.offset - case os.SEEK_END: - offset += int64(len(f.fi.content)) - } - if offset < 0 || int(offset) > len(f.fi.content) { - return 0, errOffset - } - f.offset = offset - return f.offset, nil -} - -// getAcceptEncodingZip returns accept encoding format in http header. -// zip is first, then deflate if both accepted. -// If no accepted, return empty string. -func getAcceptEncodingZip(r *http.Request) string { - ss := r.Header.Get("Accept-Encoding") - ss = strings.ToLower(ss) - if strings.Contains(ss, "gzip") { - return "gzip" - } else if strings.Contains(ss, "deflate") { - return "deflate" - } else { - return "" - } -} diff --git a/orm/orm_querym2m.go b/orm/orm_querym2m.go index 1eaccf72..36c5e95f 100644 --- a/orm/orm_querym2m.go +++ b/orm/orm_querym2m.go @@ -44,7 +44,20 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { dbase := orm.alias.DbBaser var models []interface{} + var other_values []interface{} + var other_names []string + for _, colname := range mi.fields.dbcols { + if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column { + other_names = append(other_names, colname) + } + } + for i, md := range mds { + if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 { + other_values = append(other_values, md) + mds = append(mds[:i], mds[i+1:]...) + } + } for _, md := range mds { val := reflect.ValueOf(md) if val.Kind() == reflect.Slice || val.Kind() == reflect.Array { @@ -67,11 +80,9 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { names := []string{mfi.column, rfi.column} values := make([]interface{}, 0, len(models)*2) - for _, md := range models { ind := reflect.Indirect(reflect.ValueOf(md)) - var v2 interface{} if ind.Kind() != reflect.Struct { v2 = ind.Interface() @@ -81,11 +92,11 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { panic(ErrMissPK) } } - values = append(values, v1, v2) } - + names = append(names, other_names...) + values = append(values, other_values...) return dbase.InsertValue(orm.db, mi, true, names, values) } diff --git a/router.go b/router.go index 5c6b7bf9..6d82c780 100644 --- a/router.go +++ b/router.go @@ -717,18 +717,18 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) if doFilter(BeforeExec) { goto Admin } - isRunable := false + isRunnable := false if routerInfo != nil { if routerInfo.routerType == routerTypeRESTFul { if _, ok := routerInfo.methods[r.Method]; ok { - isRunable = true + isRunnable = true routerInfo.runFunction(context) } else { exception("405", context) goto Admin } } else if routerInfo.routerType == routerTypeHandler { - isRunable = true + isRunnable = true routerInfo.handler.ServeHTTP(rw, r) } else { runrouter = routerInfo.controllerType @@ -750,7 +750,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) } // also defined runrouter & runMethod from filter - if !isRunable { + if !isRunnable { //Invoke the request handler vc := reflect.New(runrouter) execController, ok := vc.Interface().(ControllerInterface) diff --git a/staticfile.go b/staticfile.go index 56c5fc08..90d0be0d 100644 --- a/staticfile.go +++ b/staticfile.go @@ -15,122 +15,182 @@ package beego import ( + "bytes" "net/http" "os" "path" "strconv" "strings" + "sync" + + "errors" + + "time" + "github.com/astaxie/beego/context" - "github.com/astaxie/beego/utils" ) -func serverStaticRouter(ctx *context.Context) { +var notStaticRequestErr = errors.New("request not a static file request") +func serverStaticRouter(ctx *context.Context) { if ctx.Input.Method() != "GET" && ctx.Input.Method() != "HEAD" { return } + + forbidden, filePath, fileInfo, err := lookupFile(ctx) + if err == notStaticRequestErr { + return + } + + if forbidden { + exception("403", ctx) + return + } + + if filePath == "" || fileInfo == nil { + if RunMode == "dev" { + Warn("Can't find/open the file:", filePath, err) + } + http.NotFound(ctx.ResponseWriter, ctx.Request) + return + } + if fileInfo.IsDir() { + //serveFile will list dir + http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) + return + } + + var enableCompress = EnableGzip && isStaticCompress(filePath) + var acceptEncoding string + if enableCompress { + acceptEncoding = context.ParseEncoding(ctx.Request) + } + b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding) + if err != nil { + if RunMode == "dev" { + Warn("Can't compress the file:", filePath, err) + } + http.NotFound(ctx.ResponseWriter, ctx.Request) + return + } + + if b { + ctx.Output.Header("Content-Encoding", n) + } else { + ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10)) + } + + http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, sch) + return + +} + +type serveContentHolder struct { + *bytes.Reader + modTime time.Time + size int64 + encoding string +} + +var ( + staticFileMap = make(map[string]*serveContentHolder) + mapLock sync.Mutex +) + +func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, error) { + mapKey := acceptEncoding + ":" + filePath + mapFile, _ := staticFileMap[mapKey] + if isOk(mapFile, fi) { + return mapFile.encoding != "", mapFile.encoding, mapFile, nil + } + mapLock.Lock() + defer mapLock.Unlock() + if mapFile, _ = staticFileMap[mapKey]; !isOk(mapFile, fi) { + file, err := os.Open(filePath) + if err != nil { + return false, "", nil, err + } + defer file.Close() + var bufferWriter bytes.Buffer + _, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file) + if err != nil { + return false, "", nil, err + } + mapFile = &serveContentHolder{Reader: bytes.NewReader(bufferWriter.Bytes()), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), encoding: n} + staticFileMap[mapKey] = mapFile + } + + return mapFile.encoding != "", mapFile.encoding, mapFile, nil +} + +func isOk(s *serveContentHolder, fi os.FileInfo) bool { + if s == nil { + return false + } + return s.modTime == fi.ModTime() && s.size == fi.Size() +} + +// isStaticCompress detect static files +func isStaticCompress(filePath string) bool { + for _, statExtension := range StaticExtensionsToGzip { + if strings.HasSuffix(strings.ToLower(filePath), strings.ToLower(statExtension)) { + return true + } + } + return false +} + +// searchFile search the file by url path +// if none the static file prefix matches ,return notStaticRequestErr +func searchFile(ctx *context.Context) (string, os.FileInfo, error) { requestPath := path.Clean(ctx.Input.Request.URL.Path) // special processing : favicon.ico/robots.txt can be in any static dir if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { file := path.Join(".", requestPath) - if utils.FileExists(file) { - http.ServeFile(ctx.ResponseWriter, ctx.Request, file) - return + if fi, _ := os.Stat(file); fi != nil { + return file, fi, nil } - for _, staticDir := range StaticDir { - file := path.Join(staticDir, requestPath) - if utils.FileExists(file) { - http.ServeFile(ctx.ResponseWriter, ctx.Request, file) - return + filePath := path.Join(staticDir, requestPath) + if fi, _ := os.Stat(filePath); fi != nil { + return filePath, fi, nil } } - - http.NotFound(ctx.ResponseWriter, ctx.Request) - return + return "", nil, errors.New(requestPath + " file not find") } for prefix, staticDir := range StaticDir { if len(prefix) == 0 { continue } - if strings.HasPrefix(requestPath, prefix) { - if len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { - continue - } - filePath := path.Join(staticDir, requestPath[len(prefix):]) - fileInfo, err := os.Stat(filePath) - if err != nil { - if RunMode == "dev" { - Warn("Can't find the file:", filePath, err) - } - http.NotFound(ctx.ResponseWriter, ctx.Request) - return - } - //if the request is dir and DirectoryIndex is false then - if fileInfo.IsDir() { - if !DirectoryIndex { - exception("403", ctx) - return - } - if ctx.Input.Request.URL.Path[len(ctx.Input.Request.URL.Path)-1] != '/' { - http.Redirect(ctx.ResponseWriter, ctx.Request, ctx.Input.Request.URL.Path+"/", 302) - return - } - } - - if strings.HasSuffix(requestPath, "/index.html") { - fileReader, err := os.Open(filePath) - if err != nil { - if RunMode == "dev" { - Warn("Can't open the file:", filePath, err) - } - http.NotFound(ctx.ResponseWriter, ctx.Request) - return - } - defer fileReader.Close() - http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, fileInfo.ModTime(), fileReader) - return - } - - isStaticFileToCompress := false - for _, statExtension := range StaticExtensionsToGzip { - if strings.HasSuffix(strings.ToLower(filePath), strings.ToLower(statExtension)) { - isStaticFileToCompress = true - break - } - } - - if !isStaticFileToCompress { - http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) - return - } - - //to compress file - var contentEncoding string - if EnableGzip { - contentEncoding = getAcceptEncodingZip(ctx.Request) - } - - memZipFile, err := openMemZipFile(filePath, contentEncoding) - if err != nil { - if RunMode == "dev" { - Warn("Can't compress the file:", filePath, err) - } - http.NotFound(ctx.ResponseWriter, ctx.Request) - return - } - - if contentEncoding == "gzip" { - ctx.Output.Header("Content-Encoding", "gzip") - } else if contentEncoding == "deflate" { - ctx.Output.Header("Content-Encoding", "deflate") - } else { - ctx.Output.Header("Content-Length", strconv.FormatInt(fileInfo.Size(), 10)) - } - - http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, fileInfo.ModTime(), memZipFile) - return + if !strings.Contains(requestPath, prefix) { + continue + } + if len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { + continue + } + filePath := path.Join(staticDir, requestPath[len(prefix):]) + if fi, err := os.Stat(filePath); fi != nil { + return filePath, fi, err } } + return "", nil, notStaticRequestErr +} + +// lookupFile find the file to serve +// if the file is dir ,search the index.html as default file( MUST NOT A DIR also) +// if the index.html not exist or is a dir, give a forbidden response depending on DirectoryIndex +func lookupFile(ctx *context.Context) (bool, string, os.FileInfo, error) { + fp, fi, err := searchFile(ctx) + if fp == "" || fi == nil { + return false, "", nil, err + } + if !fi.IsDir() { + return false, fp, fi, err + } + ifp := path.Join(fp, "index.html") + if ifi, _ := os.Stat(ifp); ifi != nil && ifi.Mode().IsRegular() { + return false, ifp, ifi, err + } + return !DirectoryIndex, fp, fi, err } diff --git a/staticfile_test.go b/staticfile_test.go new file mode 100644 index 00000000..e635fcc6 --- /dev/null +++ b/staticfile_test.go @@ -0,0 +1,71 @@ +package beego + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "io" + "io/ioutil" + "os" + "testing" +) + +const licenseFile = "./LICENSE" + +func testOpenFile(encoding string, content []byte, t *testing.T) { + fi, _ := os.Stat(licenseFile) + b, n, sch, err := openFile(licenseFile, fi, encoding) + if err != nil { + t.Log(err) + t.Fail() + } + + t.Log("open static file encoding "+n, b) + + assetOpenFileAndContent(sch, content, t) +} +func TestOpenStaticFile_1(t *testing.T) { + file, _ := os.Open(licenseFile) + content, _ := ioutil.ReadAll(file) + testOpenFile("", content, t) +} + +func TestOpenStaticFileGzip_1(t *testing.T) { + file, _ := os.Open(licenseFile) + var zipBuf bytes.Buffer + fileWriter, _ := gzip.NewWriterLevel(&zipBuf, gzip.BestCompression) + io.Copy(fileWriter, file) + fileWriter.Close() + content, _ := ioutil.ReadAll(&zipBuf) + + testOpenFile("gzip", content, t) +} +func TestOpenStaticFileDeflate_1(t *testing.T) { + file, _ := os.Open(licenseFile) + var zipBuf bytes.Buffer + fileWriter, _ := flate.NewWriter(&zipBuf, flate.BestCompression) + io.Copy(fileWriter, file) + fileWriter.Close() + content, _ := ioutil.ReadAll(&zipBuf) + + testOpenFile("deflate", content, t) +} + +func assetOpenFileAndContent(sch *serveContentHolder, content []byte, t *testing.T) { + t.Log(sch.size, len(content)) + if sch.size != int64(len(content)) { + t.Log("static content file size not same") + t.Fail() + } + bs, _ := ioutil.ReadAll(sch) + for i, v := range content { + if v != bs[i] { + t.Log("content not same") + t.Fail() + } + } + if len(staticFileMap) == 0 { + t.Log("men map is empty") + t.Fail() + } +}