Using context for HTTP requests

This commit is contained in:
Quentin McGaw
2020-10-17 21:54:09 +00:00
parent 0d2ca377df
commit 6f4be72785
7 changed files with 37 additions and 28 deletions

View File

@@ -52,7 +52,8 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
var err error
switch args[1] {
case "healthcheck":
err = cli.HealthCheck()
client := &http.Client{Timeout: time.Second}
err = cli.HealthCheck(background, client)
case "clientkey":
err = cli.ClientKey(args[2:])
case "openvpnconfig":
@@ -403,7 +404,7 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn
if !versionInformation {
break
}
message, err := versionpkg.GetMessage(version, commit, httpClient)
message, err := versionpkg.GetMessage(ctx, version, commit, httpClient)
if err != nil {
logger.Error(err)
break

1
go.mod
View File

@@ -9,5 +9,6 @@ require (
github.com/qdm12/golibs v0.0.0-20200712151944-a0325873bf5a
github.com/qdm12/ss-server v0.0.0-20200819005413-6b516c299307
github.com/stretchr/testify v1.6.1
golang.org/x/net v0.0.0-20190620200207-3b0461eec859
golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed
)

View File

@@ -17,6 +17,7 @@ import (
"github.com/qdm12/gluetun/internal/updater"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"golang.org/x/net/context/ctxhttp"
)
func ClientKey(args []string) error {
@@ -39,9 +40,9 @@ func ClientKey(args []string) error {
return nil
}
func HealthCheck() error {
client := &http.Client{Timeout: time.Second}
response, err := client.Get("http://localhost:8000/health")
func HealthCheck(ctx context.Context, client *http.Client) error {
const url = "http://localhost:8000/health"
response, err := ctxhttp.Get(ctx, client, url)
if err != nil {
return err
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"golang.org/x/net/context/ctxhttp"
)
type piaV3 struct {
@@ -89,7 +90,7 @@ func (p *piaV3) PortForward(ctx context.Context, client *http.Client,
}
clientID := hex.EncodeToString(b)
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
response, err := client.Get(url) // TODO add ctx
response, err := ctxhttp.Get(ctx, client, url)
if err != nil {
pfLogger.Error(err)
return

View File

@@ -21,6 +21,7 @@ import (
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"golang.org/x/net/context/ctxhttp"
)
type piaV4 struct {
@@ -151,7 +152,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
if !dataFound || expired {
tryUntilSuccessful(ctx, pfLogger, func() error {
data, err = refreshPIAPortForwardData(client, gateway, fileManager)
data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager)
return err
})
if ctx.Err() != nil {
@@ -163,7 +164,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
// First time binding
tryUntilSuccessful(ctx, pfLogger, func() error {
return bindPIAPort(client, gateway, data)
return bindPIAPort(ctx, client, gateway, data)
})
if ctx.Err() != nil {
return
@@ -202,7 +203,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
}
return
case <-keepAliveTimer.C:
if err := bindPIAPort(client, gateway, data); err != nil {
if err := bindPIAPort(ctx, client, gateway, data); err != nil {
pfLogger.Error(err)
}
keepAliveTimer.Reset(keepAlivePeriod)
@@ -210,7 +211,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
oldPort := data.Port
for {
data, err = refreshPIAPortForwardData(client, gateway, fileManager)
data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager)
if err != nil {
pfLogger.Error(err)
continue
@@ -233,7 +234,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
); err != nil {
pfLogger.Error(err)
}
if err := bindPIAPort(client, gateway, data); err != nil {
if err := bindPIAPort(ctx, client, gateway, data); err != nil {
pfLogger.Error(err)
}
if !keepAliveTimer.Stop() {
@@ -292,12 +293,12 @@ func newPIAv4HTTPClient(serverName string) (client *http.Client, err error) {
return client, nil
}
func refreshPIAPortForwardData(client *http.Client, gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) {
func refreshPIAPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) {
data.Token, err = fetchPIAToken(fileManager, client)
if err != nil {
return data, fmt.Errorf("cannot obtain token: %w", err)
}
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(client, gateway, data.Token)
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(ctx, client, gateway, data.Token)
if err != nil {
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
}
@@ -429,7 +430,7 @@ func getOpenvpnCredentials(fileManager files.FileManager) (username, password st
return username, password, nil
}
func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string) (port uint16, signature string, expiration time.Time, err error) {
func fetchPIAPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, token string) (port uint16, signature string, expiration time.Time, err error) {
queryParams := url.Values{}
queryParams.Add("token", token)
url := url.URL{
@@ -438,7 +439,7 @@ func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string)
Path: "/getSignature",
RawQuery: queryParams.Encode(),
}
response, err := client.Get(url.String())
response, err := ctxhttp.Get(ctx, client, url.String())
if err != nil {
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
}
@@ -465,7 +466,7 @@ func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string)
return port, data.Signature, expiration, err
}
func bindPIAPort(client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
payload, err := packPIAPayload(data.Port, data.Token, data.Expiration)
if err != nil {
return err
@@ -480,7 +481,7 @@ func bindPIAPort(client *http.Client, gateway net.IP, data piaPortForwardData) (
RawQuery: queryParams.Encode(),
}
response, err := client.Get(url.String())
response, err := ctxhttp.Get(ctx, client, url.String())
if err != nil {
return fmt.Errorf("cannot bind port: %w", err)
}

View File

@@ -1,10 +1,13 @@
package version
import (
"context"
"encoding/json"
"io/ioutil"
"net/http"
"time"
"golang.org/x/net/context/ctxhttp"
)
type githubRelease struct {
@@ -23,9 +26,9 @@ type githubCommit struct {
}
}
func getGithubReleases(client *http.Client) (releases []githubRelease, err error) {
func getGithubReleases(ctx context.Context, client *http.Client) (releases []githubRelease, err error) {
const url = "https://api.github.com/repos/qdm12/gluetun/releases"
response, err := client.Get(url)
response, err := ctxhttp.Get(ctx, client, url)
if err != nil {
return nil, err
}
@@ -40,9 +43,9 @@ func getGithubReleases(client *http.Client) (releases []githubRelease, err error
return releases, nil
}
func getGithubCommits(client *http.Client) (commits []githubCommit, err error) {
func getGithubCommits(ctx context.Context, client *http.Client) (commits []githubCommit, err error) {
const url = "https://api.github.com/repos/qdm12/gluetun/commits"
response, err := client.Get(url)
response, err := ctxhttp.Get(ctx, client, url)
if err != nil {
return nil, err
}

View File

@@ -1,6 +1,7 @@
package version
import (
"context"
"fmt"
"net/http"
"time"
@@ -10,10 +11,10 @@ import (
// GetMessage returns a message for the user describing if there is a newer version
// available. It should only be called once the tunnel is established.
func GetMessage(version, commitShort string, client *http.Client) (message string, err error) {
func GetMessage(ctx context.Context, version, commitShort string, client *http.Client) (message string, err error) {
if version == "latest" {
// Find # of commits between current commit and latest commit
commitsSince, err := getCommitsSince(client, commitShort)
commitsSince, err := getCommitsSince(ctx, client, commitShort)
if err != nil {
return "", fmt.Errorf("cannot get version information: %w", err)
} else if commitsSince == 0 {
@@ -25,7 +26,7 @@ func GetMessage(version, commitShort string, client *http.Client) (message strin
}
return fmt.Sprintf("You are running %d %s behind the most recent %s", commitsSince, commits, version), nil
}
tagName, name, releaseTime, err := getLatestRelease(client)
tagName, name, releaseTime, err := getLatestRelease(ctx, client)
if err != nil {
return "", fmt.Errorf("cannot get version information: %w", err)
}
@@ -38,8 +39,8 @@ func GetMessage(version, commitShort string, client *http.Client) (message strin
nil
}
func getLatestRelease(client *http.Client) (tagName, name string, time time.Time, err error) {
releases, err := getGithubReleases(client)
func getLatestRelease(ctx context.Context, client *http.Client) (tagName, name string, time time.Time, err error) {
releases, err := getGithubReleases(ctx, client)
if err != nil {
return "", "", time, err
}
@@ -52,8 +53,8 @@ func getLatestRelease(client *http.Client) (tagName, name string, time time.Time
return "", "", time, fmt.Errorf("no releases found")
}
func getCommitsSince(client *http.Client, commitShort string) (n int, err error) {
commits, err := getGithubCommits(client)
func getCommitsSince(ctx context.Context, client *http.Client, commitShort string) (n int, err error) {
commits, err := getGithubCommits(ctx, client)
if err != nil {
return 0, err
}