diff --git a/doc/modules/http.hsts.md b/doc/modules/http.hsts.md index ac893c44..857d6ae3 100644 --- a/doc/modules/http.hsts.md +++ b/doc/modules/http.hsts.md @@ -7,6 +7,14 @@ Data structures useful for HSTS (HTTP Strict Transport Security) Creates and returns a new HSTS store. +### `hsts_store.max_items` {#http.hsts.max_items} + +The maximum number of items allowed in the store. +Decreasing this value will only prevent new items from being added, it will not remove old items. + +Defaults to infinity (any number of items is allowed). + + ### `hsts_store:clone()` {#http.hsts:clone} Creates and returns a copy of a store. @@ -16,12 +24,24 @@ Creates and returns a copy of a store. Add new directives to the store about the given `host`. `directives` should be a table of directives, which *must* include the key `"max-age"`. +Returns a boolean indicating if the item was accepted. + + +### `hsts_store:remove(host)` {#http.hsts:remove} + +Removes the entry for `host` from the store (if it exists). + ### `hsts_store:check(host)` {#http.hsts:check} Returns a boolean indicating if the given `host` is a known HSTS host. +### `hsts_store:clean_due()` {#http.hsts:clean_due} + +Returns the number of seconds until the next item in the store expires. + + ### `hsts_store:clean()` {#http.hsts:clean} Removes expired entries from the store. diff --git a/http/hsts.lua b/http/hsts.lua index 84317cb6..70a4759c 100644 --- a/http/hsts.lua +++ b/http/hsts.lua @@ -3,13 +3,12 @@ Data structures useful for HSTS (HTTP Strict Transport Security) HSTS is described in RFC 6797 ]] -local EOF = require "lpeg".P(-1) -local IPv4address = require "lpeg_patterns.IPv4".IPv4address -local IPv6address = require "lpeg_patterns.IPv6".IPv6address -local IPaddress = (IPv4address + IPv6address) * EOF +local binaryheap = require "binaryheap" +local http_util = require "http.util" local store_methods = { time = function() return os.time() end; + max_items = (1e999); } local store_mt = { @@ -23,25 +22,22 @@ local store_item_mt = { __index = store_item_methods; } -local function host_is_ip(host) - if IPaddress:match(host) then - return true - else - return false - end -end - local function new_store() return setmetatable({ domains = {}; + expiry_heap = binaryheap.minUnique(); + n_items = 0; }, store_mt) end function store_methods:clone() local r = new_store() r.time = rawget(self, "time") + r.n_items = rawget(self, "n_items") + r.expiry_heap = binaryheap.minUnique() for host, item in pairs(self.domains) do r.domains[host] = item + r.expiry_heap:insert(item.expires, item) end return r end @@ -56,34 +52,62 @@ function store_methods:store(host, directives) else max_age = tonumber(max_age, 10) end - if host_is_ip(host) then - return false - end + + -- Clean now so that we can assume there are no expired items in store + self:clean() + if max_age == 0 then - -- delete from store - self.domains[host] = nil + return self:remove(host) else + if http_util.is_ip(host) then + return false + end -- add to store - self.domains[host] = setmetatable({ + local old_item = self.domains[host] + if old_item then + self.expiry_heap:remove(old_item) + else + local n_items = self.n_items + if n_items >= self.max_items then + return false + end + self.n_items = n_items + 1 + end + local expires = now + max_age + local item = setmetatable({ + host = host; includeSubdomains = directives.includeSubdomains; - expires = now + max_age; + expires = expires; }, store_item_mt) + self.domains[host] = item + self.expiry_heap:insert(expires, item) + end + return true +end + +function store_methods:remove(host) + local item = self.domains[host] + if item then + self.expiry_heap:remove(item) + self.domains[host] = nil + self.n_items = self.n_items - 1 end return true end function store_methods:check(host) - if host_is_ip(host) then + if http_util.is_ip(host) then return false end - local now = self.time() + + -- Clean now so that we can assume there are no expired items in store + self:clean() + local h = host repeat local item = self.domains[h] if item then - if item.expires < now then - self:clean() - elseif host == h or item.includeSubdomains then + if host == h or item.includeSubdomains then return true end end @@ -93,12 +117,20 @@ function store_methods:check(host) return false end +function store_methods:clean_due() + local next_expiring = self.expiry_heap:peek() + if not next_expiring then + return (1e999) + end + return next_expiring.expires +end + function store_methods:clean() local now = self.time() - for host, item in pairs(self.domains) do - if item.expires < now then - self.domains[host] = nil - end + while self:clean_due() < now do + local item = self.expiry_heap:pop() + self.domains[item.host] = nil + self.n_items = self.n_items - 1 end return true end diff --git a/spec/hsts_spec.lua b/spec/hsts_spec.lua index e01bc395..90eb6ca2 100644 --- a/spec/hsts_spec.lua +++ b/spec/hsts_spec.lua @@ -9,12 +9,26 @@ describe("hsts module", function() end) it("can be cloned", function() local s = http_hsts.new_store() - assert.same(s, s:clone()) + do + local clone = s:clone() + local old_heap = s.expiry_heap + s.expiry_heap = nil + clone.expiry_heap = nil + assert.same(s, clone) + s.expiry_heap = old_heap + end assert.truthy(s:store("foo.example.com", { ["max-age"] = "100"; })) + do + local clone = s:clone() + local old_heap = s.expiry_heap + s.expiry_heap = nil + clone.expiry_heap = nil + assert.same(s, clone) + s.expiry_heap = old_heap + end local clone = s:clone() - assert.same(s, clone) assert.truthy(s:check("foo.example.com")) assert.truthy(clone:check("foo.example.com")) end) @@ -93,4 +107,22 @@ describe("hsts module", function() assert.falsy(s:check("example.com")) assert.truthy(s:check("keep.me")) end) + it("enforces .max_items", function() + local s = http_hsts.new_store() + s.max_items = 0 + assert.falsy(s:store("example.com", { + ["max-age"] = "100"; + })) + s.max_items = 1 + assert.truthy(s:store("example.com", { + ["max-age"] = "100"; + })) + assert.falsy(s:store("other.com", { + ["max-age"] = "100"; + })) + s:remove("example.com", "/", "foo") + assert.truthy(s:store("other.com", { + ["max-age"] = "100"; + })) + end) end)