diff --git a/pytensor/sandbox/__init__.py b/pytensor/sandbox/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/pytensor/sandbox/linalg/__init__.py b/pytensor/sandbox/linalg/__init__.py
deleted file mode 100644
index e4428ca21f..0000000000
--- a/pytensor/sandbox/linalg/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from pytensor.sandbox.linalg.ops import spectral_radius_bound
diff --git a/pytensor/sandbox/minimal.py b/pytensor/sandbox/minimal.py
deleted file mode 100644
index c0236e6cc7..0000000000
--- a/pytensor/sandbox/minimal.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import numpy as np
-
-from pytensor.graph.basic import Apply
-from pytensor.graph.op import Op
-from pytensor.tensor.type import lscalar
-
-
-class Minimal(Op):
-    # TODO : need description for class
-
-    # if the Op has any attributes, consider using them in the eq function.
-    # If two Apply nodes have the same inputs and the ops compare equal...
-    # then they will be MERGED so they had better have computed the same thing!
-
-    __props__ = ()
-
-    def __init__(self):
-        # If you put things here, think about whether they change the outputs
-        # computed by # self.perform()
-        #  - If they do, then you should take them into consideration in
-        #    __eq__ and __hash__
-        #  - If they do not, then you should not use them in
-        #    __eq__ and __hash__
-
-        super().__init__()
-
-    def make_node(self, *args):
-        # HERE `args` must be PYTENSOR VARIABLES
-        return Apply(op=self, inputs=args, outputs=[lscalar()])
-
-    def perform(self, node, inputs, out_):
-        (output,) = out_
-        # HERE `inputs` are PYTHON OBJECTS
-
-        # do what you want here,
-        # but do not modify any of the arguments [inplace].
-        print("perform got %i arguments" % len(inputs))
-
-        print("Max of input[0] is ", np.max(inputs[0]))
-
-        # return some computed value.
-        # do not return something that is aliased to one of the inputs.
-        output[0] = np.asarray(0, dtype="int64")
-
-
-minimal = Minimal()
diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py
index 1282cabae5..b276d7339b 100644
--- a/pytensor/tensor/blas.py
+++ b/pytensor/tensor/blas.py
@@ -1,4 +1,4 @@
-"""Ops and optimizations for using BLAS calls
+"""Ops for using BLAS calls
 
 BLAS = Basic Linear Algebra Subroutines
 Learn more about BLAS here:
@@ -71,60 +71,10 @@
 that system.
 
 
-Optimizations
-=============
-
-The optimization pipeline works something like this:
-
-    1. identify dot22 from dot
-    2. identify gemm from dot22
-    3. identify dot22scalar from dot22 that are not gemm
-    4. specialize gemm to gemv where applicable
-    5. specialize gemm to ger where applicable
-    6. specialize dot22 -> gemv or ger where applicable
-
-:note: GEMM is the most canonical BLAS signature that we deal with so far, it
-    would be good to turn most things into GEMM (dot, inner, outer, dot22,
-    dot22scalar), and then to specialize from gemm to the various other L2 and
-    L3 operations.
-
-Identify Dot22
---------------
-
-Numpy's dot supports arguments that are of any rank, and we should support that
-too (just for compatibility).  The BLAS optimizations work with Dot Ops whose
-inputs are each either vector or matrix.  So the first part of the optimization
-pipeline is to transform qualifying Dot Ops to Dot22 Ops. Dot22 Ops may be
-transformed further, but they will get implemented by a BLAS call.
-
-More precisely, Dot nodes whose inputs are all vectors or matrices and whose
-inputs both have the same dtype, and whose dtype is float or complex, become
-Dot22.  This is implemented in `local_dot_to_dot22`.
-
-
-Identify Gemm from Dot22
-------------------------
-
-This is complicated, done in GemmOptimizer.
-
-Identify Dot22Scalar from Dot22
--------------------------------
-
-Dot22 Ops that remain after the GemmOptimizer is done have not
-qualified as GEMM Ops. Still they might be scaled by a factor, in
-which case we use Dot22Scalar which is like Gemm, but without the b
-and the Z.  In the future it would be good to merge this into the
-GemmOptimizer.
-
-Specialize Gemm to Gemv
------------------------
-
-If arguments to GEMM are dimshuffled vectors, then we can use GEMV
-instead. This optimization is `local_gemm_to_gemv`.
+Optimizations associated with these BLAS Ops are in tensor.rewriting.blas
 
 """
 
-import copy
 import logging
 import os
 import time
@@ -140,38 +90,20 @@
 from typing import Tuple
 
 import pytensor.scalar
-from pytensor.compile.mode import optdb
 from pytensor.configdefaults import config
 from pytensor.graph.basic import Apply, view_roots
-from pytensor.graph.features import ReplacementDidNotRemoveError, ReplaceValidate
 from pytensor.graph.op import Op
-from pytensor.graph.rewriting.basic import (
-    EquilibriumGraphRewriter,
-    GraphRewriter,
-    copy_stack_trace,
-    in2out,
-    node_rewriter,
-)
-from pytensor.graph.rewriting.db import SequenceDB
 from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
 from pytensor.link.c.op import COp
 from pytensor.link.c.params_type import ParamsType
-from pytensor.printing import FunctionPrinter, debugprint, pprint
+from pytensor.printing import FunctionPrinter, pprint
 from pytensor.scalar import bool as bool_t
 from pytensor.tensor import basic as at
 from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
-from pytensor.tensor.elemwise import DimShuffle, Elemwise
-from pytensor.tensor.exceptions import NotScalarConstantError
-from pytensor.tensor.math import Dot, add, mul, neg, sub
-from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
+from pytensor.tensor.elemwise import DimShuffle
+from pytensor.tensor.math import add, mul, neg, sub
 from pytensor.tensor.shape import specify_broadcastable
-from pytensor.tensor.type import (
-    DenseTensorType,
-    TensorType,
-    integer_dtypes,
-    tensor,
-    values_eq_approx_remove_inf_nan,
-)
+from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
 from pytensor.utils import memoize
 
 
@@ -1512,150 +1444,6 @@ def _gemm_from_node2(fgraph, node):
     return None, t1 - t0, 0, 0
 
 
-class GemmOptimizer(GraphRewriter):
-    """Graph optimizer for inserting Gemm operations."""
-
-    def __init__(self):
-        super().__init__()
-        self.warned = False
-
-    def add_requirements(self, fgraph):
-        fgraph.attach_feature(ReplaceValidate())
-
-    def apply(self, fgraph):
-        did_something = True
-        nb_iter = 0
-        nb_replacement = 0
-        nb_replacement_didn_t_remove = 0
-        nb_inconsistency_make = 0
-        nb_inconsistency_replace = 0
-        time_canonicalize = 0
-        time_factor_can = 0
-        time_factor_list = 0
-        time_toposort = 0
-        if fgraph.profile:
-            validate_before = fgraph.profile.validate_time
-            callbacks_before = fgraph.execute_callbacks_times.copy()
-            callback_before = fgraph.execute_callbacks_time
-
-        def on_import(new_node):
-            if new_node is not node:
-                nodelist.append(new_node)
-
-        u = pytensor.graph.rewriting.basic.DispatchingFeature(
-            on_import, None, None, name="GemmOptimizer"
-        )
-        fgraph.attach_feature(u)
-        while did_something:
-            nb_iter += 1
-            t0 = time.perf_counter()
-            nodelist = pytensor.graph.basic.io_toposort(fgraph.inputs, fgraph.outputs)
-            time_toposort += time.perf_counter() - t0
-            did_something = False
-            nodelist.reverse()
-            for node in nodelist:
-                if not (
-                    isinstance(node.op, Elemwise)
-                    and isinstance(
-                        node.op.scalar_op,
-                        (
-                            pytensor.scalar.Add,
-                            pytensor.scalar.Sub,
-                            pytensor.scalar.Neg,
-                            pytensor.scalar.Mul,
-                        ),
-                    )
-                ):
-                    continue
-                if node not in fgraph.apply_nodes:
-                    # This mean that we already removed this node from
-                    # the graph
-                    continue
-                try:
-                    new_outputs, time1, time2, time3 = _gemm_from_node2(fgraph, node)
-                    time_canonicalize += time1
-                    time_factor_can += time2
-                    time_factor_list += time3
-                except InconsistencyError:
-                    nb_inconsistency_make += 1
-                    continue
-                if new_outputs:
-                    new_outputs, old_dot22 = new_outputs
-                    assert len(new_outputs) == len(node.outputs)
-                    new_outputs[
-                        0
-                    ].tag.values_eq_approx = values_eq_approx_remove_inf_nan
-                    try:
-                        fgraph.replace_all_validate_remove(
-                            list(zip(node.outputs, new_outputs)),
-                            [old_dot22],
-                            reason="GemmOptimizer",
-                            # For now we disable the warning as we know case
-                            # that we need to fix.
-                            warn=False,  # warn=not self.warned
-                        )
-                        did_something = True
-                        nb_replacement += 1
-                    except InconsistencyError:
-                        # TODO: retry other applications of gemm (see comment
-                        # in _gemm_from_node)
-                        nb_inconsistency_replace += 1
-                    except ReplacementDidNotRemoveError:
-                        nb_replacement_didn_t_remove += 1
-                        self.warned = True
-        fgraph.remove_feature(u)
-        if fgraph.profile:
-            validate_time = fgraph.profile.validate_time - validate_before
-            callback_time = fgraph.execute_callbacks_time - callback_before
-            callbacks_time = {}
-            for k, v in fgraph.execute_callbacks_times.items():
-                if k in callbacks_before:
-                    callbacks_time[k] = v - callbacks_before[k]
-                else:
-                    callbacks_time[k] = v
-        else:
-            validate_time = None
-            callback_time = None
-            callbacks_time = {}
-
-        return (
-            self,
-            nb_iter,
-            nb_replacement,
-            nb_replacement_didn_t_remove,
-            nb_inconsistency_make,
-            nb_inconsistency_replace,
-            time_canonicalize,
-            time_factor_can,
-            time_factor_list,
-            time_toposort,
-            validate_time,
-            callback_time,
-            callbacks_time,
-        )
-
-    @classmethod
-    def print_profile(cls, stream, prof, level=0):
-        blanc = "    " * level
-        print(blanc, cls.__name__, file=stream)
-        print(blanc, " nb_iter", prof[1], file=stream)
-        print(blanc, " nb_replacement", prof[2], file=stream)
-        print(blanc, " nb_replacement_didn_t_remove", prof[3], file=stream)
-        print(blanc, " nb_inconsistency_make", prof[4], file=stream)
-        print(blanc, " nb_inconsistency_replace", prof[5], file=stream)
-        print(blanc, " time_canonicalize", prof[6], file=stream)
-        print(blanc, " time_factor_can", prof[7], file=stream)
-        print(blanc, " time_factor_list", prof[8], file=stream)
-        print(blanc, " time_toposort", prof[9], file=stream)
-        print(blanc, " validate_time", prof[10], file=stream)
-        print(blanc, " callback_time", prof[11], file=stream)
-        if prof[11] > 1:
-            print(blanc, " callbacks_time", file=stream)
-            for i in sorted(prof[12].items(), key=lambda a: a[1]):
-                if i[1] > 0:
-                    print(i)
-
-
 class Dot22(GemmRelated):
     """Compute a matrix-matrix product.
 
