diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py
index 64d6829fc8..ae1787686f 100644
--- a/pymc/sampling/mcmc.py
+++ b/pymc/sampling/mcmc.py
@@ -338,6 +338,7 @@ def _sample_external_nuts(
                 UserWarning,
             )
         compile_kwargs = {}
+        nuts_sampler_kwargs = nuts_sampler_kwargs.copy()
         for kwarg in ("backend", "gradient_backend"):
             if kwarg in nuts_sampler_kwargs:
                 compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)