Skip to content

Add metadata retrieved from the context to the user agent when a new HTTP client is created #2789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions commands/instances.go
Original file line number Diff line number Diff line change
@@ -89,7 +89,7 @@ func (s *arduinoCoreServerImpl) Create(ctx context.Context, req *rpc.CreateReque
}
}

config, err := s.settings.DownloaderConfig()
config, err := s.settings.DownloaderConfig(ctx)
if err != nil {
return nil, err
}
@@ -377,7 +377,7 @@ func (s *arduinoCoreServerImpl) Init(req *rpc.InitRequest, stream rpc.ArduinoCor
responseError(err.GRPCStatus())
continue
}
config, err := s.settings.DownloaderConfig()
config, err := s.settings.DownloaderConfig(ctx)
if err != nil {
taskCallback(&rpc.TaskProgress{Name: i18n.Tr("Error downloading library %s", libraryRef)})
e := &cmderrors.FailedLibraryInstallError{Cause: err}
@@ -498,7 +498,7 @@ func (s *arduinoCoreServerImpl) UpdateLibrariesIndex(req *rpc.UpdateLibrariesInd
}

// Perform index update
config, err := s.settings.DownloaderConfig()
config, err := s.settings.DownloaderConfig(stream.Context())
if err != nil {
return err
}
@@ -608,7 +608,7 @@ func (s *arduinoCoreServerImpl) UpdateIndex(req *rpc.UpdateIndexRequest, stream
}
}