@@ -1750,207 +1538,6 @@ def c_code_cache_version(self):
 _dot22 = Dot22()
 
 
-@node_rewriter([Dot])
-def local_dot_to_dot22(fgraph, node):
-    # This works for tensor.outer too because basic.outer is a macro that
-    # produces a dot(dimshuffle,dimshuffle) of form 4 below
-    if not isinstance(node.op, Dot):
-        return
-
-    if any(not isinstance(i.type, DenseTensorType) for i in node.inputs):
-        return False
-
-    x, y = node.inputs
-    if y.type.dtype != x.type.dtype:
-        # TODO: upcast one so the types match
-        _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}")
-        return
-
-    if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"):
-        if x.ndim == 2 and y.ndim == 2:
-            new_out = [_dot22(*node.inputs)]
-        elif x.ndim == 2 and y.ndim == 1:
-            new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)]
-        elif x.ndim == 1 and y.ndim == 2:
-            new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)]
-        elif x.ndim == 1 and y.ndim == 1:
-            new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()]
-        else:
-            return
-        copy_stack_trace(node.outputs, new_out)
-        return new_out
-
-    _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}")
-
-
-@node_rewriter([gemm_no_inplace], inplace=True)
-def local_inplace_gemm(fgraph, node):
-    if node.op == gemm_no_inplace:
-        new_out = [gemm_inplace(*node.inputs)]
-        copy_stack_trace(node.outputs, new_out)
-        return new_out
-
-
-@node_rewriter([gemv_no_inplace], inplace=True)
-def local_inplace_gemv(fgraph, node):
-    if node.op == gemv_no_inplace:
-        new_out = [gemv_inplace(*node.inputs)]
-        copy_stack_trace(node.outputs, new_out)
-        return new_out
-
-
-@node_rewriter([ger], inplace=True)
-def local_inplace_ger(fgraph, node):
-    if node.op == ger:
-        new_out = [ger_destructive(*node.inputs)]
-        copy_stack_trace(node.outputs, new_out)
-        return new_out
-
-
-@node_rewriter([gemm_no_inplace])
-def local_gemm_to_gemv(fgraph, node):
-    """GEMM acting on row or column matrices -> GEMV."""
-    if node.op == gemm_no_inplace:
-        z, a, x, y, b = node.inputs
-        if z.broadcastable == x.broadcastable == (True, False):
-            r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b)
-            new_out = [r.dimshuffle("x", 0)]
-        elif z.broadcastable == y.broadcastable == (False, True):
-            r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
-            new_out = [r.dimshuffle(0, "x")]
-        else:
-            return
-        copy_stack_trace(node.outputs, new_out)
-        return new_out
-
-
-@node_rewriter([gemm_no_inplace])
-def local_gemm_to_ger(fgraph, node):
-    """GEMM computing an outer-product -> GER."""
-    if node.op == gemm_no_inplace:
-        z, a, x, y, b = node.inputs
-        if x.broadcastable[1] and y.broadcastable[0]:
-            # x and y are both vectors so this might qualifies for a GER
-            xv = x.dimshuffle(0)
-            yv = y.dimshuffle(1)
-            try:
-                bval = at.get_underlying_scalar_constant_value(b)
-            except NotScalarConstantError:
-                # b isn't a constant, GEMM is doing useful pre-scaling
-                return
-
-            if bval == 1:  # best case a natural GER
-                rval = ger(z, a, xv, yv)
-                new_out = [rval]
-            elif bval == 0:  # GER on zeros_like should be faster than GEMM
-                zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype)
-                rval = ger(zeros, a, xv, yv)
-                new_out = [rval]
-            else:
-                # if bval is another constant, then z is being usefully
-                # pre-scaled and GER isn't really the right tool for the job.
-                return
-            copy_stack_trace(node.outputs, new_out)
-            return new_out
-
-
-# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
-#      working
-@node_rewriter([_dot22])
-def local_dot22_to_ger_or_gemv(fgraph, node):
-    """dot22 computing an outer-product -> GER."""
-    if node.op == _dot22:
-        x, y = node.inputs
-        xb = x.broadcastable
-        yb = y.broadcastable
-        one = at.as_tensor_variable(np.asarray(1, dtype=x.dtype))
-        zero = at.as_tensor_variable(np.asarray(0, dtype=x.dtype))
-        if xb[1] and yb[0]:
-            # x and y are both vectors so this might qualifies for a GER
-            xv = x.dimshuffle(0)
-            yv = y.dimshuffle(1)
-            zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
-            rval = ger(zeros, one, xv, yv)
-            new_out = [rval]
-        elif xb[0] and yb[1]:
-            # x and y are both vectors so this qualifies for a sdot / ddot
-            # TODO: PyTensor doesn't have a sdot, but gemv is better than _dot22
-            xv = x.dimshuffle(1)
-            zeros = at.AllocEmpty(x.dtype)(1)
-            rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
-            new_out = [rval.dimshuffle("x", 0)]
-        elif xb[0] and not yb[0] and not yb[1]:
-            # x is vector, y is matrix so try gemv
-            xv = x.dimshuffle(1)
-            zeros = at.AllocEmpty(x.dtype)(y.shape[1])
-            rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
-            new_out = [rval.dimshuffle("x", 0)]
-        elif not xb[0] and not xb[1] and yb[1]:
-            # x is matrix, y is vector, try gemv
-            yv = y.dimshuffle(0)
-            zeros = at.AllocEmpty(x.dtype)(x.shape[0])
-            rval = gemv_no_inplace(zeros, one, x, yv, zero)
-            new_out = [rval.dimshuffle(0, "x")]
-        else:
-            return
-        copy_stack_trace(node.outputs, new_out)
-        return new_out
-
-
-#################################
-#
-# Set up the BlasOpt optimizer
-#
-#################################
-
-blas_optdb = SequenceDB()
-
-# run after numerical stability optimizations (1.5)
-optdb.register("BlasOpt", blas_optdb, "fast_run", "fast_compile", position=1.7)
-# run before specialize (2.0) because specialize is basically a
-# free-for-all that makes the graph crazy.
-
-# fast_compile is needed to have GpuDot22 created.
-blas_optdb.register(
-    "local_dot_to_dot22",
-    in2out(local_dot_to_dot22),
-    "fast_run",
-    "fast_compile",
-    position=0,
-)
-blas_optdb.register("gemm_optimizer", GemmOptimizer(), "fast_run", position=10)
-blas_optdb.register(
-    "local_gemm_to_gemv",
-    EquilibriumGraphRewriter(
-        [
-            local_gemm_to_gemv,
-            local_gemm_to_ger,
-            local_dot22_to_ger_or_gemv,
-            local_dimshuffle_lift,
-        ],
-        max_use_ratio=5,
-        ignore_newtrees=False,
-    ),
-    "fast_run",
-    position=15,
-)
-
-
-# After destroyhandler(49.5) but before we try to make elemwise things
-# inplace (75)
-blas_opt_inplace = in2out(
-    local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace"
-)
-optdb.register(
-    "InplaceBlasOpt",
-    blas_opt_inplace,
-    "fast_run",
-    "inplace",
-    "blas_opt_inplace",
-    position=70.0,
-)
-
-
 class Dot22Scalar(GemmRelated):
     """Compute a matrix-matrix product.
 
@@ -2049,133 +1636,6 @@ def c_code_cache_version(self):
 _dot22scalar = Dot22Scalar()
 
 
-@node_rewriter([mul])
-def local_dot22_to_dot22scalar(fgraph, node):
-    """
-    Notes
-    -----
-    Previous attempts to alter this optimization to replace dot22 with
-    gemm instead of dot22scalar resulted in some Scan nodes being
-    duplicated and the ScanSaveMem optimization never running on them,
-    resulting in highly increased memory usage. Until this issue is
-    resolved, this optimization should keep using dot22scalar instead of
-    gemm.
-
-    We upcast the scalar if after the multiplication with the dot this give
-    the same type.
-
-    We execute this optimizer after the gemm optimizer. This
-    allow to give more priority to gemm that give more speed up
-    then this optimizer, but allow the gemm optimizer to ignore
-    this op.
-
-    TODO: support when we can reorder the mul to generate a
-    dot22scalar or fix the canonizer to merge them(1 mul with multiple
-    inputs)
-
-    """
-    if node.op != mul:
-        return False
-    i_dot22 = [x.owner and x.owner.op == _dot22 for x in node.inputs]
-    if not any(i_dot22):
-        return False  # no dot22
-    if i_dot22.count(True) > 1:
-        # TODO: try each of them.
-        pass
-        # return False #TODO fix
-    dot22_idx = i_dot22.index(True)
-    d = node.inputs[dot22_idx]
-    i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs]
-    if not any(i_scalar):
-        # Check if we can reorder the graph as this mul have a mul in inputs.
-        # We support only 1 additional level of mul.
-        # The canonizer should have merged those mul together.
-        i_mul = [
-            x.owner
-            and x.owner.op == mul
-            and any(_as_scalar(x_i, dtype=d.dtype) for x_i in x.owner.inputs)
-            for x in node.inputs
-        ]
-        if not any(i_mul):
-            # no scalar in input and no multiplication
-            # if their was a multiplication we couls reorder the graph
-            # by the associativity of the graph.
-            return False
-
-        mul_idx = i_mul.index(True)  # The first one should always work
-        m = node.inputs[mul_idx]
-
-        scalar_idx = -1
-        for i, x in enumerate(m.owner.inputs):
-            if _as_scalar(x, dtype=d.dtype) and (
-                pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype
-            ):
-                scalar_idx = i
-                break
-
-        if scalar_idx < 0:
-            _logger.info(
-                f"Not optimizing dot22 with inputs {node.inputs} {[x.type for x in node.inputs]}, as the"
-                " type of the scalar cannot be upcasted to the"
-                " matrix type"
-            )
-            return False
-        a = at.cast(_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype)
-        assert not a.type.ndim
-        dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
-
-        # The other inputs to the original node that were
-        # neither part of the dot22 or this mul should be
-        # factors in the returned "mul" node.
-        assert dot22_idx != mul_idx
-        other_factors = [
-            inpt for i, inpt in enumerate(node.inputs) if i not in (dot22_idx, mul_idx)
-        ]
-        other_m_inputs = [
-            inpt for i, inpt in enumerate(m.owner.inputs) if i != scalar_idx
-        ]
-
-        return [mul(dot, *(other_factors + other_m_inputs))]
-
-    scalar_idx = -1
-    for i, x in enumerate(node.inputs):
-        if (
-            i != dot22_idx
-            and i_scalar[i] is not None
-            and (pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype)
-        ):
-            scalar_idx = i
-            break
-    if scalar_idx < 0:
-        _logger.info(
-            f"Not optimizing dot22 with inputs {node.inputs} {[x.type for x in node.inputs]}, as the type "
-            "of the scalar cannot be upcasted to the matrix type"
-        )
-        return False
-    assert scalar_idx < len(node.inputs)
-    s = node.inputs[scalar_idx]
-    o = copy.copy(node.inputs)
-    o.remove(d)
-    o.remove(s)
-
-    a = at.cast(i_scalar[scalar_idx], d.type.dtype)
-    assert not a.type.ndim
-    if len(o) == 0:
-        return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)]
-    else:
-        return [mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a), *o)]
-
-
-# must happen after gemm as the gemm optimizer don't understant
-# dot22scalar and gemm give more speed up then dot22scalar
-blas_optdb.register(
-    "local_dot22_to_dot22scalar",
-    in2out(local_dot22_to_dot22scalar),
-    "fast_run",
-    position=11,
-)
-
-
 class BatchedDot(COp):
     """
     Computes the batched dot product of two variables:
