diff --git a/httplib/httplib.go b/httplib/httplib.go index 18995283..f7d083f2 100644 --- a/httplib/httplib.go +++ b/httplib/httplib.go @@ -16,22 +16,45 @@ import ( "mime/multipart" "net" "net/http" + "net/http/cookiejar" "net/http/httputil" "net/url" "os" "strings" + "sync" "time" ) -var defaultUserAgent = "beegoServer" +var defaultSetting = BeegoHttpSettings{false, "beegoServer", 60 * time.Second, 60 * time.Second, nil, nil, nil, false} +var defaultCookieJar http.CookieJar +var settingMutex sync.Mutex + +// createDefaultCookieJar creates a global cookiejar to store cookies. +func createDefaultCookie() { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultCookieJar, _ = cookiejar.New(nil) +} + +// Overwrite default settings +func SetDefaultSetting(setting BeegoHttpSettings) { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultSetting = setting + if defaultSetting.ConnectTimeout == 0 { + defaultSetting.ConnectTimeout = 60 * time.Second + } + if defaultSetting.ReadWriteTimeout == 0 { + defaultSetting.ReadWriteTimeout = 60 * time.Second + } +} // Get returns *BeegoHttpRequest with GET method. func Get(url string) *BeegoHttpRequest { var req http.Request req.Method = "GET" req.Header = http.Header{} - req.Header.Set("User-Agent", defaultUserAgent) - return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil} + return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, defaultSetting} } // Post returns *BeegoHttpRequest with POST method. @@ -39,8 +62,7 @@ func Post(url string) *BeegoHttpRequest { var req http.Request req.Method = "POST" req.Header = http.Header{} - req.Header.Set("User-Agent", defaultUserAgent) - return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil} + return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, defaultSetting} } // Put returns *BeegoHttpRequest with PUT method. @@ -48,8 +70,7 @@ func Put(url string) *BeegoHttpRequest { var req http.Request req.Method = "PUT" req.Header = http.Header{} - req.Header.Set("User-Agent", defaultUserAgent) - return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil} + return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, defaultSetting} } // Delete returns *BeegoHttpRequest DELETE GET method. @@ -57,8 +78,7 @@ func Delete(url string) *BeegoHttpRequest { var req http.Request req.Method = "DELETE" req.Header = http.Header{} - req.Header.Set("User-Agent", defaultUserAgent) - return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil} + return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, defaultSetting} } // Head returns *BeegoHttpRequest with HEAD method. @@ -66,40 +86,64 @@ func Head(url string) *BeegoHttpRequest { var req http.Request req.Method = "HEAD" req.Header = http.Header{} - req.Header.Set("User-Agent", defaultUserAgent) - return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil} + return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, defaultSetting} +} + +// BeegoHttpSettings +type BeegoHttpSettings struct { + ShowDebug bool + UserAgent string + ConnectTimeout time.Duration + ReadWriteTimeout time.Duration + TlsClientConfig *tls.Config + Proxy func(*http.Request) (*url.URL, error) + Transport http.RoundTripper + EnableCookie bool } // BeegoHttpRequest provides more useful methods for requesting one url than http.Request. type BeegoHttpRequest struct { - url string - req *http.Request - params map[string]string - files map[string]string - showdebug bool - connectTimeout time.Duration - readWriteTimeout time.Duration - tlsClientConfig *tls.Config - proxy func(*http.Request) (*url.URL, error) - transport http.RoundTripper + url string + req *http.Request + params map[string]string + files map[string]string + setting BeegoHttpSettings +} + +// Change request settings +func (b *BeegoHttpRequest) Setting(setting BeegoHttpSettings) *BeegoHttpRequest { + b.setting = setting + return b +} + +// SetEnableCookie sets enable/disable cookiejar +func (b *BeegoHttpRequest) SetEnableCookie(enable bool) *BeegoHttpRequest { + b.setting.EnableCookie = enable + return b +} + +// SetUserAgent sets User-Agent header field +func (b *BeegoHttpRequest) SetAgent(useragent string) *BeegoHttpRequest { + b.setting.UserAgent = useragent + return b } // Debug sets show debug or not when executing request. func (b *BeegoHttpRequest) Debug(isdebug bool) *BeegoHttpRequest { - b.showdebug = isdebug + b.setting.ShowDebug = isdebug return b } // SetTimeout sets connect time out and read-write time out for BeegoRequest. func (b *BeegoHttpRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHttpRequest { - b.connectTimeout = connectTimeout - b.readWriteTimeout = readWriteTimeout + b.setting.ConnectTimeout = connectTimeout + b.setting.ReadWriteTimeout = readWriteTimeout return b } // SetTLSClientConfig sets tls connection configurations if visiting https url. func (b *BeegoHttpRequest) SetTLSClientConfig(config *tls.Config) *BeegoHttpRequest { - b.tlsClientConfig = config + b.setting.TlsClientConfig = config return b } @@ -134,7 +178,7 @@ func (b *BeegoHttpRequest) SetCookie(cookie *http.Cookie) *BeegoHttpRequest { // Set transport to func (b *BeegoHttpRequest) SetTransport(transport http.RoundTripper) *BeegoHttpRequest { - b.transport = transport + b.setting.Transport = transport return b } @@ -146,7 +190,7 @@ func (b *BeegoHttpRequest) SetTransport(transport http.RoundTripper) *BeegoHttpR // return u, nil // } func (b *BeegoHttpRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHttpRequest { - b.proxy = proxy + b.setting.Proxy = proxy return b } @@ -242,7 +286,7 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) { } b.req.URL = url - if b.showdebug { + if b.setting.ShowDebug { dump, err := httputil.DumpRequest(b.req, true) if err != nil { println(err.Error()) @@ -250,32 +294,47 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) { println(string(dump)) } - trans := b.transport + trans := b.setting.Transport if trans == nil { // create default transport trans = &http.Transport{ - TLSClientConfig: b.tlsClientConfig, - Proxy: b.proxy, - Dial: TimeoutDialer(b.connectTimeout, b.readWriteTimeout), + TLSClientConfig: b.setting.TlsClientConfig, + Proxy: b.setting.Proxy, + Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), } } else { // if b.transport is *http.Transport then set the settings. if t, ok := trans.(*http.Transport); ok { if t.TLSClientConfig == nil { - t.TLSClientConfig = b.tlsClientConfig + t.TLSClientConfig = b.setting.TlsClientConfig } if t.Proxy == nil { - t.Proxy = b.proxy + t.Proxy = b.setting.Proxy } if t.Dial == nil { - t.Dial = TimeoutDialer(b.connectTimeout, b.readWriteTimeout) + t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) } } } + var jar http.CookieJar + if b.setting.EnableCookie { + if defaultCookieJar == nil { + createDefaultCookie() + } + jar = defaultCookieJar + } else { + jar = nil + } + client := &http.Client{ Transport: trans, + Jar: jar, + } + + if b.setting.UserAgent != "" { + b.req.Header.Set("User-Agent", b.setting.UserAgent) } resp, err := client.Do(b.req) diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go index 6c4f9428..7325862c 100644 --- a/httplib/httplib_test.go +++ b/httplib/httplib_test.go @@ -1,4 +1,4 @@ -// Beego (http://beego.me/) +// Beego (http://beego.me) // @description beego is an open-source, high-performance web framework for the Go programming language. // @link http://github.com/astaxie/beego for the canonical source repository // @license http://github.com/astaxie/beego/blob/master/LICENSE @@ -13,7 +13,7 @@ import ( ) func TestGetUrl(t *testing.T) { - resp, err := Get("http://beego.me/").Debug(true).Response() + resp, err := Get("http://beego.me").Debug(true).Response() if err != nil { t.Fatal(err) } @@ -29,7 +29,7 @@ func TestGetUrl(t *testing.T) { t.Fatal("data is no") } - str, err := Get("http://beego.me/").String() + str, err := Get("http://beego.me").String() if err != nil { t.Fatal(err) } @@ -42,10 +42,59 @@ func ExamplePost(t *testing.T) { b := Post("http://beego.me/").Debug(true) b.Param("username", "astaxie") b.Param("password", "hello") - b.PostFile("uploadfile", "httplib.go") + b.PostFile("uploadfile", "httplib_test.go") str, err := b.String() if err != nil { t.Fatal(err) } fmt.Println(str) } + +func TestSimpleGetString(t *testing.T) { + fmt.Println("TestSimpleGetString==========================================") + html, err := Get("http://httpbin.org/headers").SetAgent("beegoooooo").String() + if err != nil { + t.Fatal(err) + } + fmt.Println(html) + fmt.Println("TestSimpleGetString==========================================") +} + +func TestSimpleGetStringWithDefaultCookie(t *testing.T) { + fmt.Println("TestSimpleGetStringWithDefaultCookie==========================================") + html, err := Get("http://httpbin.org/cookies/set?k1=v1").SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + fmt.Println(html) + html, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + fmt.Println(html) + fmt.Println("TestSimpleGetStringWithDefaultCookie==========================================") +} + +func TestDefaultSetting(t *testing.T) { + fmt.Println("TestDefaultSetting==========================================") + var def BeegoHttpSettings + def.EnableCookie = true + //def.ShowDebug = true + def.UserAgent = "UserAgent" + //def.ConnectTimeout = 60*time.Second + //def.ReadWriteTimeout = 60*time.Second + def.Transport = nil //http.DefaultTransport + SetDefaultSetting(def) + + html, err := Get("http://httpbin.org/headers").String() + if err != nil { + t.Fatal(err) + } + fmt.Println(html) + html, err = Get("http://httpbin.org/headers").String() + if err != nil { + t.Fatal(err) + } + fmt.Println(html) + fmt.Println("TestDefaultSetting==========================================") +}