Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 46 additions & 25 deletions internal/clients/clientimpl/localmatcher/zip.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package localmatcher

import (
"archive/zip"
"bytes"
"context"
"encoding/base64"
"encoding/binary"
Expand Down Expand Up @@ -78,37 +77,56 @@ func fetchRemoteArchiveCRC32CHash(ctx context.Context, url string) (uint32, erro
return 0, errors.New("could not find crc32c= checksum")
}

func fetchLocalArchiveCRC32CHash(data []byte) uint32 {
return crc32.Checksum(data, crc32.MakeTable(crc32.Castagnoli))
func fetchLocalArchiveCRC32CHash(f *os.File) (uint32, int64, error) {
h := crc32.New(crc32.MakeTable(crc32.Castagnoli))
size, err := io.Copy(h, f)

if err != nil {
return 0, 0, err
}

return h.Sum32(), size, nil
}

func (db *ZipDB) fetchZip(ctx context.Context) ([]byte, error) {
cache, err := os.ReadFile(db.StoredAt)
func (db *ZipDB) fetchZip(ctx context.Context) (*os.File, int64, error) {
f, err := os.Open(db.StoredAt)

if db.Offline {
if err != nil {
return nil, ErrOfflineDatabaseNotFound
return nil, 0, ErrOfflineDatabaseNotFound
}

s, err := f.Stat()

if err != nil {
return nil, 0, err
}

return cache, nil
return f, s.Size(), nil
}

if err == nil {
remoteHash, err := fetchRemoteArchiveCRC32CHash(ctx, db.ArchiveURL)

if err != nil {
return nil, err
return nil, 0, err
}

localHash, size, err := fetchLocalArchiveCRC32CHash(f)

if err != nil {
return nil, 0, err
}

if fetchLocalArchiveCRC32CHash(cache) == remoteHash {
return cache, nil
if remoteHash == localHash {
return f, size, nil
}
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, db.ArchiveURL, nil)

if err != nil {
return nil, fmt.Errorf("could not retrieve OSV database archive: %w", err)
return nil, 0, fmt.Errorf("could not retrieve OSV database archive: %w", err)
}

if db.UserAgent != "" {
Expand All @@ -117,35 +135,36 @@ func (db *ZipDB) fetchZip(ctx context.Context) ([]byte, error) {

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("could not retrieve OSV database archive: %w", err)
return nil, 0, fmt.Errorf("could not retrieve OSV database archive: %w", err)
}

defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("db host returned %s", resp.Status)
return nil, 0, fmt.Errorf("db host returned %s", resp.Status)
}

var body []byte

body, err = io.ReadAll(resp.Body)
err = os.MkdirAll(path.Dir(db.StoredAt), 0750)

if err != nil {
return nil, fmt.Errorf("could not read OSV database archive from response: %w", err)
return nil, 0, fmt.Errorf("could not create cache directory: %w", err)
}

err = os.MkdirAll(path.Dir(db.StoredAt), 0750)
f, err = os.OpenFile(db.StoredAt, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644)

if err == nil {
//nolint:gosec // being world readable is fine
err = os.WriteFile(db.StoredAt, body, 0644)
if err != nil {
return nil, 0, fmt.Errorf("could not create cache file: %w", err)
}

size, err := io.Copy(f, resp.Body)

if err != nil {
cmdlogger.Warnf("Failed to save database to %s: %v", db.StoredAt, err)
return nil, 0, fmt.Errorf("could not write cache file: %w", err)
}

return body, nil
_, _ = f.Seek(0, io.SeekStart)

return f, size, nil
}

func mightAffectPackages(v *osvschema.Vulnerability, names []string) bool {
Expand Down Expand Up @@ -211,13 +230,15 @@ func (db *ZipDB) loadZipFile(zipFile *zip.File, names []string) {
func (db *ZipDB) load(ctx context.Context, names []string) error {
db.Vulnerabilities = []*osvschema.Vulnerability{}

body, err := db.fetchZip(ctx)
f, size, err := db.fetchZip(ctx)
Comment thread
G-Rath marked this conversation as resolved.
Outdated

if err != nil {
return err
}

zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body)))
defer f.Close()

zipReader, err := zip.NewReader(f, size)
if err != nil {
return fmt.Errorf("could not read OSV database archive: %w", err)
}
Expand Down
Loading