Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 36 additions & 32 deletions hack/code/prices_gen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/karpenter-provider-azure/pkg/auth"
"github.com/Azure/karpenter-provider-azure/pkg/providers/pricing"
"github.com/samber/lo"
)
Expand Down Expand Up @@ -112,7 +111,6 @@ func generatePricing(filePath string) {
defer f.Close() // error handling omitted for example

ctx := context.Background()
updateStarted := time.Now()
src := &bytes.Buffer{}
fmt.Fprintln(src, "//go:build !ignore_autogenerated")
license := lo.Must(os.ReadFile("hack/boilerplate.go.txt"))
Expand All @@ -126,47 +124,53 @@ func generatePricing(filePath string) {
fmt.Fprintln(src, "func init() {")

// Pin cloud to public for now
cloud := cloud.AzurePublic
env := &auth.Environment{
Cloud: cloud,
}
pricingAPI := pricing.NewAPI(cloud.AzurePublic)

// record prices for each region
var pricingProviderByRegion = map[string]chan *pricing.Provider{}
// Fetch prices for each region concurrently
type regionResult struct {
region string
onDemandPrices map[string]float64
}
resultsChan := make(chan regionResult, len(regions))
for _, region := range regions {
resultsChan := make(chan *pricing.Provider)
Comment thread
matthchr marked this conversation as resolved.
log.Println("fetching pricing data in region", region)
go func(region string, resultsChan chan *pricing.Provider) {
pricingProvider := pricing.NewProvider(ctx, env, pricing.NewAPI(cloud), region, make(chan struct{}))
attempts := 0
for {
if pricingProvider.OnDemandLastUpdated().After(updateStarted) {
go func(region string) {
var onDemandPrices map[string]float64
var err error
for attempt := range 10 {
onDemandPrices, _, err = pricing.FetchPricing(ctx, pricingAPI, region)
if err == nil {
break
}

if attempts == 0 {
log.Println("started wait loop for pricing update on region", region)
} else if attempts%10 == 0 {
log.Printf("waiting on pricing update on region %s...\n", region)
} else if time.Since(updateStarted) >= time.Minute*2 {
log.Fatalf("failed to update region %s within 2 minutes", region)
}
time.Sleep(1 * time.Second)
attempts += 1
log.Printf("attempt %d/10 failed for region %s: %v, retrying in 10s...", attempt+1, region, err)
time.Sleep(10 * time.Second)
Comment thread
matthchr marked this conversation as resolved.
}
log.Printf("fetched pricing for region %s\n", region)
resultsChan <- pricingProvider
}(region, resultsChan)
pricingProviderByRegion[region] = resultsChan
if err != nil {
log.Fatalf("failed to fetch pricing for region %s after 10 attempts: %v", region, err)
}
log.Printf("fetched pricing for region %s (%d instance types)\n", region, len(onDemandPrices))
resultsChan <- regionResult{region: region, onDemandPrices: onDemandPrices}
Comment thread
matthchr marked this conversation as resolved.
}(region)
}

// Collect results
resultsByRegion := map[string]map[string]float64{}
for range regions {
r := <-resultsChan
resultsByRegion[r.region] = r.onDemandPrices
}

// Write output in deterministic order
for _, region := range regions {
pricingProviderChan := pricingProviderByRegion[region]
var pricingProvider = <-pricingProviderChan
log.Println("writing output for", region)
instanceTypes := pricingProvider.InstanceTypes()
onDemandPrices := resultsByRegion[region]
instanceTypes := lo.Keys(onDemandPrices)
sort.Strings(instanceTypes)

writePricing(src, instanceTypes, region, pricingProvider.OnDemandPrice)
writePricing(src, instanceTypes, region, func(instanceType string) (float64, bool) {
price, ok := onDemandPrices[instanceType]
return price, ok
})
}
fmt.Fprintln(src, "}")
formatted, err := format.Source(src.Bytes())
Expand Down
108 changes: 35 additions & 73 deletions pkg/providers/pricing/pricing.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ type Provider struct {
done chan struct{}
}

type Err struct {
error
lastOnDemandUpdateTime time.Time
lastSpotUpdateTime time.Time
}

// NewPricingAPI returns a pricing API
func NewAPI(cloud cloud.Configuration) client.PricingAPI {
return client.New(cloud)
Expand Down Expand Up @@ -190,65 +184,46 @@ func (p *Provider) updatePricing(ctx context.Context) {
return
}

prices := map[client.Item]bool{}
err := p.fetchPricing(ctx, processPage(prices))
onDemandPrices, spotPrices, err := FetchPricing(ctx, p.pricing, p.region)
if err != nil {
if ctx.Err() != nil {
return
}
log.FromContext(ctx).Error(err, "failed to fetch updated pricing, using existing pricing data",
"lastOnDemandUpdateTime", err.lastOnDemandUpdateTime.Format(time.RFC3339),
"lastSpotUpdateTime", err.lastSpotUpdateTime.Format(time.RFC3339),
)
log.FromContext(ctx).Error(err, "failed to fetch updated pricing, using existing pricing data")
return
}

onDemandPrices, spotPrices := categorizePrices(prices)
p.mu.Lock()
defer p.mu.Unlock()

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := p.UpdateOnDemandPricing(ctx, onDemandPrices); err != nil {
log.FromContext(ctx).Error(err, "failed to update on-demand pricing, using existing pricing data",
"lastOnDemandUpdateTime", err.lastOnDemandUpdateTime.Format(time.RFC3339),
)
}
}()

wg.Add(1)
go func() {
defer wg.Done()
if err := p.UpdateSpotPricing(ctx, spotPrices); err != nil {
log.FromContext(ctx).Error(err, "failed to update spot pricing, using existing pricing data",
"lastSpotUpdateTime", err.lastSpotUpdateTime.Format(time.RFC3339),
if len(onDemandPrices) > 0 {
p.onDemandPrices = onDemandPrices
p.onDemandUpdateTime = time.Now()
if p.cm.HasChanged("on-demand-prices", p.onDemandPrices) {
log.FromContext(ctx).Info("updated on-demand pricing",
"instanceTypeCount", len(p.onDemandPrices),
)
}
}()

wg.Wait()
}

func (p *Provider) UpdateOnDemandPricing(ctx context.Context, onDemandPrices map[string]float64) *Err {
p.mu.Lock()
defer p.mu.Unlock()
if len(onDemandPrices) == 0 {
return &Err{error: errors.New("no on-demand pricing found"), lastOnDemandUpdateTime: p.onDemandUpdateTime}
} else {
log.FromContext(ctx).Error(errors.New("no on-demand pricing found"), "using existing on-demand pricing data")
}

p.onDemandPrices = lo.Assign(onDemandPrices)
p.onDemandUpdateTime = time.Now()
if p.cm.HasChanged("on-demand-prices", p.onDemandPrices) {
log.FromContext(ctx).Info("updated on-demand pricing",
"instanceTypeCount", len(p.onDemandPrices),
)
if len(spotPrices) > 0 {
p.spotPrices = spotPrices
p.spotUpdateTime = time.Now()
if p.cm.HasChanged("spot-prices", p.spotPrices) {
log.FromContext(ctx).Info("updated spot pricing",
"instanceTypeCount", len(p.spotPrices),
)
}
} else {
log.FromContext(ctx).Error(errors.New("no spot pricing found"), "using existing spot pricing data")
}
return nil
}

func (p *Provider) fetchPricing(ctx context.Context, pageHandler func(output *client.ProductsPricePage)) *Err {
p.mu.Lock()
defer p.mu.Unlock()
// FetchPricing fetches VM pricing from the Azure retail pricing API for the given region,
// returning on-demand and spot prices keyed by ARM SKU name.
func FetchPricing(ctx context.Context, pricingAPI client.PricingAPI, region string) (onDemandPrices, spotPrices map[string]float64, err error) {
filters := []*client.Filter{
{
Field: "priceType",
Expand All @@ -273,13 +248,16 @@ func (p *Provider) fetchPricing(ctx context.Context, pageHandler func(output *cl
{
Field: "armRegionName",
Operator: client.Equals,
Value: p.region,
Value: region,
}}
err := p.pricing.GetProductsPricePages(ctx, filters, pageHandler)
if err != nil {
return &Err{error: err, lastOnDemandUpdateTime: p.onDemandUpdateTime, lastSpotUpdateTime: p.spotUpdateTime}

prices := map[client.Item]bool{}
if err := pricingAPI.GetProductsPricePages(ctx, filters, processPage(prices)); err != nil {
return nil, nil, err
}
return nil

onDemandPrices, spotPrices = categorizePrices(prices)
return onDemandPrices, spotPrices, nil
}

func processPage(prices map[client.Item]bool) func(page *client.ProductsPricePage) {
Expand All @@ -297,25 +275,9 @@ func processPage(prices map[client.Item]bool) func(page *client.ProductsPricePag
}
}

func (p *Provider) UpdateSpotPricing(ctx context.Context, spotPrices map[string]float64) *Err {
p.mu.Lock()
defer p.mu.Unlock()
if len(spotPrices) == 0 {
return &Err{error: errors.New("no spot pricing found"), lastSpotUpdateTime: p.spotUpdateTime}
}

p.spotPrices = lo.Assign(spotPrices)
p.spotUpdateTime = time.Now()
if p.cm.HasChanged("spot-prices", p.spotPrices) {
log.FromContext(ctx).Info("updated spot pricing",
"instanceTypeCount", len(p.spotPrices),
)
}
return nil
}

func categorizePrices(prices map[client.Item]bool) (map[string]float64, map[string]float64) {
var onDemandPrices, spotPrices = map[string]float64{}, map[string]float64{}
onDemandPrices := map[string]float64{}
spotPrices := map[string]float64{}
for price := range prices {
if strings.HasSuffix(price.SkuName, " Spot") {
spotPrices[price.ArmSkuName] = price.RetailPrice
Expand Down
Loading