config, err := s.settings.DownloaderConfig()
config, err := s.settings.DownloaderConfig(stream.Context())
if err != nil {
downloadCB.Start(u, i18n.Tr("Downloading index: %s", filepath.Base(URL.Path)))
downloadCB.End(false, i18n.Tr("Invalid network configuration: %s", err))
18 changes: 9 additions & 9 deletions commands/service_board_identify.go
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ func (s *arduinoCoreServerImpl) BoardIdentify(ctx context.Context, req *rpc.Boar
defer release()

props := properties.NewFromHashmap(req.GetProperties())
res, err := identify(pme, props, s.settings, !req.GetUseCloudApiForUnknownBoardDetection())
res, err := identify(ctx, pme, props, s.settings, !req.GetUseCloudApiForUnknownBoardDetection())
if err != nil {
return nil, err
}
@@ -58,7 +58,7 @@ func (s *arduinoCoreServerImpl) BoardIdentify(ctx context.Context, req *rpc.Boar
}

// identify returns a list of boards checking first the installed platforms or the Cloud API
func identify(pme *packagemanager.Explorer, properties *properties.Map, settings *configuration.Settings, skipCloudAPI bool) ([]*rpc.BoardListItem, error) {
func identify(ctx context.Context, pme *packagemanager.Explorer, properties *properties.Map, settings *configuration.Settings, skipCloudAPI bool) ([]*rpc.BoardListItem, error) {
if properties == nil {
return nil, nil
}
@@ -90,7 +90,7 @@ func identify(pme *packagemanager.Explorer, properties *properties.Map, settings
// if installed cores didn't recognize the board, try querying
// the builder API if the board is a USB device port
if len(boards) == 0 && !skipCloudAPI && !settings.SkipCloudApiForBoardDetection() {
items, err := identifyViaCloudAPI(properties, settings)
items, err := identifyViaCloudAPI(ctx, properties, settings)
if err != nil {
// this is bad, but keep going
logrus.WithError(err).Debug("Error querying builder API")
@@ -119,22 +119,22 @@ func identify(pme *packagemanager.Explorer, properties *properties.Map, settings
return boards, nil
}

func identifyViaCloudAPI(props *properties.Map, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
func identifyViaCloudAPI(ctx context.Context, props *properties.Map, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
// If the port is not USB do not try identification via cloud
if !props.ContainsKey("vid") || !props.ContainsKey("pid") {
return nil, nil
}

logrus.Debug("Querying builder API for board identification...")
return cachedAPIByVidPid(props.Get("vid"), props.Get("pid"), settings)
return cachedAPIByVidPid(ctx, props.Get("vid"), props.Get("pid"), settings)
}

var (
vidPidURL = "https://builder.arduino.cc/v3/boards/byVidPid"
validVidPid = regexp.MustCompile(`0[xX][a-fA-F\d]{4}`)
)

func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
func cachedAPIByVidPid(ctx context.Context, vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
var resp []*rpc.BoardListItem

cacheKey := fmt.Sprintf("cache.builder-api.v3/boards/byvid/pid/%s/%s", vid, pid)
@@ -148,7 +148,7 @@ func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rp
}
}

resp, err := apiByVidPid(vid, pid, settings) // Perform API requrest
resp, err := apiByVidPid(ctx, vid, pid, settings) // Perform API requrest

if err == nil {
if cachedResp, err := json.Marshal(resp); err == nil {
@@ -160,7 +160,7 @@ func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rp
return resp, err
}

func apiByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
func apiByVidPid(ctx context.Context, vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
// ensure vid and pid are valid before hitting the API
if !validVidPid.MatchString(vid) {
return nil, errors.New(i18n.Tr("Invalid vid value: '%s'", vid))
@@ -173,7 +173,7 @@ func apiByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.Boar
req, _ := http.NewRequest("GET", url, nil)
req.Header.Set("Content-Type", "application/json")

httpClient, err := settings.NewHttpClient()
httpClient, err := settings.NewHttpClient(ctx)
if err != nil {
return nil, fmt.Errorf("%s: %w", i18n.Tr("failed to initialize http client"), err)
}
15 changes: 8 additions & 7 deletions commands/service_board_identify_test.go
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
package commands

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
@@ -48,15 +49,15 @@ func TestGetByVidPid(t *testing.T) {

vidPidURL = ts.URL
settings := configuration.NewSettings()
res, err := apiByVidPid("0xf420", "0XF069", settings)
res, err := apiByVidPid(context.Background(), "0xf420", "0XF069", settings)
require.Nil(t, err)
require.Len(t, res, 1)
require.Equal(t, "Arduino/Genuino MKR1000", res[0].GetName())
require.Equal(t, "arduino:samd:mkr1000", res[0].GetFqbn())

// wrong vid (too long), wrong pid (not an hex value)

_, err = apiByVidPid("0xfffff", "0xDEFG", settings)
_, err = apiByVidPid(context.Background(), "0xfffff", "0xDEFG", settings)
require.NotNil(t, err)
}

@@ -69,7 +70,7 @@ func TestGetByVidPidNotFound(t *testing.T) {
defer ts.Close()

vidPidURL = ts.URL
res, err := apiByVidPid("0x0420", "0x0069", settings)
res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings)
require.NoError(t, err)
require.Empty(t, res)
}
@@ -84,7 +85,7 @@ func TestGetByVidPid5xx(t *testing.T) {
defer ts.Close()

vidPidURL = ts.URL
res, err := apiByVidPid("0x0420", "0x0069", settings)
res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings)
require.NotNil(t, err)
require.Equal(t, "the server responded with status 500 Internal Server Error", err.Error())
require.Len(t, res, 0)
@@ -99,15 +100,15 @@ func TestGetByVidPidMalformedResponse(t *testing.T) {
defer ts.Close()

vidPidURL = ts.URL
res, err := apiByVidPid("0x0420", "0x0069", settings)
res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings)
require.NotNil(t, err)
require.Equal(t, "wrong format in server response", err.Error())
require.Len(t, res, 0)
}

func TestBoardDetectionViaAPIWithNonUSBPort(t *testing.T) {
settings := configuration.NewSettings()
items, err := identifyViaCloudAPI(properties.NewMap(), settings)
items, err := identifyViaCloudAPI(context.Background(), properties.NewMap(), settings)
require.NoError(t, err)
require.Empty(t, items)
}
@@ -156,7 +157,7 @@ func TestBoardIdentifySorting(t *testing.T) {
defer release()

settings := configuration.NewSettings()
res, err := identify(pme, idPrefs, settings, true)
res, err := identify(context.Background(), pme, idPrefs, settings, true)
require.NoError(t, err)
require.NotNil(t, res)
require.Len(t, res, 4)
6 changes: 3 additions & 3 deletions commands/service_check_for_updates.go
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ func (s *arduinoCoreServerImpl) CheckForArduinoCLIUpdates(ctx context.Context, r
inventory.WriteStore()
}()

latestVersion, err := semver.Parse(s.getLatestRelease())
latestVersion, err := semver.Parse(s.getLatestRelease(ctx))
if err != nil {
return nil, err
}
@@ -82,8 +82,8 @@ func (s *arduinoCoreServerImpl) shouldCheckForUpdate(currentVersion *semver.Vers

// getLatestRelease queries the official Arduino download server for the latest release,
// if there are no errors or issues a version string is returned, in all other case an empty string.
func (s *arduinoCoreServerImpl) getLatestRelease() string {
client, err := s.settings.NewHttpClient()
func (s *arduinoCoreServerImpl) getLatestRelease(ctx context.Context) string {
client, err := s.settings.NewHttpClient(ctx)
if err != nil {
return ""
}
12 changes: 8 additions & 4 deletions commands/service_library_download.go
Original file line number Diff line number Diff line change
@@ -82,11 +82,15 @@ func (s *arduinoCoreServerImpl) LibraryDownload(req *rpc.LibraryDownloadRequest,
})
}

func downloadLibrary(ctx context.Context, downloadsDir *paths.Path, libRelease *librariesindex.Release,
downloadCB rpc.DownloadProgressCB, taskCB rpc.TaskProgressCB, queryParameter string, settings *configuration.Settings) error {

func downloadLibrary(
ctx context.Context,
downloadsDir *paths.Path, libRelease *librariesindex.Release,
downloadCB rpc.DownloadProgressCB, taskCB rpc.TaskProgressCB,
queryParameter string,
settings *configuration.Settings,
) error {
taskCB(&rpc.TaskProgress{Name: i18n.Tr("Downloading %s", libRelease)})
config, err := settings.DownloaderConfig()
config, err := settings.DownloaderConfig(ctx)
if err != nil {
return &cmderrors.FailedDownloadError{Message: i18n.Tr("Can't download library"), Cause: err}
}
2 changes: 1 addition & 1 deletion internal/arduino/resources/helpers_test.go
Original file line number Diff line number Diff line change
@@ -55,7 +55,7 @@ func TestDownloadApplyUserAgentHeaderUsingConfig(t *testing.T) {

settings := configuration.NewSettings()
settings.Set("network.user_agent_ext", goldUserAgentValue)
config, err := settings.DownloaderConfig()
config, err := settings.DownloaderConfig(context.Background())
require.NoError(t, err)
err = r.Download(context.Background(), tmp, config, "", func(progress *rpc.DownloadProgress) {}, "")
require.NoError(t, err)
17 changes: 13 additions & 4 deletions internal/cli/configuration/network.go
Original file line number Diff line number Diff line change
@@ -16,18 +16,21 @@
package configuration

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"runtime"
"strings"
"time"

"github.com/arduino/arduino-cli/commands/cmderrors"
"github.com/arduino/arduino-cli/internal/i18n"
"github.com/arduino/arduino-cli/internal/version"
"go.bug.st/downloader/v2"
"google.golang.org/grpc/metadata"
)

// UserAgent returns the user agent (mainly used by HTTP clients)
@@ -84,17 +87,23 @@ func (settings *Settings) NetworkProxy() (*url.URL, error) {
}

// NewHttpClient returns a new http client for use in the arduino-cli
func (settings *Settings) NewHttpClient() (*http.Client, error) {
func (settings *Settings) NewHttpClient(ctx context.Context) (*http.Client, error) {
proxy, err := settings.NetworkProxy()
if err != nil {
return nil, err
}
userAgent := settings.UserAgent()
if md, ok := metadata.FromIncomingContext(ctx); ok {
if extraUserAgent := strings.Join(md.Get("user-agent"), " "); extraUserAgent != "" {
userAgent += " " + extraUserAgent
}
}
return &http.Client{
Transport: &httpClientRoundTripper{
transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
userAgent: settings.UserAgent(),
userAgent: userAgent,
},
Timeout: settings.ConnectionTimeout(),
}, nil
@@ -111,8 +120,8 @@ func (h *httpClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
}

// DownloaderConfig returns the downloader configuration based on current settings.
func (settings *Settings) DownloaderConfig() (downloader.Config, error) {
httpClient, err := settings.NewHttpClient()
func (settings *Settings) DownloaderConfig(ctx context.Context) (downloader.Config, error) {
httpClient, err := settings.NewHttpClient(ctx)
if err != nil {
return downloader.Config{}, &cmderrors.InvalidArgumentError{
Message: i18n.Tr("Could not connect via HTTP"),
7 changes: 4 additions & 3 deletions internal/cli/configuration/network_test.go
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
package configuration_test

import (
"context"
"fmt"
"io"
"net/http"
@@ -35,7 +36,7 @@ func TestUserAgentHeader(t *testing.T) {

settings := configuration.NewSettings()
require.NoError(t, settings.Set("network.user_agent_ext", "test-user-agent"))
client, err := settings.NewHttpClient()
client, err := settings.NewHttpClient(context.Background())
require.NoError(t, err)

request, err := http.NewRequest("GET", ts.URL, nil)
@@ -59,7 +60,7 @@ func TestProxy(t *testing.T) {

settings := configuration.NewSettings()
settings.Set("network.proxy", ts.URL)
client, err := settings.NewHttpClient()
client, err := settings.NewHttpClient(context.Background())
require.NoError(t, err)

request, err := http.NewRequest("GET", "http://arduino.cc", nil)
@@ -83,7 +84,7 @@ func TestConnectionTimeout(t *testing.T) {
if timeout != 0 {
require.NoError(t, settings.Set("network.connection_timeout", "2s"))
}
client, err := settings.NewHttpClient()
client, err := settings.NewHttpClient(context.Background())
require.NoError(t, err)

request, err := http.NewRequest("GET", "http://arduino.cc", nil)
6 changes: 5 additions & 1 deletion internal/integrationtest/arduino-cli.go
Original file line number Diff line number Diff line change
@@ -450,7 +450,11 @@ func (cli *ArduinoCLI) StartDaemon(verbose bool) string {
for retries := 5; retries > 0; retries-- {
time.Sleep(time.Second)

conn, err := grpc.NewClient(cli.daemonAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
conn, err := grpc.NewClient(
cli.daemonAddr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUserAgent("cli-test/0.0.0"),
)
if err != nil {
connErr = err
continue
59 changes: 59 additions & 0 deletions internal/integrationtest/daemon/daemon_test.go
Original file line number Diff line number Diff line change
@@ -20,6 +20,10 @@ import (
"errors"
"fmt"
"io"
"maps"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

@@ -555,6 +559,61 @@ func TestDaemonCoreUpgradePlatform(t *testing.T) {
})
}

func TestDaemonUserAgent(t *testing.T) {
env, cli := integrationtest.CreateEnvForDaemon(t)
defer env.CleanUp()

// Set up an http server to serve our custom index file
// The user-agent is tested inside the HTTPServeFile function
test_index := paths.New("..", "testdata", "test_index.json")
url := env.HTTPServeFile(8000, test_index)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Test that the user-agent contains metadata from the context when the CLI is in daemon mode
userAgent := r.Header.Get("User-Agent")

require.Contains(t, userAgent, "cli-test/0.0.0")
require.Contains(t, userAgent, "grpc-go")
// Depends on how we built the client we may have git-snapshot or 0.0.0-git in dev releases
require.Condition(t, func() (success bool) {
return strings.Contains(userAgent, "arduino-cli/git-snapshot") ||
strings.Contains(userAgent, "arduino-cli/0.0.0-git")
})

proxiedReq, err := http.NewRequest(r.Method, url.String(), r.Body)
require.NoError(t, err)
maps.Copy(proxiedReq.Header, r.Header)

proxiedResp, err := http.DefaultTransport.RoundTrip(proxiedReq)
require.NoError(t, err)
defer proxiedResp.Body.Close()

// Copy the headers from the proxy response to the original response
maps.Copy(r.Header, proxiedReq.Header)
w.WriteHeader(proxiedResp.StatusCode)
io.Copy(w, proxiedResp.Body)
}))
defer ts.Close()

grpcInst := cli.Create()
require.NoError(t, grpcInst.Init("", "", func(ir *commands.InitResponse) {
fmt.Printf("INIT> %v\n", ir.GetMessage())
}))

// Set extra indexes
additionalURL := ts.URL + "/test_index.json"
err := cli.SetValue("board_manager.additional_urls", fmt.Sprintf(`["%s"]`, additionalURL))
require.NoError(t, err)

{
cl, err := grpcInst.UpdateIndex(context.Background(), false)
require.NoError(t, err)
res, err := analyzeUpdateIndexClient(t, cl)
require.NoError(t, err)
require.Len(t, res, 2)
require.True(t, res[additionalURL].GetSuccess())
}
}

func analyzeUpdateIndexClient(t *testing.T, cl commands.ArduinoCoreService_UpdateIndexClient) (map[string]*commands.DownloadProgressEnd, error) {
analyzer := NewDownloadProgressAnalyzer(t)
for {
2 changes: 1 addition & 1 deletion internal/integrationtest/http_server.go
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@ import (
// HTTPServeFile spawn an http server that serve a single file. The server
// is started on the given port. The URL to the file and a cleanup function are returned.
func (env *Environment) HTTPServeFile(port uint16, path *paths.Path) *url.URL {
t := env.T()
mux := http.NewServeMux()
mux.HandleFunc("/"+path.Base(), func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, path.String())
@@ -36,7 +37,6 @@ func (env *Environment) HTTPServeFile(port uint16, path *paths.Path) *url.URL {
Handler: mux,
}

t := env.T()
fileURL, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d/%s", port, path.Base()))
require.NoError(t, err)