From a3b7c4569e1a09c97a6f91b1a799fefad64502e7 Mon Sep 17 00:00:00 2001
From: Brock <jbrockmendel@gmail.com>
Date: Wed, 14 Oct 2020 14:47:13 -0700
Subject: [PATCH 1/7] ENH: IntervalArray comparisons

---
 pandas/core/arrays/_mixins.py                 |  2 -
 pandas/core/arrays/interval.py                | 92 +++++++++++++++----
 pandas/tests/arithmetic/test_interval.py      |  5 -
 pandas/tests/extension/base/methods.py        |  2 +-
 .../tests/indexes/interval/test_interval.py   |  8 +-
 5 files changed, 79 insertions(+), 30 deletions(-)

diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py
index 95a003efbe1d0..b691f425b8436 100644
--- a/pandas/core/arrays/_mixins.py
+++ b/pandas/core/arrays/_mixins.py
@@ -3,7 +3,6 @@
 import numpy as np
 
 from pandas._libs import lib
-from pandas.compat.numpy import function as nv
 from pandas.errors import AbstractMethodError
 from pandas.util._decorators import cache_readonly, doc
 from pandas.util._validators import validate_fillna_kwargs
@@ -139,7 +138,6 @@ def repeat(self: _T, repeats, axis=None) -> _T:
         --------
         numpy.ndarray.repeat
         """
-        nv.validate_repeat(tuple(), dict(axis=axis))
         new_data = self._ndarray.repeat(repeats, axis=axis)
         return self._from_backing_data(new_data)
 
diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py
index 09488b9576212..db1f3377b8cb9 100644
--- a/pandas/core/arrays/interval.py
+++ b/pandas/core/arrays/interval.py
@@ -1,3 +1,4 @@
+import operator
 from operator import le, lt
 import textwrap
 from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
@@ -48,7 +49,7 @@
 from pandas.core.construction import array, extract_array
 from pandas.core.indexers import check_array_indexer
 from pandas.core.indexes.base import ensure_index
-from pandas.core.ops import unpack_zerodim_and_defer
+from pandas.core.ops import invalid_comparison, unpack_zerodim_and_defer
 
 if TYPE_CHECKING:
     from pandas import Index
@@ -520,8 +521,7 @@ def __setitem__(self, key, value):
         self._left[key] = value_left
         self._right[key] = value_right
 
-    @unpack_zerodim_and_defer("__eq__")
-    def __eq__(self, other):
+    def _cmp_method(self, other, op):
         # ensure pandas array for list-like and eliminate non-interval scalars
         if is_list_like(other):
             if len(self) != len(other):
@@ -529,7 +529,7 @@ def __eq__(self, other):
             other = array(other)
         elif not isinstance(other, Interval):
             # non-interval scalar -> no matches
-            return np.zeros(len(self), dtype=bool)
+            return invalid_comparison(self, other, op)
 
         # determine the dtype of the elements we want to compare
         if isinstance(other, Interval):
@@ -543,33 +543,87 @@ def __eq__(self, other):
             # extract intervals if we have interval categories with matching closed
             if is_interval_dtype(other_dtype):
                 if self.closed != other.categories.closed:
-                    return np.zeros(len(self), dtype=bool)
+                    return invalid_comparison(self, other, op)
                 other = other.categories.take(other.codes)
 
         # interval-like -> need same closed and matching endpoints
         if is_interval_dtype(other_dtype):
             if self.closed != other.closed:
-                return np.zeros(len(self), dtype=bool)
-            return (self._left == other.left) & (self._right == other.right)
+                return invalid_comparison(self, other, op)
+            if isinstance(other, Interval):
+                other = type(self)._from_sequence([other])
+                if self._combined.dtype.kind in ["m", "M"]:
+                    # Need to repeat bc we do not broadcast length-1
+                    # TODO: would be helpful to have a tile method to do
+                    #  this without copies
+                    other = other.repeat(len(self))
+            else:
+                other = type(self)(other)
+
+            if op is operator.eq:
+                return (self._combined[:, 0] == other._left) & (
+                    self._combined[:, 1] == other._right
+                )
+            elif op is operator.ne:
+                return (self._combined[:, 0] != other._left) | (
+                    self._combined[:, 1] != other._right
+                )
+            elif op is operator.gt:
+                return (self._combined[:, 0] > other._combined[:, 0]) | (
+                    (self._combined[:, 0] == other._left)
+                    & (self._combined[:, 1] > other._right)
+                )
+            elif op is operator.ge:
+                return (self == other) | (self > other)
+            elif op is operator.lt:
+                return (self._combined[:, 0] < other._combined[:, 0]) | (
+                    (self._combined[:, 0] == other._left)
+                    & (self._combined[:, 1] < other._right)
+                )
+            else:
+                # operator.lt
+                return (self == other) | (self < other)
 
         # non-interval/non-object dtype -> no matches
         if not is_object_dtype(other_dtype):
-            return np.zeros(len(self), dtype=bool)
+            return invalid_comparison(self, other, op)
 
         # object dtype -> iteratively check for intervals
-        result = np.zeros(len(self), dtype=bool)
-        for i, obj in enumerate(other):
-            # need object to be an Interval with same closed and endpoints
-            if (
-                isinstance(obj, Interval)
-                and self.closed == obj.closed
-                and self._left[i] == obj.left
-                and self._right[i] == obj.right
-            ):
-                result[i] = True
-
+        try:
+            result = np.zeros(len(self), dtype=bool)
+            for i, obj in enumerate(other):
+                result[i] = op(self[i], obj)
+        except TypeError:
+            # pd.NA
+            result = np.zeros(len(self), dtype=object)
+            for i, obj in enumerate(other):
+                result[i] = op(self[i], obj)
         return result
 
+    @unpack_zerodim_and_defer("__eq__")
+    def __eq__(self, other):
+        return self._cmp_method(other, operator.eq)
+
+    @unpack_zerodim_and_defer("__ne__")
+    def __ne__(self, other):
+        return self._cmp_method(other, operator.ne)
+
+    @unpack_zerodim_and_defer("__gt__")
+    def __gt__(self, other):
+        return self._cmp_method(other, operator.gt)
+
+    @unpack_zerodim_and_defer("__ge__")
+    def __ge__(self, other):
+        return self._cmp_method(other, operator.ge)
+
+    @unpack_zerodim_and_defer("__lt__")
+    def __lt__(self, other):
+        return self._cmp_method(other, operator.lt)
+
+    @unpack_zerodim_and_defer("__le__")
+    def __le__(self, other):
+        return self._cmp_method(other, operator.le)
+
     def fillna(self, value=None, method=None, limit=None):
         """
         Fill NA/NaN values using the specified method.
