diff --git a/lib/axon.ex b/lib/axon.ex index 7a761fcb3..8aefafe0c 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -301,9 +301,14 @@ defmodule Axon do to inference function except: * `:name` - layer name. + * `:op_name` - layer operation for inspection and building parameter map. + * `:mode` - if the layer should run only on `:inference` or `:train`. Defaults to `:both` + * `:global_options` - a list of global option names that this layer + supports. Global options passed to `build/2` will be forwarded to + the layer, as long as they are declared Note this means your layer should not use these as input options, as they will always be dropped during inference compilation. @@ -332,14 +337,15 @@ defmodule Axon do {mode, opts} = Keyword.pop(opts, :mode, :both) {name, opts} = Keyword.pop(opts, :name) {op_name, opts} = Keyword.pop(opts, :op_name, :custom) + {global_options, opts} = Keyword.pop(opts, :global_options, []) name = name(op_name, name) id = System.unique_integer([:positive, :monotonic]) - axon_node = make_node(id, op, name, op_name, mode, inputs, params, args, opts) + axon_node = make_node(id, op, name, op_name, mode, inputs, params, args, opts, global_options) %Axon{output: id, nodes: Map.put(updated_nodes, id, axon_node)} end - defp make_node(id, op, name, op_name, mode, inputs, params, args, layer_opts) do + defp make_node(id, op, name, op_name, mode, inputs, params, args, layer_opts, global_options) do {:current_stacktrace, [_process_info, _axon_layer | stacktrace]} = Process.info(self(), :current_stacktrace) @@ -354,6 +360,7 @@ defmodule Axon do policy: Axon.MixedPrecision.create_policy(), hooks: [], opts: layer_opts, + global_options: global_options, op_name: op_name, stacktrace: stacktrace } @@ -3651,6 +3658,9 @@ defmodule Axon do to control differences in compilation at training or inference time. Defaults to `:inference` + * `:global_layer_options` - a keyword list of options passed to + layers that accept said options + All other options are forwarded to the underlying JIT compiler. """ @doc type: :model diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index d75989029..07b9109a4 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -50,7 +50,8 @@ defmodule Axon.Compiler do debug? = Keyword.get(opts, :debug, false) mode = Keyword.get(opts, :mode, :inference) seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end) - config = %{mode: mode, debug?: debug?} + global_layer_options = Keyword.get(opts, :global_layer_options, []) + config = %{mode: mode, debug?: debug?, global_layer_options: global_layer_options} {time, {root_id, {cache, _op_counts, _block_cache}}} = :timer.tc(fn -> @@ -718,6 +719,7 @@ defmodule Axon.Compiler do parameters: layer_params, args: args, opts: opts, + global_options: global_options, policy: policy, hooks: hooks, op_name: op_name, @@ -725,7 +727,7 @@ defmodule Axon.Compiler do }, nodes, cache_and_counts, - %{mode: mode, debug?: debug?} = config + %{mode: mode, debug?: debug?, global_layer_options: global_layer_options} = config ) when (is_function(op) or is_atom(op)) and is_list(inputs) do # Traverse to accumulate cache and get parent_ids for @@ -761,10 +763,12 @@ defmodule Axon.Compiler do name, args, opts, + global_options, policy, layer_params, hooks, mode, + global_layer_options, stacktrace ) @@ -841,10 +845,12 @@ defmodule Axon.Compiler do name, args, opts, + global_options, policy, layer_params, hooks, mode, + global_layer_options, layer_stacktrace ) do # Recurse graph inputs and invoke cache to get parent results, @@ -914,7 +920,12 @@ defmodule Axon.Compiler do # Compute arguments to be forwarded and ensure `:mode` is included # for inference/training behavior dependent functions - args = Enum.reverse(tensor_inputs, [Keyword.put(opts, :mode, mode)]) + layer_opts = + opts + |> Keyword.merge(Keyword.take(global_layer_options, global_options)) + |> Keyword.put(:mode, mode) + + args = Enum.reverse(tensor_inputs, [layer_opts]) # For built-in layers we always just apply the equivalent function # in Axon.Layers. The implication of this is that every function which diff --git a/lib/axon/node.ex b/lib/axon/node.ex index e50e32387..b6521501a 100644 --- a/lib/axon/node.ex +++ b/lib/axon/node.ex @@ -12,6 +12,7 @@ defmodule Axon.Node do :policy, :hooks, :opts, + :global_options, :op_name, :stacktrace ] diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index 591bf7bfa..98a045cac 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -5837,4 +5837,27 @@ defmodule CompilerTest do assert predict_fn.(params, x) == Nx.add(x, a) end end + + describe "global layer options" do + test "global options are forwarded to the layer when declared" do + input = Axon.input("input") + + model = + Axon.layer( + fn input, opts -> + assert Keyword.has_key?(opts, :option1) + refute Keyword.has_key?(opts, :option2) + input + end, + [input], + global_options: [:option1] + ) + + {_, predict_fn} = Axon.build(model, global_layer_options: [option1: true, option2: true]) + + params = %{} + input = random({1, 1}, type: {:f, 32}) + predict_fn.(params, input) + end + end end