@@ -2669,14 +2129,6 @@ def infer_shape(self, fgraph, node, shapes):
 _batched_dot = BatchedDot()
 
 
-# from opt import register_specialize, register_canonicalize
-# @register_specialize
-@node_rewriter([sub, add])
-def local_print_as_we_go_along(fgraph, node):
-    if node.op in (sub, add):
-        debugprint(node)
-
-
 def batched_dot(a, b):
     """Compute the batched dot product of two variables.
 
diff --git a/pytensor/tensor/blas_c.py b/pytensor/tensor/blas_c.py
index e4e90066b0..704970b5ef 100644
--- a/pytensor/tensor/blas_c.py
+++ b/pytensor/tensor/blas_c.py
@@ -1,22 +1,12 @@
-from pytensor.configdefaults import config
-from pytensor.graph.rewriting.basic import in2out
 from pytensor.link.c.op import COp
 from pytensor.link.c.params_type import ParamsType
 from pytensor.scalar import bool as bool_t
-from pytensor.tensor import basic as at
 from pytensor.tensor.blas import (
     Gemv,
     Ger,
     blas_header_text,
     blas_header_version,
-    blas_optdb,
-    gemv_inplace,
-    gemv_no_inplace,
-    ger,
-    ger_destructive,
     ldflags,
-    node_rewriter,
-    optdb,
 )
 
 
@@ -344,23 +334,6 @@ def c_code_cache_version(self):
 cger_no_inplace = CGer(False)
 
 
-@node_rewriter([ger, ger_destructive])
-def use_c_ger(fgraph, node):
-    if not config.blas__ldflags:
-        return
-    # Only float32 and float64 are supported for now.
-    if node.op == ger and node.outputs[0].dtype in ("float32", "float64"):
-        return [CGer(False)(*node.inputs)]
-    if node.op == ger_destructive and node.outputs[0].dtype in ("float32", "float64"):
-        return [CGer(True)(*node.inputs)]
-
-
-@node_rewriter([CGer(False)])
-def make_c_ger_destructive(fgraph, node):
-    if isinstance(node.op, CGer) and not node.op.destructive:
-        return [cger_inplace(*node.inputs)]
-
-
 # ##### ####### #######
 # GEMV
 # ##### ####### #######
@@ -697,48 +670,3 @@ def check_force_gemv_init():
 
 
 check_force_gemv_init._force_init_beta = None
-
-
-@node_rewriter([gemv_inplace, gemv_no_inplace])
-def use_c_gemv(fgraph, node):
-    if not config.blas__ldflags:
-        return
-    # Only float32 and float64 are supported for now.
-    if node.op == gemv_no_inplace and node.outputs[0].dtype in ("float32", "float64"):
-        return [cgemv_no_inplace(*node.inputs)]
-    if node.op == gemv_inplace and node.outputs[0].dtype in ("float32", "float64"):
-        return [cgemv_inplace(*node.inputs)]
-
-
-@node_rewriter([CGemv(inplace=False)])
-def make_c_gemv_destructive(fgraph, node):
-    if isinstance(node.op, CGemv) and not node.op.inplace:
-        inputs = list(node.inputs)
-        dest = inputs[0]
-        if (
-            dest.owner
-            and isinstance(dest.owner.op, at.AllocEmpty)
-            and len(fgraph.clients[dest]) > 1
-        ):
-            inputs[0] = at.AllocEmpty(dest.dtype)(*dest.owner.inputs)
-
-        return [cgemv_inplace(*inputs)]
-
-
-# ##### ####### #######
-# Optimizers
-# ##### ####### #######
-
-blas_optdb.register(
-    "use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20
-)
-
-# this matches the InplaceBlasOpt defined in blas.py
-optdb.register(
-    "c_blas_destructive",
-    in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"),
-    "fast_run",
-    "inplace",
-    "c_blas",
-    position=70.0,
-)
diff --git a/pytensor/tensor/blas_scipy.py b/pytensor/tensor/blas_scipy.py
index 4d1be6e322..527d5150a1 100644
--- a/pytensor/tensor/blas_scipy.py
+++ b/pytensor/tensor/blas_scipy.py
@@ -4,16 +4,7 @@
 
 import numpy as np
 
-from pytensor.graph.rewriting.basic import in2out
-from pytensor.tensor.blas import (
-    Ger,
-    blas_optdb,
-    ger,
-    ger_destructive,
-    have_fblas,
-    node_rewriter,
-    optdb,
-)
+from pytensor.tensor.blas import Ger, have_fblas
 
 
 if have_fblas:
@@ -56,36 +47,3 @@ def perform(self, node, inputs, output_storage):
 
 scipy_ger_no_inplace = ScipyGer(False)
 scipy_ger_inplace = ScipyGer(True)
-
-
-@node_rewriter([ger, ger_destructive])
-def use_scipy_ger(fgraph, node):
-    if node.op == ger:
-        return [scipy_ger_no_inplace(*node.inputs)]
-
-
-@node_rewriter([scipy_ger_no_inplace])
-def make_ger_destructive(fgraph, node):
-    if node.op == scipy_ger_no_inplace:
-        return [scipy_ger_inplace(*node.inputs)]
-
-
-use_scipy_blas = in2out(use_scipy_ger)
-make_scipy_blas_destructive = in2out(make_ger_destructive)
-
-if have_fblas:
-    # scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
-    # sucks, but it is almost always present.
-    # C implementations should be scheduled earlier than this, so that they take
-    # precedence. Once the original Ger is replaced, then these optimizations
-    # have no effect.
-    blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
-
-    # this matches the InplaceBlasOpt defined in blas.py
-    optdb.register(
-        "make_scipy_blas_destructive",
-        make_scipy_blas_destructive,
-        "fast_run",
-        "inplace",
-        position=70.0,
-    )
diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py
index cb244afb7e..80946d524c 100644
--- a/pytensor/tensor/rewriting/__init__.py
+++ b/pytensor/tensor/rewriting/__init__.py
@@ -1,9 +1,13 @@
 import pytensor.tensor.rewriting.basic
+import pytensor.tensor.rewriting.blas
+import pytensor.tensor.rewriting.blas_c
+import pytensor.tensor.rewriting.blas_scipy
 import pytensor.tensor.rewriting.elemwise
 import pytensor.tensor.rewriting.extra_ops
 
 # Register JAX specializations
 import pytensor.tensor.rewriting.jax
+import pytensor.tensor.rewriting.linalg
 import pytensor.tensor.rewriting.math
 import pytensor.tensor.rewriting.shape
 import pytensor.tensor.rewriting.special
diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py
new file mode 100644
index 0000000000..a310cb5837
--- /dev/null
+++ b/pytensor/tensor/rewriting/blas.py
@@ -0,0 +1,907 @@
+"""optimizations for using BLAS calls
+
+Optimizations
+=============
+
+The optimization pipeline works something like this:
+
+    1. identify dot22 from dot
+    2. identify gemm from dot22
+    3. identify dot22scalar from dot22 that are not gemm
+    4. specialize gemm to gemv where applicable
+    5. specialize gemm to ger where applicable
+    6. specialize dot22 -> gemv or ger where applicable
+
+:note: GEMM is the most canonical BLAS signature that we deal with so far, it
+    would be good to turn most things into GEMM (dot, inner, outer, dot22,
+    dot22scalar), and then to specialize from gemm to the various other L2 and
+    L3 operations.
+
+Identify Dot22
+--------------
+
+Numpy's dot supports arguments that are of any rank, and we should support that
+too (just for compatibility).  The BLAS optimizations work with Dot Ops whose
+inputs are each either vector or matrix.  So the first part of the optimization
+pipeline is to transform qualifying Dot Ops to Dot22 Ops. Dot22 Ops may be
+transformed further, but they will get implemented by a BLAS call.
+
+More precisely, Dot nodes whose inputs are all vectors or matrices and whose
+inputs both have the same dtype, and whose dtype is float or complex, become
+Dot22.  This is implemented in `local_dot_to_dot22`.
+
+
+Identify Gemm from Dot22
+------------------------
+
+This is complicated, done in GemmOptimizer.
+
+Identify Dot22Scalar from Dot22
+-------------------------------
+
+Dot22 Ops that remain after the GemmOptimizer is done have not
+qualified as GEMM Ops. Still they might be scaled by a factor, in
+which case we use Dot22Scalar which is like Gemm, but without the b
+and the Z.  In the future it would be good to merge this into the
+GemmOptimizer.
+
+Specialize Gemm to Gemv
+-----------------------
+
+If arguments to GEMM are dimshuffled vectors, then we can use GEMV
+instead. This optimization is `local_gemm_to_gemv`.
+
+"""
+
+import copy
+import logging
+import time
+
+import numpy as np
+
+
+try:
+    import numpy.__config__  # noqa
+except ImportError:
+    pass
+
+
+import pytensor.scalar
+from pytensor.compile.mode import optdb
+from pytensor.configdefaults import config
+from pytensor.graph.features import ReplacementDidNotRemoveError, ReplaceValidate
+from pytensor.graph.rewriting.basic import (
+    EquilibriumGraphRewriter,
+    GraphRewriter,
+    copy_stack_trace,
+    in2out,
+    node_rewriter,
+)
+from pytensor.graph.rewriting.db import SequenceDB
+from pytensor.graph.utils import InconsistencyError
+from pytensor.printing import debugprint
+from pytensor.tensor import basic as at
+from pytensor.tensor.blas import (
+    Dot22,
+    _dot22,
+    _dot22scalar,
+    gemm_inplace,
+    gemm_no_inplace,
+    gemv_inplace,
+    gemv_no_inplace,
+    ger,
+    ger_destructive,
+)
+from pytensor.tensor.elemwise import DimShuffle, Elemwise
+from pytensor.tensor.exceptions import NotScalarConstantError
+from pytensor.tensor.math import Dot, add, mul, neg, sub
+from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
+from pytensor.tensor.type import (
+    DenseTensorType,
+    TensorType,
+    integer_dtypes,
+    values_eq_approx_remove_inf_nan,
+)
+
+
+_logger = logging.getLogger("pytensor.tensor.rewriting.blas")
+
+
+def res_is_a(fgraph, var, op, maxclients=None):
+    if maxclients is not None and var in fgraph.clients:
+        retval = len(fgraph.get_clients(var)) <= maxclients
+    else:
+        retval = True
+
+    return var.owner and var.owner.op == op and retval
+
+
+def _as_scalar(res, dtype=None):
+    """Return ``None`` or a `TensorVariable` of float type"""
+    if dtype is None:
+        dtype = config.floatX
+    if all(s == 1 for s in res.type.shape):
+        while res.owner and isinstance(res.owner.op, DimShuffle):
+            res = res.owner.inputs[0]
+        # may still have some number of True's
+        if res.type.ndim > 0:
+            rval = res.dimshuffle()
+        else:
+            rval = res
+        if rval.type.dtype in integer_dtypes:
+            # We check that the upcast of res and dtype won't change dtype.
+            # If dtype is float64, we will cast int64 to float64.
+            # This is valid when res is a scalar used as input to a dot22
+            # as the cast of the scalar can be done before or after the dot22
+            # and this will give the same result.
+            if pytensor.scalar.upcast(res.dtype, dtype) == dtype:
+                return at.cast(rval, dtype)
+            else:
+                return None
+
+        return rval
+
+
+def _is_real_matrix(res):
+    return (
+        res.type.dtype in ("float16", "float32", "float64")
+        and res.type.ndim == 2
+        and res.type.shape[0] != 1
+        and res.type.shape[1] != 1
+    )  # cope with tuple vs. list
+
+
+def _is_real_vector(res):
+    return (
+        res.type.dtype in ("float16", "float32", "float64")
+        and res.type.ndim == 1
+        and res.type.shape[0] != 1
+    )
+
+
+def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True):
+    # print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
+    # EXPRESSION: (beta * L) + (alpha * M)
+
+    # we've already checked the client counts, now just make the type check.
+    # if res_is_a(M, _dot22, 1):
+    if M.owner and M.owner.op == _dot22:
+        Ml, Mr = M.owner.inputs
+        rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
+        return rval, M
+
+    # it also might be the case that there is a dimshuffle between the +
+    # and the dot22. local_dot_to_dot22 in particular will put in such things.
+    if (
+        M.owner
+        and isinstance(M.owner.op, DimShuffle)
+        and M.owner.inputs[0].owner
+        and isinstance(M.owner.inputs[0].owner.op, Dot22)
+    ):
+        MM = M.owner.inputs[0]
+        if M.owner.op.new_order == (0,):
+            # it is making a column MM into a vector
+            MMl, MMr = MM.owner.inputs
+            g = gemm_no_inplace(L.dimshuffle(0, "x"), alpha, MMl, MMr, beta)
+            rval = [g.dimshuffle(0)]
+            return rval, MM
+        if M.owner.op.new_order == (1,):
+            # it is making a row MM into a vector
+            MMl, MMr = MM.owner.inputs
+            g = gemm_no_inplace(L.dimshuffle("x", 0), alpha, MMl, MMr, beta)
+            rval = [g.dimshuffle(1)]
+            return rval, MM
+        if len(M.owner.op.new_order) == 0:
+            # it is making a row MM into a vector
+            MMl, MMr = MM.owner.inputs
+            g = gemm_no_inplace(L.dimshuffle("x", "x"), alpha, MMl, MMr, beta)
+            rval = [g.dimshuffle()]
+            return rval, MM
+
+    if recurse_flip:
+        return _beta_L_plus_alpha_M(fgraph, alpha, M, beta, L, recurse_flip=False)
+    else:
+        return False, False
+
+
+def _gemm_canonicalize(fgraph, r, scale, rval, maxclients):
+    # Tries to interpret node as a sum of scalars * (vectors or matrices)
+    def scaled(thing):
+        if scale == 1:
+            return thing
+        if scale == -1 and thing.type.dtype != "bool":
+            return -thing
+        else:
+            return scale * thing
+
+    if not isinstance(r.type, TensorType):
+        return None
+
+    if (r.type.ndim not in (1, 2)) or r.type.dtype not in (
+        "float16",
+        "float32",
+        "float64",
+        "complex64",
+        "complex128",
+    ):
+        rval.append(scaled(r))
+        return rval
+
+    if maxclients and len(fgraph.clients[r]) > maxclients:
+        rval.append((scale, r))
+        return rval
+
+    if r.owner and r.owner.op == sub:
+        _gemm_canonicalize(fgraph, r.owner.inputs[0], scale, rval, 1)
+        _gemm_canonicalize(fgraph, r.owner.inputs[1], -scale, rval, 1)
+
+    elif r.owner and r.owner.op == add:
+        for i in r.owner.inputs:
+            _gemm_canonicalize(fgraph, i, scale, rval, 1)
+
+    elif r.owner and r.owner.op == neg:
+        _gemm_canonicalize(fgraph, r.owner.inputs[0], -scale, rval, 1)
+
+    elif r.owner and r.owner.op == mul:
+        scalars = []
+        vectors = []
+        matrices = []
+        for i in r.owner.inputs:
+            if all(s == 1 for s in i.type.shape):
+                while i.owner and isinstance(i.owner.op, DimShuffle):
+                    i = i.owner.inputs[0]
+                if i.type.ndim > 0:
+                    scalars.append(i.dimshuffle())
+                else:
+                    scalars.append(i)
+            elif _is_real_vector(i):
+                vectors.append(i)
+            elif _is_real_matrix(i):
+                matrices.append(i)
+            else:
+                # just put the original arguments as in the base case
+                rval.append((scale, r))
+                return rval
+        if len(matrices) == 1:
+            assert len(vectors) == 0
+            m = matrices[0]
+            if len(scalars) == 0:
+                _gemm_canonicalize(fgraph, m, scale, rval, 1)
+            elif len(scalars) == 1:
+                _gemm_canonicalize(fgraph, m, scaled(scalars[0]), rval, 1)
+            else:
+                _gemm_canonicalize(
+                    fgraph, m, mul(scaled(scalars[0]), *scalars[1:]), rval, 1
+                )
+        elif len(vectors) == 1:
+            assert len(matrices) == 0
+            v = vectors[0]
+            if len(scalars) == 0:
+                _gemm_canonicalize(fgraph, v, scale, rval, 1)
+            elif len(scalars) == 1:
+                _gemm_canonicalize(fgraph, v, scaled(scalars[0]), rval, 1)
+            else:
+                _gemm_canonicalize(
+                    fgraph, v, mul(scaled(scalars[0]), *scalars[1:]), rval, 1
+                )
+        else:  # lets not open this up
+            rval.append((scale, r))
+    else:
+        rval.append((scale, r))
+    return rval
+
+
+def _factor_canonicalized(lst):
+    # remove duplicates from canonicalized list
+
+    # we only delete out of the right end of the list,
+    # once i has touched a list element, it is permantent
+    lst = list(lst)
+    # print 'FACTOR', lst
+    # for t in lst:
+    #    if not isinstance(t, (list, tuple)):
+    #        t = (t,)
+    #    for e in t:
+    #        try:
+    #            pytensor.printing.debugprint(e)
+    #        except TypeError:
+    #            print e, type(e)
+    i = 0
+    while i < len(lst) - 1:
+        try:
+            s_i, M_i = lst[i]
+        except Exception:
+            i += 1
+            continue
+
+        j = i + 1
+        while j < len(lst):
+            try:
+                s_j, M_j = lst[j]
+            except Exception:
+                j += 1
+                continue
+
+            if M_i is M_j:
+                s_i = s_i + s_j
+                lst[i] = (s_i, M_i)
+                del lst[j]
+            else:
+                j += 1
+        i += 1
+    return lst
+
+
+def _gemm_from_factored_list(fgraph, lst):
+    """
+    Returns None, or a list to replace node.outputs.
+
+    """
+    lst2 = []
+    # Remove the tuple that can't be cast correctly.
+    # This can happen when we try to cast a complex to a real
+    for sM in lst:
+        # Make every pair in list have matching dtypes
+        # sM can be a tuple of 2 elements or an PyTensor variable.
+        if isinstance(sM, tuple):
+            sm0, sm1 = sM
+            sm0 = at.as_tensor_variable(sm0)
+            if pytensor.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype:
+                lst2.append((at.cast(sm0, sm1.dtype), sM[1]))
+
+    lst = lst2
+
+    def item_to_var(t):
+        try:
+            s, M = t
+        except Exception:
+            return t
+        if s == 1:
+            return M
+        if s == -1:
+            return -M
+        return s * M
+
+    # Try every pair in the sM_list, trying to turn it into a gemm operation
+    for i in range(len(lst) - 1):
+        s_i, M_i = lst[i]
+
+        for j in range(i + 1, len(lst)):
+            s_j, M_j = lst[j]
+
+            if not M_j.type.in_same_class(M_i.type):
+                continue
+
+            # print 'TRYING', (s_i, M_i, s_j, M_j)
+
+            gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M(
+                fgraph, s_i, M_i, s_j, M_j
+            )
+            # print 'GOT IT', gemm_of_sM_list
+            if gemm_of_sM_list:
+                assert len(gemm_of_sM_list) == 1
+                add_inputs = [
+                    item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
+                ]
+                add_inputs.extend(gemm_of_sM_list)
+                if len(add_inputs) > 1:
+                    rval = [add(*add_inputs)]
+                else:
+                    rval = add_inputs
+                # print "RETURNING GEMM THING", rval
+                return rval, old_dot22
+
+
+def _gemm_from_node2(fgraph, node):
+    """
+
+    TODO: In many expressions, there are many ways to turn it into a
+    gemm.  For example dot(a,b) + c + d.  This function should return all
+    of them, so that if one version of gemm causes a cycle in the graph, then
+    another application of gemm can be tried.
+
+    """
+    lst = []
+    t0 = time.perf_counter()
+    _gemm_canonicalize(fgraph, node.outputs[0], 1.0, lst, 0)
+    t1 = time.perf_counter()
+
+    if len(lst) > 1:
+        lst = _factor_canonicalized(lst)
+        t2 = time.perf_counter()
+        rval = _gemm_from_factored_list(fgraph, lst)
+        t3 = time.perf_counter()
+
+        # It can happen that _factor_canonicalized and
+        # _gemm_from_factored_list return a node with an incorrect
+        # type.  This happens in particular when one of the scalar
+        # factors forces the upcast of the whole expression.  In that
+        # case, we simply skip that candidate for Gemm.  This was
+        # discussed in
+        # http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5,
+        # but never made it into a trac ticket.
+
+        if rval and rval[0][0].type.in_same_class(node.outputs[0].type):
+            return rval, t1 - t0, t2 - t1, t3 - t2
+
+    return None, t1 - t0, 0, 0
+
+
+class GemmOptimizer(GraphRewriter):
+    """Graph optimizer for inserting Gemm operations."""
+
+    def __init__(self):
+        super().__init__()
+        self.warned = False
+
+    def add_requirements(self, fgraph):
+        fgraph.attach_feature(ReplaceValidate())
+
+    def apply(self, fgraph):
+        did_something = True
+        nb_iter = 0
+        nb_replacement = 0
+        nb_replacement_didn_t_remove = 0
+        nb_inconsistency_make = 0
+        nb_inconsistency_replace = 0
+        time_canonicalize = 0
+        time_factor_can = 0
+        time_factor_list = 0
+        time_toposort = 0
+        if fgraph.profile:
+            validate_before = fgraph.profile.validate_time
+            callbacks_before = fgraph.execute_callbacks_times.copy()
+            callback_before = fgraph.execute_callbacks_time
+
+        def on_import(new_node):
+            if new_node is not node:
+                nodelist.append(new_node)
+
+        u = pytensor.graph.rewriting.basic.DispatchingFeature(
+            on_import, None, None, name="GemmOptimizer"
+        )
+        fgraph.attach_feature(u)
+        while did_something:
+            nb_iter += 1
+            t0 = time.perf_counter()
+            nodelist = pytensor.graph.basic.io_toposort(fgraph.inputs, fgraph.outputs)
+            time_toposort += time.perf_counter() - t0
+            did_something = False
+            nodelist.reverse()
+            for node in nodelist:
+                if not (
+                    isinstance(node.op, Elemwise)
+                    and isinstance(
+                        node.op.scalar_op,
+                        (
+                            pytensor.scalar.Add,
+                            pytensor.scalar.Sub,
+                            pytensor.scalar.Neg,
+                            pytensor.scalar.Mul,
+                        ),
+                    )
+                ):
+                    continue
+                if node not in fgraph.apply_nodes:
+                    # This mean that we already removed this node from
+                    # the graph
+                    continue
+                try:
+                    new_outputs, time1, time2, time3 = _gemm_from_node2(fgraph, node)
+                    time_canonicalize += time1
+                    time_factor_can += time2
+                    time_factor_list += time3
+                except InconsistencyError:
+                    nb_inconsistency_make += 1
+                    continue
+                if new_outputs:
+                    new_outputs, old_dot22 = new_outputs
+                    assert len(new_outputs) == len(node.outputs)
+                    new_outputs[
+                        0
+                    ].tag.values_eq_approx = values_eq_approx_remove_inf_nan
+                    try:
+                        fgraph.replace_all_validate_remove(
+                            list(zip(node.outputs, new_outputs)),
+                            [old_dot22],
+                            reason="GemmOptimizer",
+                            # For now we disable the warning as we know case
+                            # that we need to fix.
+                            warn=False,  # warn=not self.warned
+                        )
+                        did_something = True
+                        nb_replacement += 1
+                    except InconsistencyError:
+                        # TODO: retry other applications of gemm (see comment
+                        # in _gemm_from_node)
+                        nb_inconsistency_replace += 1
+                    except ReplacementDidNotRemoveError:
+                        nb_replacement_didn_t_remove += 1
+                        self.warned = True
+        fgraph.remove_feature(u)
+        if fgraph.profile:
+            validate_time = fgraph.profile.validate_time - validate_before
+            callback_time = fgraph.execute_callbacks_time - callback_before
+            callbacks_time = {}
+            for k, v in fgraph.execute_callbacks_times.items():
+                if k in callbacks_before:
+                    callbacks_time[k] = v - callbacks_before[k]
+                else:
+                    callbacks_time[k] = v
+        else:
+            validate_time = None
+            callback_time = None
+            callbacks_time = {}
+
+        return (
+            self,
+            nb_iter,
+            nb_replacement,
+            nb_replacement_didn_t_remove,
+            nb_inconsistency_make,
+            nb_inconsistency_replace,
+            time_canonicalize,
+            time_factor_can,
+            time_factor_list,
+            time_toposort,
+            validate_time,
+            callback_time,
+            callbacks_time,
+        )
+
+    @classmethod
+    def print_profile(cls, stream, prof, level=0):
+        blanc = "    " * level
+        print(blanc, cls.__name__, file=stream)
+        print(blanc, " nb_iter", prof[1], file=stream)
+        print(blanc, " nb_replacement", prof[2], file=stream)
+        print(blanc, " nb_replacement_didn_t_remove", prof[3], file=stream)
+        print(blanc, " nb_inconsistency_make", prof[4], file=stream)
+        print(blanc, " nb_inconsistency_replace", prof[5], file=stream)
+        print(blanc, " time_canonicalize", prof[6], file=stream)
+        print(blanc, " time_factor_can", prof[7], file=stream)
+        print(blanc, " time_factor_list", prof[8], file=stream)
+        print(blanc, " time_toposort", prof[9], file=stream)
+        print(blanc, " validate_time", prof[10], file=stream)
+        print(blanc, " callback_time", prof[11], file=stream)
+        if prof[11] > 1:
+            print(blanc, " callbacks_time", file=stream)
+            for i in sorted(prof[12].items(), key=lambda a: a[1]):
+                if i[1] > 0:
+                    print(i)
+
+
+@node_rewriter([Dot])
+def local_dot_to_dot22(fgraph, node):
+    # This works for tensor.outer too because basic.outer is a macro that
+    # produces a dot(dimshuffle,dimshuffle) of form 4 below
+    if not isinstance(node.op, Dot):
+        return
+
+    if any(not isinstance(i.type, DenseTensorType) for i in node.inputs):
+        return False
+
+    x, y = node.inputs
+    if y.type.dtype != x.type.dtype:
+        # TODO: upcast one so the types match
+        _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}")
+        return
+
+    if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"):
+        if x.ndim == 2 and y.ndim == 2:
+            new_out = [_dot22(*node.inputs)]
+        elif x.ndim == 2 and y.ndim == 1:
+            new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)]
+        elif x.ndim == 1 and y.ndim == 2:
+            new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)]
+        elif x.ndim == 1 and y.ndim == 1:
+            new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()]
+        else:
+            return
+        copy_stack_trace(node.outputs, new_out)
+        return new_out
+
+    _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}")
+
+
+@node_rewriter([gemm_no_inplace], inplace=True)
+def local_inplace_gemm(fgraph, node):
+    if node.op == gemm_no_inplace:
+        new_out = [gemm_inplace(*node.inputs)]
+        copy_stack_trace(node.outputs, new_out)
+        return new_out
+
+
+@node_rewriter([gemv_no_inplace], inplace=True)
+def local_inplace_gemv(fgraph, node):
+    if node.op == gemv_no_inplace:
+        new_out = [gemv_inplace(*node.inputs)]
+        copy_stack_trace(node.outputs, new_out)
+        return new_out
+
+
+@node_rewriter([ger], inplace=True)
+def local_inplace_ger(fgraph, node):
+    if node.op == ger:
+        new_out = [ger_destructive(*node.inputs)]
+        copy_stack_trace(node.outputs, new_out)
+        return new_out
+
+
+@node_rewriter([gemm_no_inplace])
+def local_gemm_to_gemv(fgraph, node):
+    """GEMM acting on row or column matrices -> GEMV."""
+    if node.op == gemm_no_inplace:
+        z, a, x, y, b = node.inputs
+        if z.broadcastable == x.broadcastable == (True, False):
+            r = gemv_no_inplace(z.dimshuffle(1), a, y.T, x.dimshuffle(1), b)
+            new_out = [r.dimshuffle("x", 0)]
+        elif z.broadcastable == y.broadcastable == (False, True):
+            r = gemv_no_inplace(z.dimshuffle(0), a, x, y.dimshuffle(0), b)
+            new_out = [r.dimshuffle(0, "x")]
+        else:
+            return
+        copy_stack_trace(node.outputs, new_out)
+        return new_out
+
+
+@node_rewriter([gemm_no_inplace])
+def local_gemm_to_ger(fgraph, node):
+    """GEMM computing an outer-product -> GER."""
+    if node.op == gemm_no_inplace:
+        z, a, x, y, b = node.inputs
+        if x.broadcastable[1] and y.broadcastable[0]:
+            # x and y are both vectors so this might qualifies for a GER
+            xv = x.dimshuffle(0)
+            yv = y.dimshuffle(1)
+            try:
+                bval = at.get_underlying_scalar_constant_value(b)
+            except NotScalarConstantError:
+                # b isn't a constant, GEMM is doing useful pre-scaling
+                return
+
+            if bval == 1:  # best case a natural GER
+                rval = ger(z, a, xv, yv)
+                new_out = [rval]
+            elif bval == 0:  # GER on zeros_like should be faster than GEMM
+                zeros = at.zeros([x.shape[0], y.shape[1]], x.dtype)
+                rval = ger(zeros, a, xv, yv)
+                new_out = [rval]
+            else:
+                # if bval is another constant, then z is being usefully
+                # pre-scaled and GER isn't really the right tool for the job.
+                return
+            copy_stack_trace(node.outputs, new_out)
+            return new_out
+
+
+# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
+#      working
+@node_rewriter([_dot22])
+def local_dot22_to_ger_or_gemv(fgraph, node):
+    """dot22 computing an outer-product -> GER."""
+    if node.op == _dot22:
+        x, y = node.inputs
+        xb = x.broadcastable
+        yb = y.broadcastable
+        one = at.as_tensor_variable(np.asarray(1, dtype=x.dtype))
+        zero = at.as_tensor_variable(np.asarray(0, dtype=x.dtype))
+        if xb[1] and yb[0]:
+            # x and y are both vectors so this might qualifies for a GER
+            xv = x.dimshuffle(0)
+            yv = y.dimshuffle(1)
+            zeros = at.zeros([x.shape[0], y.shape[1]], dtype=x.dtype)
+            rval = ger(zeros, one, xv, yv)
+            new_out = [rval]
+        elif xb[0] and yb[1]:
+            # x and y are both vectors so this qualifies for a sdot / ddot
+            # TODO: PyTensor doesn't have a sdot, but gemv is better than _dot22
+            xv = x.dimshuffle(1)
+            zeros = at.AllocEmpty(x.dtype)(1)
+            rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
+            new_out = [rval.dimshuffle("x", 0)]
+        elif xb[0] and not yb[0] and not yb[1]:
+            # x is vector, y is matrix so try gemv
+            xv = x.dimshuffle(1)
+            zeros = at.AllocEmpty(x.dtype)(y.shape[1])
+            rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
+            new_out = [rval.dimshuffle("x", 0)]
+        elif not xb[0] and not xb[1] and yb[1]:
+            # x is matrix, y is vector, try gemv
+            yv = y.dimshuffle(0)
+            zeros = at.AllocEmpty(x.dtype)(x.shape[0])
+            rval = gemv_no_inplace(zeros, one, x, yv, zero)
+            new_out = [rval.dimshuffle(0, "x")]
+        else:
+            return
+        copy_stack_trace(node.outputs, new_out)
+        return new_out
+
+
+#################################
+#
+# Set up the BlasOpt optimizer
+#
+#################################
+
+blas_optdb = SequenceDB()
+
+# run after numerical stability optimizations (1.5)
+optdb.register("BlasOpt", blas_optdb, "fast_run", "fast_compile", position=1.7)
+# run before specialize (2.0) because specialize is basically a
+# free-for-all that makes the graph crazy.
+
+# fast_compile is needed to have GpuDot22 created.
+blas_optdb.register(
+    "local_dot_to_dot22",
+    in2out(local_dot_to_dot22),
+    "fast_run",
+    "fast_compile",
+    position=0,
+)
+blas_optdb.register("gemm_optimizer", GemmOptimizer(), "fast_run", position=10)
+blas_optdb.register(
+    "local_gemm_to_gemv",
+    EquilibriumGraphRewriter(
+        [
+            local_gemm_to_gemv,
+            local_gemm_to_ger,
+            local_dot22_to_ger_or_gemv,
+            local_dimshuffle_lift,
+        ],
+        max_use_ratio=5,
+        ignore_newtrees=False,
+    ),
+    "fast_run",
+    position=15,
+)
+
+
+# After destroyhandler(49.5) but before we try to make elemwise things
+# inplace (75)
+blas_opt_inplace = in2out(
+    local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace"
+)
+optdb.register(
+    "InplaceBlasOpt",
+    blas_opt_inplace,
+    "fast_run",
+    "inplace",
+    "blas_opt_inplace",
+    position=70.0,
+)
+
+
+@node_rewriter([mul])
+def local_dot22_to_dot22scalar(fgraph, node):
+    """
+    Notes
+    -----
+    Previous attempts to alter this optimization to replace dot22 with
+    gemm instead of dot22scalar resulted in some Scan nodes being
+    duplicated and the ScanSaveMem optimization never running on them,
+    resulting in highly increased memory usage. Until this issue is
+    resolved, this optimization should keep using dot22scalar instead of
+    gemm.
+
+    We upcast the scalar if after the multiplication with the dot this give
+    the same type.
+
+    We execute this optimizer after the gemm optimizer. This
+    allow to give more priority to gemm that give more speed up
+    then this optimizer, but allow the gemm optimizer to ignore
+    this op.
+
+    TODO: support when we can reorder the mul to generate a
+    dot22scalar or fix the canonizer to merge them(1 mul with multiple
+    inputs)
+
+    """
+    if node.op != mul:
+        return False
+    i_dot22 = [x.owner and x.owner.op == _dot22 for x in node.inputs]
+    if not any(i_dot22):
+        return False  # no dot22
+    if i_dot22.count(True) > 1:
+        # TODO: try each of them.
+        pass
+        # return False #TODO fix
+    dot22_idx = i_dot22.index(True)
+    d = node.inputs[dot22_idx]
+    i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs]
+    if not any(i_scalar):
+        # Check if we can reorder the graph as this mul have a mul in inputs.
+        # We support only 1 additional level of mul.
+        # The canonizer should have merged those mul together.
+        i_mul = [
+            x.owner
+            and x.owner.op == mul
+            and any(_as_scalar(x_i, dtype=d.dtype) for x_i in x.owner.inputs)
+            for x in node.inputs
+        ]
+        if not any(i_mul):
+            # no scalar in input and no multiplication
+            # if their was a multiplication we couls reorder the graph
+            # by the associativity of the graph.
+            return False
+
+        mul_idx = i_mul.index(True)  # The first one should always work
+        m = node.inputs[mul_idx]
+
+        scalar_idx = -1
+        for i, x in enumerate(m.owner.inputs):
+            if _as_scalar(x, dtype=d.dtype) and (
+                pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype
+            ):
+                scalar_idx = i
+                break
+
+        if scalar_idx < 0:
+            _logger.info(
+                f"Not optimizing dot22 with inputs {node.inputs} {[x.type for x in node.inputs]}, as the"
+                " type of the scalar cannot be upcasted to the"
+                " matrix type"
+            )
+            return False
+        a = at.cast(_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype)
+        assert not a.type.ndim
+        dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
+
+        # The other inputs to the original node that were
+        # neither part of the dot22 or this mul should be
+        # factors in the returned "mul" node.
+        assert dot22_idx != mul_idx
+        other_factors = [
+            inpt for i, inpt in enumerate(node.inputs) if i not in (dot22_idx, mul_idx)
+        ]
+        other_m_inputs = [
+            inpt for i, inpt in enumerate(m.owner.inputs) if i != scalar_idx
+        ]
+
+        return [mul(dot, *(other_factors + other_m_inputs))]
+
+    scalar_idx = -1
+    for i, x in enumerate(node.inputs):
+        if (
+            i != dot22_idx
+            and i_scalar[i] is not None
+            and (pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype)
+        ):
+            scalar_idx = i
+            break
+    if scalar_idx < 0:
+        _logger.info(
+            f"Not optimizing dot22 with inputs {node.inputs} {[x.type for x in node.inputs]}, as the type "
+            "of the scalar cannot be upcasted to the matrix type"
+        )
+        return False
+    assert scalar_idx < len(node.inputs)
+    s = node.inputs[scalar_idx]
+    o = copy.copy(node.inputs)
+    o.remove(d)
+    o.remove(s)
+
+    a = at.cast(i_scalar[scalar_idx], d.type.dtype)
+    assert not a.type.ndim
+    if len(o) == 0:
+        return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)]
+    else:
+        return [mul(_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a), *o)]
+
+
+# must happen after gemm as the gemm optimizer don't understant
+# dot22scalar and gemm give more speed up then dot22scalar
+blas_optdb.register(
+    "local_dot22_to_dot22scalar",
+    in2out(local_dot22_to_dot22scalar),
+    "fast_run",
+    position=11,
+)
+
+
+# from opt import register_specialize, register_canonicalize
+# @register_specialize
+@node_rewriter([sub, add])
+def local_print_as_we_go_along(fgraph, node):
+    if node.op in (sub, add):
+        debugprint(node)
diff --git a/pytensor/tensor/rewriting/blas_c.py b/pytensor/tensor/rewriting/blas_c.py
new file mode 100644
index 0000000000..77629dccca
--- /dev/null
+++ b/pytensor/tensor/rewriting/blas_c.py
@@ -0,0 +1,70 @@
+from pytensor.configdefaults import config
+from pytensor.graph.rewriting.basic import in2out
+from pytensor.tensor import basic as at
+from pytensor.tensor.blas import gemv_inplace, gemv_no_inplace, ger, ger_destructive
+from pytensor.tensor.blas_c import (
+    CGemv,
+    CGer,
+    cgemv_inplace,
+    cgemv_no_inplace,
+    cger_inplace,
+)
+from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb
+
+
+@node_rewriter([ger, ger_destructive])
+def use_c_ger(fgraph, node):
+    if not config.blas__ldflags:
+        return
+    # Only float32 and float64 are supported for now.
+    if node.op == ger and node.outputs[0].dtype in ("float32", "float64"):
+        return [CGer(False)(*node.inputs)]
+    if node.op == ger_destructive and node.outputs[0].dtype in ("float32", "float64"):
+        return [CGer(True)(*node.inputs)]
+
+
+@node_rewriter([CGer(False)])
+def make_c_ger_destructive(fgraph, node):
+    if isinstance(node.op, CGer) and not node.op.destructive:
+        return [cger_inplace(*node.inputs)]
+
+
+@node_rewriter([gemv_inplace, gemv_no_inplace])
+def use_c_gemv(fgraph, node):
+    if not config.blas__ldflags:
+        return
+    # Only float32 and float64 are supported for now.
+    if node.op == gemv_no_inplace and node.outputs[0].dtype in ("float32", "float64"):
+        return [cgemv_no_inplace(*node.inputs)]
+    if node.op == gemv_inplace and node.outputs[0].dtype in ("float32", "float64"):
+        return [cgemv_inplace(*node.inputs)]
+
+
+@node_rewriter([CGemv(inplace=False)])
+def make_c_gemv_destructive(fgraph, node):
+    if isinstance(node.op, CGemv) and not node.op.inplace:
+        inputs = list(node.inputs)
+        dest = inputs[0]
+        if (
+            dest.owner
+            and isinstance(dest.owner.op, at.AllocEmpty)
+            and len(fgraph.clients[dest]) > 1
+        ):
+            inputs[0] = at.AllocEmpty(dest.dtype)(*dest.owner.inputs)
+
+        return [cgemv_inplace(*inputs)]
+
+
+blas_optdb.register(
+    "use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20
+)
+
+# this matches the InplaceBlasOpt defined in blas.py
+optdb.register(
+    "c_blas_destructive",
+    in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"),
+    "fast_run",
+    "inplace",
+    "c_blas",
+    position=70.0,
+)
diff --git a/pytensor/tensor/rewriting/blas_scipy.py b/pytensor/tensor/rewriting/blas_scipy.py
new file mode 100644
index 0000000000..2b2aa94eef
--- /dev/null
+++ b/pytensor/tensor/rewriting/blas_scipy.py
@@ -0,0 +1,37 @@
+from pytensor.graph.rewriting.basic import in2out
+from pytensor.tensor.blas import ger, ger_destructive, have_fblas
+from pytensor.tensor.blas_scipy import scipy_ger_inplace, scipy_ger_no_inplace
+from pytensor.tensor.rewriting.blas import blas_optdb, node_rewriter, optdb
+
+
+@node_rewriter([ger, ger_destructive])
+def use_scipy_ger(fgraph, node):
+    if node.op == ger:
+        return [scipy_ger_no_inplace(*node.inputs)]
+
+
+@node_rewriter([scipy_ger_no_inplace])
+def make_ger_destructive(fgraph, node):
+    if node.op == scipy_ger_no_inplace:
+        return [scipy_ger_inplace(*node.inputs)]
+
+
+use_scipy_blas = in2out(use_scipy_ger)
+make_scipy_blas_destructive = in2out(make_ger_destructive)
+
+if have_fblas:
+    # scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
+    # sucks, but it is almost always present.
+    # C implementations should be scheduled earlier than this, so that they take
+    # precedence. Once the original Ger is replaced, then these optimizations
+    # have no effect.
+    blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
+
+    # this matches the InplaceBlasOpt defined in blas.py
+    optdb.register(
+        "make_scipy_blas_destructive",
+        make_scipy_blas_destructive,
+        "fast_run",
+        "inplace",
+        position=70.0,
+    )
diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py
index afc51a9e3c..6bf4b5b902 100644
--- a/pytensor/tensor/rewriting/elemwise.py
+++ b/pytensor/tensor/rewriting/elemwise.py
@@ -349,7 +349,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
 
 
 inplace_elemwise_optimizer = InplaceElemwiseOptimizer(Elemwise)
-compile.optdb.register(  # type: ignore
+compile.optdb.register(
     "inplace_elemwise_opt",
     inplace_elemwise_optimizer,
     "inplace_opt",  # for historic reason
@@ -1097,7 +1097,7 @@ def print_profile(stream, prof, level=0):
         "fusion",
         position=1,
     )
-    compile.optdb.register(  # type: ignore
+    compile.optdb.register(
         "elemwise_fusion",
         fuse_seqopt,
         "fast_run",
@@ -1211,7 +1211,7 @@ def local_careduce_fusion(fgraph, node):
     return [new_car_op(*elm_inputs)]
 
 
-compile.optdb.register(  # type: ignore
+compile.optdb.register(
     "local_careduce_fusion",
     in2out(local_careduce_fusion),
     "fusion",
@@ -1321,7 +1321,7 @@ def split_2f1grad_loop(fgraph, node):
     return replacements
 
 
-compile.optdb["py_only"].register(  # type: ignore
+compile.optdb["py_only"].register(
     "split_2f1grad_loop",
     split_2f1grad_loop,
     "fast_compile",
diff --git a/pytensor/sandbox/linalg/ops.py b/pytensor/tensor/rewriting/linalg.py
similarity index 83%
rename from pytensor/sandbox/linalg/ops.py
rename to pytensor/tensor/rewriting/linalg.py
index 0a53924801..8f09e52261 100644
--- a/pytensor/sandbox/linalg/ops.py
+++ b/pytensor/tensor/rewriting/linalg.py
@@ -109,6 +109,50 @@ def psd_solve_with_chol(fgraph, node):
             return [x]
 
 
+@register_canonicalize
+@register_stabilize
+@node_rewriter([Cholesky])
+def cholesky_ldotlt(fgraph, node):
+    """
+    rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
+    or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
+
+    This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
+    """
+    if not isinstance(node.op, Cholesky):
+        return
+
+    A = node.inputs[0]
+    if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))):
+        return
+
+    l, r = A.owner.inputs
+
+    # cholesky(dot(L,L.T)) case
+    if (
+        getattr(l.tag, "lower_triangular", False)
+        and r.owner
+        and isinstance(r.owner.op, DimShuffle)
+        and r.owner.op.new_order == (1, 0)
+        and r.owner.inputs[0] == l
+    ):
+        if node.op.lower:
+            return [l]
+        return [r]
+
+    # cholesky(dot(U.T,U)) case
+    if (
+        getattr(r.tag, "upper_triangular", False)
+        and l.owner
+        and isinstance(l.owner.op, DimShuffle)
+        and l.owner.op.new_order == (1, 0)
+        and l.owner.inputs[0] == r
+    ):
+        if node.op.lower:
+            return [l]
+        return [r]
+
+
 @register_stabilize
 @register_specialize
 @node_rewriter([Det])
diff --git a/tests/sandbox/__init__.py b/tests/sandbox/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/sandbox/linalg/__init__.py b/tests/sandbox/linalg/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/sandbox/test_minimal.py b/tests/sandbox/test_minimal.py
deleted file mode 100644
index 82e346eaf3..0000000000
--- a/tests/sandbox/test_minimal.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import numpy as np
-import pytest
-
-from pytensor import function
-from pytensor.sandbox.minimal import minimal
-from pytensor.tensor.type import matrix, vector
-from tests import unittest_tools as utt
-
-
-@pytest.mark.skip(reason="Unfinished test")
-class TestMinimal:
-    """
-    TODO: test dtype conversion
-    TODO: test that invalid types are rejected by make_node
-    TODO: test that each valid type for A and b works correctly
-    """
-
-    def setup_method(self):
-        self.rng = np.random.default_rng(utt.fetch_seed(666))
-
-    def test_minimal(self):
-        A = matrix()
-        b = vector()
-
-        print("building function")
-        f = function([A, b], minimal(A, A, b, b, A))
-        print("built")
-
-        Aval = self.rng.standard_normal((5, 5))
-        bval = np.arange(5, dtype=float)
-        f(Aval, bval)
-        print("done")
diff --git a/tests/sandbox/linalg/test_linalg.py b/tests/tensor/rewriting/test_linalg.py
similarity index 59%
rename from tests/sandbox/linalg/test_linalg.py
rename to tests/tensor/rewriting/test_linalg.py
index f2cb67221c..9ec182cb21 100644
--- a/tests/sandbox/linalg/test_linalg.py
+++ b/tests/tensor/rewriting/test_linalg.py
@@ -1,14 +1,17 @@
 import numpy as np
 import numpy.linalg
+import pytest
+import scipy.linalg
 
 import pytensor
 from pytensor import function
 from pytensor import tensor as at
+from pytensor.compile import get_default_mode
 from pytensor.configdefaults import config
-from pytensor.sandbox.linalg.ops import inv_as_solve, spectral_radius_bound
 from pytensor.tensor.elemwise import DimShuffle
 from pytensor.tensor.math import _allclose
 from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse
+from pytensor.tensor.rewriting.linalg import inv_as_solve
 from pytensor.tensor.slinalg import Cholesky, Solve, solve
 from pytensor.tensor.type import dmatrix, matrix, vector
 from tests import unittest_tools as utt
@@ -65,53 +68,6 @@ def test_rop_lop():
     assert _allclose(v1, v2), f"LOP mismatch: {v1} {v2}"
 
 
-def test_spectral_radius_bound():
-    tol = 10 ** (-6)
-    rng = np.random.default_rng(utt.fetch_seed())
-    x = matrix()
-    radius_bound = spectral_radius_bound(x, 5)
-    f = pytensor.function([x], radius_bound)
-
-    shp = (3, 4)
-    m = rng.random(shp)
-    m = np.cov(m).astype(config.floatX)
-    radius_bound_pytensor = f(m)
-
-    # test the approximation
-    mm = m
-    for i in range(5):
-        mm = np.dot(mm, mm)
-    radius_bound_numpy = np.trace(mm) ** (2 ** (-5))
-    assert abs(radius_bound_numpy - radius_bound_pytensor) < tol
-
-    # test the bound
-    eigen_val = numpy.linalg.eig(m)
-    assert (eigen_val[0].max() - radius_bound_pytensor) < tol
-
-    # test type errors
-    xx = vector()
-    ok = False
-    try:
-        spectral_radius_bound(xx, 5)
-    except TypeError:
-        ok = True
-    assert ok
-    ok = False
-    try:
-        spectral_radius_bound(x, 5.0)
-    except TypeError:
-        ok = True
-    assert ok
-
-    # test value error
-    ok = False
-    try:
-        spectral_radius_bound(x, -5)
-    except ValueError:
-        ok = True
-    assert ok
-
-
 def test_transinv_to_invtrans():
     X = matrix("X")
     Y = matrix_inverse(X)
@@ -152,3 +108,75 @@ def test_matrix_inverse_solve():
     node = matrix_inverse(A).dot(b).owner
     [out] = inv_as_solve.transform(None, node)
     assert isinstance(out.owner.op, Solve)
+
+
+@pytest.mark.parametrize("tag", ("lower", "upper", None))
+@pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
+@pytest.mark.parametrize("product", ("lower", "upper", None))
+def test_cholesky_ldotlt(tag, cholesky_form, product):
+    cholesky = Cholesky(lower=(cholesky_form == "lower"))
+
+    transform_removes_chol = tag is not None and product == tag
+    transform_transposes = transform_removes_chol and cholesky_form != tag
+
+    A = matrix("L")
+    if tag:
+        setattr(A.tag, tag + "_triangular", True)
+
+    if product == "lower":
+        M = A.dot(A.T)
+    elif product == "upper":
+        M = A.T.dot(A)
+    else:
+        M = A
+
+    C = cholesky(M)
+    f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))
+
+    print(f.maker.fgraph.apply_nodes)
+
+    no_cholesky_in_graph = not any(
+        isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes
+    )
+
+    assert no_cholesky_in_graph == transform_removes_chol
+
+    if transform_transposes:
+        assert any(
+            isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)
+            for node in f.maker.fgraph.apply_nodes
+        )
+
+    # Test some concrete value through f
+    # there must be lower triangular (f assumes they are)
+    Avs = [
+        np.eye(1, dtype=pytensor.config.floatX),
+        np.eye(10, dtype=pytensor.config.floatX),
+        np.array([[2, 0], [1, 4]], dtype=pytensor.config.floatX),
+    ]
+    if not tag:
+        # these must be positive def
+        Avs.extend(
+            [
+                np.ones((4, 4), dtype=pytensor.config.floatX)
+                + np.eye(4, dtype=pytensor.config.floatX),
+            ]
+        )
+
+    for Av in Avs:
+        if tag == "upper":
+            Av = Av.T
+
+        if product == "lower":
+            Mv = Av.dot(Av.T)
+        elif product == "upper":
+            Mv = Av.T.dot(Av)
+        else:
+            Mv = Av
+
+        assert np.all(
+            np.isclose(
+                scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")),
+                f(Av),
+            )
+        )
diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py
index 0ce7640d38..035f9e036b 100644
--- a/tests/tensor/test_blas.py
+++ b/tests/tensor/test_blas.py
@@ -44,12 +44,11 @@
     gemv_no_inplace,
     ger,
     ger_destructive,
-    local_dot22_to_dot22scalar,
-    local_gemm_to_ger,
     res_is_a,
 )
 from pytensor.tensor.elemwise import DimShuffle
 from pytensor.tensor.math import Dot, dot, mean, mul, neg, outer, sigmoid, sqrt
+from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger
 from pytensor.tensor.type import (
     cmatrix,
     col,