diff --git a/pandas/tests/arithmetic/test_interval.py b/pandas/tests/arithmetic/test_interval.py
index 03cc4fe2bdcb5..8ab07a358ee51 100644
--- a/pandas/tests/arithmetic/test_interval.py
+++ b/pandas/tests/arithmetic/test_interval.py
@@ -216,11 +216,6 @@ def test_compare_list_like_nan(self, op, array, nulls_fixture, request):
         result = op(array, other)
         expected = self.elementwise_comparison(op, array, other)
 
-        if nulls_fixture is pd.NA and array.dtype.subtype != "i8":
-            reason = "broken for non-integer IntervalArray; see GH 31882"
-            mark = pytest.mark.xfail(reason=reason)
-            request.node.add_marker(mark)
-
         tm.assert_numpy_array_equal(result, expected)
 
     @pytest.mark.parametrize(
diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py
index 23e20a2c0903a..94533dcc08c48 100644
--- a/pandas/tests/extension/base/methods.py
+++ b/pandas/tests/extension/base/methods.py
@@ -443,7 +443,7 @@ def test_repeat(self, data, repeats, as_series, use_numpy):
     @pytest.mark.parametrize(
         "repeats, kwargs, error, msg",
         [
-            (2, dict(axis=1), ValueError, "'axis"),
+            (2, dict(axis=1), ValueError, "axis"),
             (-1, dict(), ValueError, "negative"),
             ([1, 2], dict(), ValueError, "shape"),
             (2, dict(foo="bar"), TypeError, "'foo'"),
diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py
index 17a1c69858c11..e99b3dcc36fbb 100644
--- a/pandas/tests/indexes/interval/test_interval.py
+++ b/pandas/tests/indexes/interval/test_interval.py
@@ -579,9 +579,11 @@ def test_comparison(self):
         actual = self.index == self.index.left
         tm.assert_numpy_array_equal(actual, np.array([False, False]))
 
-        msg = (
-            "not supported between instances of 'int' and "
-            "'pandas._libs.interval.Interval'"
+        msg = "|".join(
+            [
+                "not supported between instances of 'int' and '.*.Interval'",
+                r"Invalid comparison between dtype=interval\[int64\] and ",
+            ]
         )
         with pytest.raises(TypeError, match=msg):
             self.index > 0

From cf20846c33d435b4a797d6d3e29679a3f447d2da Mon Sep 17 00:00:00 2001
From: Brock <jbrockmendel@gmail.com>
Date: Wed, 14 Oct 2020 16:01:46 -0700
Subject: [PATCH 2/7] CLN: get IntervalIndex comparisons from IntervalArray

---
 pandas/core/indexes/interval.py | 13 -------------
 1 file changed, 13 deletions(-)

diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py
index cb25ef1241ce0..969638367ae92 100644
--- a/pandas/core/indexes/interval.py
+++ b/pandas/core/indexes/interval.py
@@ -1105,19 +1105,6 @@ def _is_all_dates(self) -> bool:
 
     # TODO: arithmetic operations
 
-    # GH#30817 until IntervalArray implements inequalities, get them from Index
-    def __lt__(self, other):
-        return Index.__lt__(self, other)
-
-    def __le__(self, other):
-        return Index.__le__(self, other)
-
-    def __gt__(self, other):
-        return Index.__gt__(self, other)
-
-    def __ge__(self, other):
-        return Index.__ge__(self, other)
-
 
 def _is_valid_endpoint(endpoint) -> bool:
     """

From fcfe47d29eedb726344bc6e22ea307c49cf2c3cb Mon Sep 17 00:00:00 2001
From: Brock <jbrockmendel@gmail.com>
Date: Tue, 20 Oct 2020 10:34:17 -0700
Subject: [PATCH 3/7] update per requests

---
 pandas/core/arrays/interval.py | 22 ++++++++--------------
 1 file changed, 8 insertions(+), 14 deletions(-)

diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py
index db1f3377b8cb9..77f60dc879ce7 100644
--- a/pandas/core/arrays/interval.py
+++ b/pandas/core/arrays/interval.py
@@ -552,7 +552,7 @@ def _cmp_method(self, other, op):
                 return invalid_comparison(self, other, op)
             if isinstance(other, Interval):
                 other = type(self)._from_sequence([other])
-                if self._combined.dtype.kind in ["m", "M"]:
+                if self._left.dtype.kind in ["m", "M"]:
                     # Need to repeat bc we do not broadcast length-1
                     # TODO: would be helpful to have a tile method to do
                     #  this without copies
@@ -561,24 +561,18 @@ def _cmp_method(self, other, op):
                 other = type(self)(other)
 
             if op is operator.eq:
-                return (self._combined[:, 0] == other._left) & (
-                    self._combined[:, 1] == other._right
-                )
+                return (self._left == other._left) & (self._right == other._right)
             elif op is operator.ne:
-                return (self._combined[:, 0] != other._left) | (
-                    self._combined[:, 1] != other._right
-                )
+                return (self._left != other._left) | (self._right != other._right)
             elif op is operator.gt:
-                return (self._combined[:, 0] > other._combined[:, 0]) | (
-                    (self._combined[:, 0] == other._left)
-                    & (self._combined[:, 1] > other._right)
+                return (self._left > other._left) | (
+                    (self._left == other._left) & (self._right > other._right)
                 )
             elif op is operator.ge:
                 return (self == other) | (self > other)
             elif op is operator.lt:
-                return (self._combined[:, 0] < other._combined[:, 0]) | (
-                    (self._combined[:, 0] == other._left)
-                    & (self._combined[:, 1] < other._right)
+                return (self._left < other._left) | (
+                    (self._left == other._left) & (self._right < other._right)
                 )
             else:
                 # operator.lt
@@ -589,8 +583,8 @@ def _cmp_method(self, other, op):
             return invalid_comparison(self, other, op)
 
         # object dtype -> iteratively check for intervals
+        result = np.zeros(len(self), dtype=bool)
         try:
-            result = np.zeros(len(self), dtype=bool)
             for i, obj in enumerate(other):
                 result[i] = op(self[i], obj)
         except TypeError:

From fa6cecdebcc049f0ed278eff141e81b9740c3eb2 Mon Sep 17 00:00:00 2001
From: Brock <jbrockmendel@gmail.com>
Date: Fri, 23 Oct 2020 18:01:17 -0700
Subject: [PATCH 4/7] Avoid having to tile

---
 pandas/core/arrays/interval.py | 21 +++++++--------------
 1 file changed, 7 insertions(+), 14 deletions(-)

diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py
index 77f60dc879ce7..d88d5882ffc33 100644
--- a/pandas/core/arrays/interval.py
+++ b/pandas/core/arrays/interval.py
@@ -550,29 +550,22 @@ def _cmp_method(self, other, op):
         if is_interval_dtype(other_dtype):
             if self.closed != other.closed:
                 return invalid_comparison(self, other, op)
-            if isinstance(other, Interval):
-                other = type(self)._from_sequence([other])
-                if self._left.dtype.kind in ["m", "M"]:
-                    # Need to repeat bc we do not broadcast length-1
-                    # TODO: would be helpful to have a tile method to do
-                    #  this without copies
-                    other = other.repeat(len(self))
-            else:
+            elif not isinstance(other, Interval):
                 other = type(self)(other)
 
             if op is operator.eq:
-                return (self._left == other._left) & (self._right == other._right)
+                return (self._left == other.left) & (self._right == other.right)
             elif op is operator.ne:
-                return (self._left != other._left) | (self._right != other._right)
+                return (self._left != other.left) | (self._right != other.right)
             elif op is operator.gt:
-                return (self._left > other._left) | (
-                    (self._left == other._left) & (self._right > other._right)
+                return (self._left > other.left) | (
+                    (self._left == other.left) & (self._right > other.right)
                 )
             elif op is operator.ge:
                 return (self == other) | (self > other)
             elif op is operator.lt:
-                return (self._left < other._left) | (
-                    (self._left == other._left) & (self._right < other._right)
+                return (self._left < other.left) | (
+                    (self._left == other.left) & (self._right < other.right)
                 )
             else:
                 # operator.lt

From ff640ea8f8ca52059ed85aac1e71ad0955fd97c1 Mon Sep 17 00:00:00 2001
From: Brock <jbrockmendel@gmail.com>
Date: Sat, 24 Oct 2020 15:12:10 -0700
Subject: [PATCH 5/7] handle NA per suggestion

---
 pandas/core/arrays/interval.py           | 16 +++++++++-------
 pandas/tests/arithmetic/test_interval.py |  5 +++++
 2 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py
index d88d5882ffc33..dae3b70ed6b74 100644
--- a/pandas/core/arrays/interval.py
+++ b/pandas/core/arrays/interval.py
@@ -13,6 +13,7 @@
     IntervalMixin,
     intervals_to_interval_bounds,
 )
+from pandas._libs.missing import NA
 from pandas._typing import ArrayLike, Dtype
 from pandas.compat.numpy import function as nv
 from pandas.util._decorators import Appender
@@ -577,14 +578,15 @@ def _cmp_method(self, other, op):
 
         # object dtype -> iteratively check for intervals
         result = np.zeros(len(self), dtype=bool)
-        try:
-            for i, obj in enumerate(other):
-                result[i] = op(self[i], obj)
-        except TypeError:
-            # pd.NA
-            result = np.zeros(len(self), dtype=object)
-            for i, obj in enumerate(other):
+        for i, obj in enumerate(other):
+            try:
                 result[i] = op(self[i], obj)
+            except TypeError:
+                if obj is NA:
+                    # github.com/pandas-dev/pandas/pull/37124#discussion_r509095092
+                    result[i] = op is operator.ne
+                else:
+                    raise
         return result
 
     @unpack_zerodim_and_defer("__eq__")
diff --git a/pandas/tests/arithmetic/test_interval.py b/pandas/tests/arithmetic/test_interval.py
index b7011ed223166..30a23d8563ef8 100644
--- a/pandas/tests/arithmetic/test_interval.py
+++ b/pandas/tests/arithmetic/test_interval.py
@@ -216,6 +216,11 @@ def test_compare_list_like_nan(self, op, array, nulls_fixture, request):
         result = op(array, other)
         expected = self.elementwise_comparison(op, array, other)
 
+        if nulls_fixture is pd.NA and array.dtype.subtype != "i8":
+            reason = "broken for non-integer IntervalArray; see GH 31882"
+            mark = pytest.mark.xfail(reason=reason)
+            request.node.add_marker(mark)
+
         tm.assert_numpy_array_equal(result, expected)
 
     @pytest.mark.parametrize(

From 247ce9083eab161be84cebef752c318ac0c567cf Mon Sep 17 00:00:00 2001
From: Brock <jbrockmendel@gmail.com>
Date: Tue, 27 Oct 2020 11:08:38 -0700
Subject: [PATCH 6/7] comment

---
 pandas/core/arrays/_mixins.py  | 2 ++
 pandas/core/arrays/interval.py | 2 ++
 2 files changed, 4 insertions(+)

diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py
index 4f49cd6b3faf2..948ffdc1f7c01 100644
--- a/pandas/core/arrays/_mixins.py
+++ b/pandas/core/arrays/_mixins.py
@@ -3,6 +3,7 @@
 import numpy as np
 
 from pandas._libs import lib
+from pandas.compat.numpy import function as nv
 from pandas.errors import AbstractMethodError
 from pandas.util._decorators import cache_readonly, doc
 from pandas.util._validators import validate_fillna_kwargs
@@ -138,6 +139,7 @@ def repeat(self: _T, repeats, axis=None) -> _T:
         --------
         numpy.ndarray.repeat
         """
+        nv.validate_repeat(tuple(), dict(axis=axis))
         new_data = self._ndarray.repeat(repeats, axis=axis)
         return self._from_backing_data(new_data)
 
diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py
index dae3b70ed6b74..b943c811c54f9 100644
--- a/pandas/core/arrays/interval.py
+++ b/pandas/core/arrays/interval.py
@@ -583,6 +583,8 @@ def _cmp_method(self, other, op):
                 result[i] = op(self[i], obj)
             except TypeError:
                 if obj is NA:
+                    # comparison returns NA, which we (for now?) treat like
+                    #  other NAs
                     # github.com/pandas-dev/pandas/pull/37124#discussion_r509095092
                     result[i] = op is operator.ne
                 else:

From badb99d0a6873ce717cd621b785ac90edb159338 Mon Sep 17 00:00:00 2001
From: Brock <jbrockmendel@gmail.com>
Date: Mon, 2 Nov 2020 12:22:34 -0800
Subject: [PATCH 7/7] update comment

---
 pandas/core/arrays/interval.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py
index 3488d35e3518a..f8ece2a9fe7d4 100644
--- a/pandas/core/arrays/interval.py
+++ b/pandas/core/arrays/interval.py
@@ -586,8 +586,7 @@ def _cmp_method(self, other, op):
                 result[i] = op(self[i], obj)
             except TypeError:
                 if obj is NA:
-                    # comparison returns NA, which we (for now?) treat like
-                    #  other NAs
+                    # comparison with np.nan returns NA
                     # github.com/pandas-dev/pandas/pull/37124#discussion_r509095092
                     result[i] = op is operator.ne
                 else: