From 7d8c4e1aa1fd0c4e6e0deef52784e267a7e19d22 Mon Sep 17 00:00:00 2001 From: Stefanos Chaliasos Date: Thu, 6 Mar 2025 11:46:22 +0200 Subject: [PATCH 1/8] Add support for character classes `[...]` Also, fix the quantifiers more expressive. I.e., now it supports: {,4}, {4}, {1,3}, {1,} instead of just {1,3} and {1,} --- automata/fa/gnfa.py | 24 ++--- automata/regex/parser.py | 149 +++++++++++++++++++++++++++++-- automata/regex/regex.py | 4 + tests/test_regex.py | 188 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 346 insertions(+), 19 deletions(-) diff --git a/automata/fa/gnfa.py b/automata/fa/gnfa.py index f0a0d51d..d804595b 100644 --- a/automata/fa/gnfa.py +++ b/automata/fa/gnfa.py @@ -129,9 +129,9 @@ def from_dfa(cls: Type[Self], target_dfa: dfa.DFA) -> Self: if state in target_dfa.transitions: for input_symbol, to_state in target_dfa.transitions[state].items(): if to_state in gnfa_transitions.keys(): - gnfa_transitions[to_state] = ( - f"{gnfa_transitions[to_state]}|{input_symbol}" - ) + gnfa_transitions[ + to_state + ] = f"{gnfa_transitions[to_state]}|{input_symbol}" else: gnfa_transitions[to_state] = input_symbol new_gnfa_transitions[state] = gnfa_transitions @@ -190,17 +190,17 @@ def from_nfa(cls: Type[Self], target_nfa: nfa.NFA) -> Self: gnfa_transitions[to_state] != "" and input_symbol == "" ): if cls._isbracket_req(gnfa_transitions[to_state]): - gnfa_transitions[to_state] = ( - f"({gnfa_transitions[to_state]})?" - ) + gnfa_transitions[ + to_state + ] = f"({gnfa_transitions[to_state]})?" else: - gnfa_transitions[to_state] = ( - f"{gnfa_transitions[to_state]}?" - ) + gnfa_transitions[ + to_state + ] = f"{gnfa_transitions[to_state]}?" else: - gnfa_transitions[to_state] = ( - f"{gnfa_transitions[to_state]}|{input_symbol}" - ) + gnfa_transitions[ + to_state + ] = f"{gnfa_transitions[to_state]}|{input_symbol}" else: gnfa_transitions[to_state] = input_symbol new_gnfa_transitions[state] = cast( diff --git a/automata/regex/parser.py b/automata/regex/parser.py index dc60b779..13417494 100644 --- a/automata/regex/parser.py +++ b/automata/regex/parser.py @@ -27,7 +27,7 @@ BuilderTransitionsT = Dict[int, Dict[str, Set[int]]] RESERVED_CHARACTERS = frozenset( - ("*", "|", "(", ")", "?", " ", "\t", "&", "+", ".", "^", "{", "}") + ("*", "|", "(", ")", "?", " ", "\t", "&", "+", ".", "^", "{", "}", "[", "]") ) @@ -414,12 +414,52 @@ def __init__(self, text: str, lower_bound: int, upper_bound: Optional[int]) -> N self.upper_bound = upper_bound @classmethod - def from_match(cls: Type[Self], match: re.Match) -> QuantifierToken: + def from_match(cls: Type[Self], match: re.Match) -> Self: lower_bound_str = match.group(1) upper_bound_str = match.group(2) - lower_bound = 0 if not lower_bound_str else int(lower_bound_str) - upper_bound = None if not upper_bound_str else int(upper_bound_str) + # Parse lower bound + if not lower_bound_str: + lower_bound = 0 + else: + try: + lower_bound = int(lower_bound_str) + if lower_bound < 0: + raise exceptions.InvalidRegexError( + f"Lower bound cannot be negative: {lower_bound}" + ) + except ValueError: + # This shouldn't happen with our regex pattern, but just in case + raise exceptions.InvalidRegexError( + f"Invalid lower bound: {lower_bound_str}" + ) + + # Parse upper bound + if upper_bound_str is None: + # Format {n} + upper_bound = lower_bound + elif not upper_bound_str: + # Format {n,} + upper_bound = None + else: + try: + upper_bound = int(upper_bound_str) + if upper_bound < 0: + raise exceptions.InvalidRegexError( + f"Upper bound cannot be negative: {upper_bound}" + ) + except ValueError: + # This shouldn't happen with our regex pattern, but just in case + raise exceptions.InvalidRegexError( + f"Invalid upper bound: {upper_bound_str}" + ) + + # Validate bounds relationship + if upper_bound is not None and lower_bound > upper_bound: + raise exceptions.InvalidRegexError( + f"Lower bound {lower_bound} cannot be " + "greater than upper bound {upper_bound}" + ) return cls(match.group(), lower_bound, upper_bound) @@ -494,6 +534,40 @@ def val(self) -> NFARegexBuilder: return NFARegexBuilder.wildcard(self.input_symbols, self.counter) +class CharacterClassToken(Literal[NFARegexBuilder]): + """Subclass of literal token defining a character class.""" + + __slots__: Tuple[str, ...] = ("input_symbols", "class_chars", "negated", "counter") + + def __init__( + self, + text: str, + class_chars: Set[str], + negated: bool, + input_symbols: AbstractSet[str], + counter: count, + ) -> None: + super().__init__(text) + self.class_chars = class_chars + self.negated = negated + self.input_symbols = input_symbols + self.counter = counter + + @classmethod + def from_match(cls: Type[Self], match: re.Match) -> NoReturn: + raise NotImplementedError + + def val(self) -> NFARegexBuilder: + if self.negated: + # For negated class, create an NFA accepting any character + # not in class_chars + acceptable_chars = self.input_symbols - self.class_chars + return NFARegexBuilder.wildcard(acceptable_chars, self.counter) + else: + # Create an NFA accepting any character in the set + return NFARegexBuilder.wildcard(self.class_chars, self.counter) + + def add_concat_and_empty_string_tokens( token_list: List[Token[NFARegexBuilder]], state_name_counter: count, @@ -501,7 +575,6 @@ def add_concat_and_empty_string_tokens( """Add concat tokens to list of parsed infix tokens.""" final_token_list = [] - # Pairs of token types to insert concat tokens in between concat_pairs = [ (Literal, Literal), @@ -524,7 +597,6 @@ def add_concat_and_empty_string_tokens( next_token, secondClass ): final_token_list.append(ConcatToken("")) - for firstClass, secondClass in empty_string_pairs: if isinstance(curr_token, firstClass) and isinstance( next_token, secondClass @@ -548,11 +620,27 @@ def get_regex_lexer( lexer.register_token(KleeneStarToken.from_match, r"\*") lexer.register_token(KleenePlusToken.from_match, r"\+") lexer.register_token(OptionToken.from_match, r"\?") - lexer.register_token(QuantifierToken.from_match, r"\{(.*?),(.*?)\}") + # Match both {n}, {n,m}, and {,m} formats for quantifiers + lexer.register_token(QuantifierToken.from_match, r"\{(-?\d*)(?:,(-?\d*))?\}") + # Register wildcard and character classes next lexer.register_token( lambda match: WildcardToken(match.group(), input_symbols, state_name_counter), r"\.", ) + + # Add character class token + def character_class_factory(match: re.Match) -> CharacterClassToken: + class_str = match.group() + negated, class_chars = process_char_class(class_str) + return CharacterClassToken( + class_str, class_chars, negated, input_symbols, state_name_counter + ) + + lexer.register_token( + character_class_factory, + r"\[[^\]]*\]", # Match anything between [ and ] + ) + lexer.register_token( lambda match: StringToken(match.group(), state_name_counter), r"\S" ) @@ -577,3 +665,50 @@ def parse_regex(regexstr: str, input_symbols: AbstractSet[str]) -> NFARegexBuild postfix = tokens_to_postfix(tokens_with_concats) return parse_postfix_tokens(postfix) + + +def process_char_class(class_str: str) -> Tuple[bool, Set[str]]: + """Process a character class string into a set of characters and negation flag. + + Parameters + ---------- + class_str : str + The character class string including brackets, e.g., '[a-z]' or '[^abc]' + + Returns + ------- + Tuple[bool, Set[str]] + A tuple containing (is_negated, set_of_characters) + """ + content = class_str[1:-1] + + if not content: + raise exceptions.InvalidRegexError("Empty character class '[]' is not allowed") + + negated = content.startswith("^") + if negated: + content = content[1:] + + if not content: + raise exceptions.InvalidRegexError( + "Empty negated character class '[^]' is not allowed" + ) + + chars = set() + i = 0 + while i < len(content): + # Special case: - at the beginning or end is treated as literal + if content[i] == "-" and (i == 0 or i == len(content) - 1): + chars.add("-") + i += 1 + # Handle ranges - but only when there are characters on both sides + elif i + 2 < len(content) and content[i + 1] == "-": + # Range like a-z + start, end = content[i], content[i + 2] + chars.update(chr(c) for c in range(ord(start), ord(end) + 1)) + i += 3 + else: + chars.add(content[i]) + i += 1 + + return negated, chars diff --git a/automata/regex/regex.py b/automata/regex/regex.py index 3857005d..3b53325f 100644 --- a/automata/regex/regex.py +++ b/automata/regex/regex.py @@ -12,6 +12,10 @@ - `&`: Intersection. Ex: `a&b` - `.`: Wildcard. Ex: `a.b` - `^`: Shuffle. Ex: `a^b` +- `[...]`: Character class, matching any single character from the class. + Ex: `[abc]`, `[0-9]` +- `[^...]`: Negated character class, matching any single character not in the class. + Ex: `[^abc]` - `{}`: Quantifiers expressing finite repetitions. Ex: `a{1,2}`,`a{3,}` - `()`: The empty string. - `(...)`: Grouping. diff --git a/tests/test_regex.py b/tests/test_regex.py index 809ae77c..a2a48df7 100644 --- a/tests/test_regex.py +++ b/tests/test_regex.py @@ -1,6 +1,7 @@ """Classes and functions for testing the behavior of Regex tools""" import re as regex +import string import unittest import automata.base.exceptions as exceptions @@ -232,3 +233,190 @@ def test_invalid_symbols(self) -> None: """Should throw exception if reserved character is in input symbols""" with self.assertRaises(exceptions.InvalidSymbolError): NFA.from_regex("a+", input_symbols={"a", "+"}) + + def test_character_class(self) -> None: + """Should correctly handle character classes""" + # Basic equivalence + self.assertTrue(re.isequal("[abc]", "a|b|c")) + self.assertTrue(re.isequal("a[bc]d", "abd|acd")) + # With NFA construction + nfa1 = NFA.from_regex("[abc]") + nfa2 = NFA.from_regex("a|b|c") + self.assertEqual(nfa1, nfa2) + # Character class with repetition + self.assertTrue(re.isequal("[abc]*", "(a|b|c)*")) + + input_symbols = {"a", "b", "c", "d", "e"} + # Simple range + self.assertTrue(re.isequal("[a-c]", "a|b|c", input_symbols=input_symbols)) + # Multiple ranges + self.assertTrue(re.isequal("[a-ce-e]", "a|b|c|e", input_symbols=input_symbols)) + # Range with individual characters + self.assertTrue(re.isequal("[a-cd]", "a|b|c|d", input_symbols=input_symbols)) + # Create NFAs with negated character classes + nfa1 = NFA.from_regex("[^abc]", input_symbols=input_symbols) + nfa2 = NFA.from_regex("[^a-c]", input_symbols=input_symbols) + nfa3 = NFA.from_regex("a[^abc]+", input_symbols=input_symbols) + # Test acceptance/rejection patterns for simple negation + self.assertTrue(nfa1.accepts_input("d")) + self.assertTrue(nfa1.accepts_input("e")) + self.assertFalse(nfa1.accepts_input("a")) + self.assertFalse(nfa1.accepts_input("b")) + self.assertFalse(nfa1.accepts_input("c")) + # Test acceptance/rejection patterns for negated range + self.assertTrue(nfa2.accepts_input("d")) + self.assertTrue(nfa2.accepts_input("e")) + self.assertFalse(nfa2.accepts_input("a")) + self.assertFalse(nfa2.accepts_input("b")) + self.assertFalse(nfa2.accepts_input("c")) + # Test negated class with kleene plus + self.assertTrue(nfa3.accepts_input("ad")) + self.assertTrue(nfa3.accepts_input("ae")) + self.assertTrue(nfa3.accepts_input("ade")) + self.assertTrue(nfa3.accepts_input("aedd")) + self.assertFalse(nfa3.accepts_input("a")) + self.assertFalse(nfa3.accepts_input("aa")) + self.assertFalse(nfa3.accepts_input("ab")) + self.assertFalse(nfa3.accepts_input("abc")) + + input_symbols = {"a", "b", "c", "d", "e", "0", "1", "2", "3"} + # Character class with quantifiers + self.assertTrue( + re.isequal( + "[abc]{2}", "aa|ab|ac|ba|bb|bc|ca|cb|cc", input_symbols=input_symbols + ) + ) + # Character class with operators + self.assertTrue( + re.isequal("[abc]|[0-3]", "a|b|c|0|1|2|3", input_symbols=input_symbols) + ) + # Intersection with character classes + self.assertTrue(re.isequal("[abc]&[bc0]", "b|c", input_symbols=input_symbols)) + # Shuffle with character classes + self.assertTrue( + re.isequal( + "[ab]^[cd]", "ac|ad|bc|bd|ca|cb|da|db", input_symbols=input_symbols + ) + ) + + # Empty character class should raise error + with self.assertRaises(exceptions.InvalidRegexError): + re.validate("[]") + # Special character as range boundary + input_symbols = {"a", "b", "c", "-", "#"} + self.assertTrue(re.isequal("[a-c-]", "a|b|c|-", input_symbols=input_symbols)) + # Hyphen at the beginning of class (literal interpretation) + self.assertTrue(re.isequal("[-abc]", "-|a|b|c", input_symbols=input_symbols)) + # Hyphen at both beginning and end + self.assertTrue(re.isequal("[-abc-]", "-|a|b|c", input_symbols=input_symbols)) + # Special character with literal interpretation + input_symbols = {"a", "b", "c", "#"} + self.assertTrue(re.isequal("[a#c]", "a|#|c", input_symbols=input_symbols)) + + input_symbols = {"a", "b", "c"} + # Exact repetition {n} + self.assertTrue( + re.isequal("[abc]{2}", "(a|b|c)(a|b|c)", input_symbols=input_symbols) + ) + # Range repetition {n,m} + self.assertTrue( + re.isequal( + "[abc]{1,2}", "(a|b|c)|(a|b|c)(a|b|c)", input_symbols=input_symbols + ) + ) + # Lower bound only {n,} + nfa1 = NFA.from_regex("[abc]{2,}", input_symbols=input_symbols) + nfa2 = NFA.from_regex("(a|b|c)(a|b|c)((a|b|c)*)", input_symbols=input_symbols) + self.assertEqual(nfa1, nfa2) + + def test_unicode_character_classes(self) -> None: + """Should correctly handle Unicode character ranges in character classes""" + + def create_range(start_char, end_char): + return {chr(i) for i in range(ord(start_char), ord(end_char) + 1)} + + latin_ext_chars = create_range("¡", "ƿ") + greek_chars = create_range("Ͱ", "Ͽ") + cyrillic_chars = create_range("Ѐ", "ӿ") + + input_symbols = set() + input_symbols.update(latin_ext_chars) + input_symbols.update(greek_chars) + input_symbols.update(cyrillic_chars) + + ascii_chars = set(string.printable) + input_symbols.update(ascii_chars) + + from automata.fa.nfa import RESERVED_CHARACTERS + + input_symbols = input_symbols - RESERVED_CHARACTERS + + latin_nfa = NFA.from_regex("[¡-ƿ]+", input_symbols=input_symbols) + greek_nfa = NFA.from_regex("[Ͱ-Ͽ]+", input_symbols=input_symbols) + cyrillic_nfa = NFA.from_regex("[Ѐ-ӿ]+", input_symbols=input_symbols) + + latin_samples = ["¡", "£", "Ā", "ŕ", "ƿ"] + greek_samples = ["Ͱ", "Α", "Θ", "Ͽ"] + cyrillic_samples = ["Ѐ", "Ё", "Џ", "ӿ"] + + for char in latin_samples: + self.assertTrue(latin_nfa.accepts_input(char), f"Should accept {char}") + self.assertTrue(latin_nfa.accepts_input("¡Āŕƿ")) # Multiple characters + self.assertFalse(latin_nfa.accepts_input("a")) # ASCII - not in range + self.assertFalse(latin_nfa.accepts_input("Α")) # Greek - not in range + self.assertFalse(latin_nfa.accepts_input("Ё")) # Cyrillic - not in range + self.assertFalse(latin_nfa.accepts_input("¡a")) # Mixed with non-matching + + for char in greek_samples: + self.assertTrue(greek_nfa.accepts_input(char), f"Should accept {char}") + self.assertTrue(greek_nfa.accepts_input("ͰΑΘϿ")) # Multiple characters + self.assertFalse(greek_nfa.accepts_input("a")) # ASCII - not in range + self.assertFalse(greek_nfa.accepts_input("Ā")) # Latin Ext - not in range + self.assertFalse(greek_nfa.accepts_input("Ё")) # Cyrillic - not in range + self.assertFalse(greek_nfa.accepts_input("Αa")) # Mixed with non-matching + + for char in cyrillic_samples: + self.assertTrue(cyrillic_nfa.accepts_input(char), f"Should accept {char}") + self.assertTrue(cyrillic_nfa.accepts_input("ЀЁЏӿ")) # Multiple characters + self.assertFalse(cyrillic_nfa.accepts_input("a")) # ASCII - not in range + self.assertFalse(cyrillic_nfa.accepts_input("Ā")) # Latin Ext - not in range + self.assertFalse(cyrillic_nfa.accepts_input("Α")) # Greek - not in range + self.assertFalse(cyrillic_nfa.accepts_input("Ёa")) # Mixed with non-matching + + combined_regex = "Latin-Extension[¡-ƿ]+Greek[Ͱ-Ͽ]+Cyrillic[Ѐ-ӿ]+" + combined_nfa = NFA.from_regex(combined_regex, input_symbols=input_symbols) + + self.assertTrue(combined_nfa.accepts_input("Latin-Extension¡GreekͰCyrillicЀ")) + self.assertTrue( + combined_nfa.accepts_input("Latin-ExtensionĀāGreekΑΒΓCyrillicЀЁЂ") + ) + + self.assertFalse(combined_nfa.accepts_input("Latin-ExtensionaGreekͰCyrillicЀ")) + self.assertFalse(combined_nfa.accepts_input("Latin-Extension¡GreekACyrillicЀ")) + self.assertFalse(combined_nfa.accepts_input("Latin-Extension¡GreekͰCyrillicA")) + + non_latin_nfa = NFA.from_regex("[^¡-ƿ]+", input_symbols=input_symbols) + self.assertTrue(non_latin_nfa.accepts_input("abc")) + self.assertTrue(non_latin_nfa.accepts_input("ЀЁЏӿ")) + self.assertTrue(non_latin_nfa.accepts_input("ͰΑΘ")) + self.assertFalse(non_latin_nfa.accepts_input("¡")) + self.assertFalse(non_latin_nfa.accepts_input("Ā")) + self.assertFalse(non_latin_nfa.accepts_input("a¡")) + + alphabet = set("abcdefghijklmnopqrstuvwxyz") + alphabet = alphabet - RESERVED_CHARACTERS + safe_input_symbols = input_symbols.union(alphabet) + + ascii_range_nfa = NFA.from_regex("[i-p]+", input_symbols=safe_input_symbols) + for char in "ijklmnop": + self.assertTrue( + ascii_range_nfa.accepts_input(char), f"Should accept {char}" + ) + for char in "abcdefgh": + self.assertFalse( + ascii_range_nfa.accepts_input(char), f"Should reject {char}" + ) + for char in "qrstuvwxyz": + self.assertFalse( + ascii_range_nfa.accepts_input(char), f"Should reject {char}" + ) From 7c291b6fd32bbe45622d507b8f715c6d39bcc943 Mon Sep 17 00:00:00 2001 From: Stefanos Chaliasos Date: Thu, 6 Mar 2025 14:55:48 +0200 Subject: [PATCH 2/8] Support class characters when no input_symbols are given --- automata/fa/nfa.py | 27 ++++++++++++++++++++++++--- automata/regex/parser.py | 29 +++++++++++++++++++++++++++-- tests/test_regex.py | 8 ++++++++ 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/automata/fa/nfa.py b/automata/fa/nfa.py index 548664b1..c51034d4 100644 --- a/automata/fa/nfa.py +++ b/automata/fa/nfa.py @@ -21,6 +21,7 @@ ) import networkx as nx +import re from cached_method import cached_method from frozendict import frozendict from typing_extensions import Self, TypeAlias @@ -216,16 +217,36 @@ def from_regex( Self The NFA accepting the language of the input regex. """ - if input_symbols is None: - input_symbols = frozenset(regex) - RESERVED_CHARACTERS + input_symbols_set = set() + + range_pattern = re.compile(r'\[([^\]]*)\]') + for match in range_pattern.finditer(regex): + class_content = match.group(1) + pos = 0 + while pos < len(class_content): + if pos + 2 < len(class_content) and class_content[pos+1] == '-': + start_char, end_char = class_content[pos], class_content[pos+2] + if ord(start_char) <= ord(end_char): + for i in range(ord(start_char), ord(end_char) + 1): + input_symbols_set.add(chr(i)) + pos += 3 + else: + if class_content[pos] != '^': + input_symbols_set.add(class_content[pos]) + pos += 1 + for char in regex: + if char not in RESERVED_CHARACTERS: + input_symbols_set.add(char) + + input_symbols = frozenset(input_symbols_set) else: conflicting_symbols = RESERVED_CHARACTERS & input_symbols if conflicting_symbols: raise exceptions.InvalidSymbolError( f"Invalid input symbols: {conflicting_symbols}" ) - + nfa_builder = parse_regex(regex, input_symbols) return cls( diff --git a/automata/regex/parser.py b/automata/regex/parser.py index 13417494..e86a7df6 100644 --- a/automata/regex/parser.py +++ b/automata/regex/parser.py @@ -554,8 +554,33 @@ def __init__( self.counter = counter @classmethod - def from_match(cls: Type[Self], match: re.Match) -> NoReturn: - raise NotImplementedError + def from_match(cls: Type[Self], match: re.Match) -> Self: + range_pattern = re.compile(r"([^\\])-([^\\])") + content = match.group(1) + + # Process character ranges and build full content + pos = 0 + expanded_content = "" + while pos < len(content): + if pos + 2 < len(content) and content[pos+1] == '-': + start_char, end_char = content[pos], content[pos+2] + if ord(start_char) <= ord(end_char): + # Include all characters in the range + expanded_content += ''.join(chr(i) for i in range(ord(start_char), ord(end_char) + 1)) + pos += 3 + else: + # Invalid range - just add characters as is + expanded_content += content[pos] + pos += 1 + else: + expanded_content += content[pos] + pos += 1 + + is_negated = content.startswith("^") + if is_negated: + expanded_content = expanded_content[1:] # Remove ^ from the content + + return cls(match.group(), expanded_content, is_negated) def val(self) -> NFARegexBuilder: if self.negated: diff --git a/tests/test_regex.py b/tests/test_regex.py index a2a48df7..8b9b1452 100644 --- a/tests/test_regex.py +++ b/tests/test_regex.py @@ -279,6 +279,14 @@ def test_character_class(self) -> None: self.assertFalse(nfa3.accepts_input("ab")) self.assertFalse(nfa3.accepts_input("abc")) + # Test character class with provided input symbols + nfa4 = NFA.from_regex("[a-zA-Z]+") + self.assertTrue(nfa4.accepts_input("Hello")) + self.assertTrue(nfa4.accepts_input("world")) + self.assertFalse(nfa4.accepts_input("123")) + self.assertFalse(nfa4.accepts_input("123abc")) + self.assertFalse(nfa4.accepts_input("abc123")) + input_symbols = {"a", "b", "c", "d", "e", "0", "1", "2", "3"} # Character class with quantifiers self.assertTrue( From db5b6eb5f3e695f89390915fcdd110347c879750 Mon Sep 17 00:00:00 2001 From: Stefanos Chaliasos Date: Thu, 6 Mar 2025 15:20:41 +0200 Subject: [PATCH 3/8] Lint fixes and one more test --- automata/fa/gnfa.py | 24 ++++++++++++------------ automata/fa/nfa.py | 19 +++++++++++-------- automata/regex/parser.py | 15 ++++++++------- tests/test_regex.py | 6 ++++++ 4 files changed, 37 insertions(+), 27 deletions(-) diff --git a/automata/fa/gnfa.py b/automata/fa/gnfa.py index d804595b..f0a0d51d 100644 --- a/automata/fa/gnfa.py +++ b/automata/fa/gnfa.py @@ -129,9 +129,9 @@ def from_dfa(cls: Type[Self], target_dfa: dfa.DFA) -> Self: if state in target_dfa.transitions: for input_symbol, to_state in target_dfa.transitions[state].items(): if to_state in gnfa_transitions.keys(): - gnfa_transitions[ - to_state - ] = f"{gnfa_transitions[to_state]}|{input_symbol}" + gnfa_transitions[to_state] = ( + f"{gnfa_transitions[to_state]}|{input_symbol}" + ) else: gnfa_transitions[to_state] = input_symbol new_gnfa_transitions[state] = gnfa_transitions @@ -190,17 +190,17 @@ def from_nfa(cls: Type[Self], target_nfa: nfa.NFA) -> Self: gnfa_transitions[to_state] != "" and input_symbol == "" ): if cls._isbracket_req(gnfa_transitions[to_state]): - gnfa_transitions[ - to_state - ] = f"({gnfa_transitions[to_state]})?" + gnfa_transitions[to_state] = ( + f"({gnfa_transitions[to_state]})?" + ) else: - gnfa_transitions[ - to_state - ] = f"{gnfa_transitions[to_state]}?" + gnfa_transitions[to_state] = ( + f"{gnfa_transitions[to_state]}?" + ) else: - gnfa_transitions[ - to_state - ] = f"{gnfa_transitions[to_state]}|{input_symbol}" + gnfa_transitions[to_state] = ( + f"{gnfa_transitions[to_state]}|{input_symbol}" + ) else: gnfa_transitions[to_state] = input_symbol new_gnfa_transitions[state] = cast( diff --git a/automata/fa/nfa.py b/automata/fa/nfa.py index c51034d4..ca193e90 100644 --- a/automata/fa/nfa.py +++ b/automata/fa/nfa.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re from collections import deque from itertools import chain, count, product, repeat from typing import ( @@ -21,7 +22,6 @@ ) import networkx as nx -import re from cached_method import cached_method from frozendict import frozendict from typing_extensions import Self, TypeAlias @@ -219,26 +219,29 @@ def from_regex( """ if input_symbols is None: input_symbols_set = set() - - range_pattern = re.compile(r'\[([^\]]*)\]') + + range_pattern = re.compile(r"\[([^\]]*)\]") for match in range_pattern.finditer(regex): class_content = match.group(1) pos = 0 while pos < len(class_content): - if pos + 2 < len(class_content) and class_content[pos+1] == '-': - start_char, end_char = class_content[pos], class_content[pos+2] + if pos + 2 < len(class_content) and class_content[pos + 1] == "-": + start_char, end_char = ( + class_content[pos], + class_content[pos + 2], + ) if ord(start_char) <= ord(end_char): for i in range(ord(start_char), ord(end_char) + 1): input_symbols_set.add(chr(i)) pos += 3 else: - if class_content[pos] != '^': + if class_content[pos] != "^": input_symbols_set.add(class_content[pos]) pos += 1 for char in regex: if char not in RESERVED_CHARACTERS: input_symbols_set.add(char) - + input_symbols = frozenset(input_symbols_set) else: conflicting_symbols = RESERVED_CHARACTERS & input_symbols @@ -246,7 +249,7 @@ def from_regex( raise exceptions.InvalidSymbolError( f"Invalid input symbols: {conflicting_symbols}" ) - + nfa_builder = parse_regex(regex, input_symbols) return cls( diff --git a/automata/regex/parser.py b/automata/regex/parser.py index e86a7df6..e76ae80d 100644 --- a/automata/regex/parser.py +++ b/automata/regex/parser.py @@ -555,18 +555,19 @@ def __init__( @classmethod def from_match(cls: Type[Self], match: re.Match) -> Self: - range_pattern = re.compile(r"([^\\])-([^\\])") content = match.group(1) - + # Process character ranges and build full content pos = 0 expanded_content = "" while pos < len(content): - if pos + 2 < len(content) and content[pos+1] == '-': - start_char, end_char = content[pos], content[pos+2] + if pos + 2 < len(content) and content[pos + 1] == "-": + start_char, end_char = content[pos], content[pos + 2] if ord(start_char) <= ord(end_char): # Include all characters in the range - expanded_content += ''.join(chr(i) for i in range(ord(start_char), ord(end_char) + 1)) + expanded_content += "".join( + chr(i) for i in range(ord(start_char), ord(end_char) + 1) + ) pos += 3 else: # Invalid range - just add characters as is @@ -575,11 +576,11 @@ def from_match(cls: Type[Self], match: re.Match) -> Self: else: expanded_content += content[pos] pos += 1 - + is_negated = content.startswith("^") if is_negated: expanded_content = expanded_content[1:] # Remove ^ from the content - + return cls(match.group(), expanded_content, is_negated) def val(self) -> NFARegexBuilder: diff --git a/tests/test_regex.py b/tests/test_regex.py index 8b9b1452..3778e49c 100644 --- a/tests/test_regex.py +++ b/tests/test_regex.py @@ -337,6 +337,12 @@ def test_character_class(self) -> None: nfa2 = NFA.from_regex("(a|b|c)(a|b|c)((a|b|c)*)", input_symbols=input_symbols) self.assertEqual(nfa1, nfa2) + # Test character class with reserved characters + nfa1 = NFA.from_regex("[a+]") + self.assertTrue(nfa1.accepts_input("a")) + self.assertTrue(nfa1.accepts_input("+")) + self.assertFalse(nfa1.accepts_input("b")) + def test_unicode_character_classes(self) -> None: """Should correctly handle Unicode character ranges in character classes""" From 298f55930b3c47c51d7cbd041c566569fc315d10 Mon Sep 17 00:00:00 2001 From: Stefanos Chaliasos Date: Thu, 6 Mar 2025 16:20:31 +0200 Subject: [PATCH 4/8] Fixx issue with reserved characters inside character class --- automata/fa/nfa.py | 67 +++++++++++++++++++++++++++------------------ tests/test_regex.py | 25 +++++++++++++++-- 2 files changed, 62 insertions(+), 30 deletions(-) diff --git a/automata/fa/nfa.py b/automata/fa/nfa.py index ca193e90..30927f1e 100644 --- a/automata/fa/nfa.py +++ b/automata/fa/nfa.py @@ -217,44 +217,57 @@ def from_regex( Self The NFA accepting the language of the input regex. """ + # First check user-provided input_symbols for reserved characters + if input_symbols is not None: + conflicting_symbols = RESERVED_CHARACTERS & input_symbols + if conflicting_symbols: + raise exceptions.InvalidSymbolError( + f"Invalid input symbols: {conflicting_symbols}" + ) + + # Extract all characters from character classes + class_symbols = set() + range_pattern = re.compile(r"\[([^\]]*)\]") + for match in range_pattern.finditer(regex): + class_content = match.group(1) + pos = 0 + while pos < len(class_content): + if pos + 2 < len(class_content) and class_content[pos + 1] == "-": + start_char, end_char = ( + class_content[pos], + class_content[pos + 2], + ) + if ord(start_char) <= ord(end_char): + for i in range(ord(start_char), ord(end_char) + 1): + class_symbols.add(chr(i)) + pos += 3 + else: + if class_content[pos] != "^": # Skip negation symbol + class_symbols.add(class_content[pos]) + pos += 1 + + # Set up the final input symbols if input_symbols is None: + # If no input_symbols provided, collect all non-reserved chars from regex input_symbols_set = set() - - range_pattern = re.compile(r"\[([^\]]*)\]") - for match in range_pattern.finditer(regex): - class_content = match.group(1) - pos = 0 - while pos < len(class_content): - if pos + 2 < len(class_content) and class_content[pos + 1] == "-": - start_char, end_char = ( - class_content[pos], - class_content[pos + 2], - ) - if ord(start_char) <= ord(end_char): - for i in range(ord(start_char), ord(end_char) + 1): - input_symbols_set.add(chr(i)) - pos += 3 - else: - if class_content[pos] != "^": - input_symbols_set.add(class_content[pos]) - pos += 1 for char in regex: if char not in RESERVED_CHARACTERS: input_symbols_set.add(char) - input_symbols = frozenset(input_symbols_set) + # Include all character class symbols + input_symbols_set.update(class_symbols) + final_input_symbols = frozenset(input_symbols_set) else: - conflicting_symbols = RESERVED_CHARACTERS & input_symbols - if conflicting_symbols: - raise exceptions.InvalidSymbolError( - f"Invalid input symbols: {conflicting_symbols}" - ) + # For user-provided input_symbols, we need to update with character class + # Create a copy to avoid modifying the original input_symbols + final_input_symbols = frozenset(input_symbols).union(class_symbols) - nfa_builder = parse_regex(regex, input_symbols) + # Build the NFA + nfa_builder = parse_regex(regex, final_input_symbols) return cls( states=frozenset(nfa_builder._transitions.keys()), - input_symbols=input_symbols, + input_symbols=final_input_symbols, transitions=nfa_builder._transitions, initial_state=nfa_builder._initial_state, final_states=nfa_builder._final_states, diff --git a/tests/test_regex.py b/tests/test_regex.py index 3778e49c..fae595f1 100644 --- a/tests/test_regex.py +++ b/tests/test_regex.py @@ -6,7 +6,7 @@ import automata.base.exceptions as exceptions import automata.regex.regex as re -from automata.fa.nfa import NFA +from automata.fa.nfa import NFA, RESERVED_CHARACTERS from automata.regex.parser import StringToken, WildcardToken @@ -343,6 +343,27 @@ def test_character_class(self) -> None: self.assertTrue(nfa1.accepts_input("+")) self.assertFalse(nfa1.accepts_input("b")) + # One more more complex test with and without input symbols + input_symbols = set(string.printable) - RESERVED_CHARACTERS + nfa1 = NFA.from_regex("[a-zA-Z0-9._%+-]+", input_symbols=input_symbols) + self.assertTrue(nfa1.accepts_input("a")) + self.assertTrue(nfa1.accepts_input("1")) + self.assertTrue(nfa1.accepts_input(".")) + self.assertTrue(nfa1.accepts_input("%")) + self.assertTrue(nfa1.accepts_input("+")) + self.assertFalse(nfa1.accepts_input("")) + self.assertFalse(nfa1.accepts_input("$")) + self.assertFalse(nfa1.accepts_input("{")) + nfa2 = NFA.from_regex("[a-zA-Z0-9._%+-]+") + self.assertTrue(nfa2.accepts_input("a")) + self.assertTrue(nfa2.accepts_input("1")) + self.assertTrue(nfa2.accepts_input(".")) + self.assertTrue(nfa2.accepts_input("%")) + self.assertTrue(nfa2.accepts_input("+")) + self.assertFalse(nfa2.accepts_input("")) + self.assertFalse(nfa2.accepts_input("$")) + self.assertFalse(nfa2.accepts_input("{")) + def test_unicode_character_classes(self) -> None: """Should correctly handle Unicode character ranges in character classes""" @@ -361,8 +382,6 @@ def create_range(start_char, end_char): ascii_chars = set(string.printable) input_symbols.update(ascii_chars) - from automata.fa.nfa import RESERVED_CHARACTERS - input_symbols = input_symbols - RESERVED_CHARACTERS latin_nfa = NFA.from_regex("[¡-ƿ]+", input_symbols=input_symbols) From ce0a5acf3b7b5be13ac4c6e54a1d91ea28168aad Mon Sep 17 00:00:00 2001 From: Stefanos Chaliasos Date: Thu, 6 Mar 2025 21:08:36 +0200 Subject: [PATCH 5/8] Add support for escaped characters and properly handle special chars like ('\') --- automata/fa/nfa.py | 83 +++++++++++++++++++++++++++++++------- automata/regex/parser.py | 86 ++++++++++++++++++++++++++++++++++++---- tests/test_regex.py | 86 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 234 insertions(+), 21 deletions(-) diff --git a/automata/fa/nfa.py b/automata/fa/nfa.py index 30927f1e..57a10ad2 100644 --- a/automata/fa/nfa.py +++ b/automata/fa/nfa.py @@ -217,30 +217,82 @@ def from_regex( Self The NFA accepting the language of the input regex. """ - # First check user-provided input_symbols for reserved characters if input_symbols is not None: - conflicting_symbols = RESERVED_CHARACTERS & input_symbols + # Create a modified set of reserved characters that doesn't include + # whitespace + whitespace_chars = {" ", "\t", "\n", "\r", "\f", "\v"} + non_whitespace_reserved = RESERVED_CHARACTERS - whitespace_chars + + conflicting_symbols = non_whitespace_reserved & input_symbols if conflicting_symbols: raise exceptions.InvalidSymbolError( f"Invalid input symbols: {conflicting_symbols}" ) - # Extract all characters from character classes + # Extract escaped sequences from the regex + escape_chars = set() + i = 0 + while i < len(regex): + if regex[i] == "\\" and i + 1 < len(regex): + from automata.regex.parser import _handle_escape_sequences + + escaped_char = _handle_escape_sequences(regex[i + 1]) + escape_chars.add(escaped_char) + i += 2 + else: + i += 1 + class_symbols = set() range_pattern = re.compile(r"\[([^\]]*)\]") for match in range_pattern.finditer(regex): class_content = match.group(1) pos = 0 while pos < len(class_content): - if pos + 2 < len(class_content) and class_content[pos + 1] == "-": - start_char, end_char = ( - class_content[pos], - class_content[pos + 2], - ) - if ord(start_char) <= ord(end_char): + if class_content[pos] == "\\" and pos + 1 < len(class_content): + # Handle escape sequence in character class + from automata.regex.parser import _handle_escape_sequences + + escaped_char = _handle_escape_sequences(class_content[pos + 1]) + class_symbols.add(escaped_char) + + # Check if this is part of a range + if ( + pos + 2 < len(class_content) + and class_content[pos + 2] == "-" + and pos + 3 < len(class_content) + ): + # Handle range with escaped start character + start_char = escaped_char + + # Check if end character is also escaped + if class_content[pos + 3] == "\\" and pos + 4 < len( + class_content + ): + end_char = _handle_escape_sequences(class_content[pos + 4]) + pos += 5 + else: + end_char = class_content[pos + 3] + pos += 4 + + # Add all characters in the range to input symbols for i in range(ord(start_char), ord(end_char) + 1): class_symbols.add(chr(i)) - pos += 3 + continue + + pos += 2 + elif pos + 2 < len(class_content) and class_content[pos + 1] == "-": + # Handle normal range + start_char = class_content[pos] + + if class_content[pos + 2] == "\\" and pos + 3 < len(class_content): + end_char = _handle_escape_sequences(class_content[pos + 3]) + pos += 4 + else: + end_char = class_content[pos + 2] + pos += 3 + + for i in range(ord(start_char), ord(end_char) + 1): + class_symbols.add(chr(i)) else: if class_content[pos] != "^": # Skip negation symbol class_symbols.add(class_content[pos]) @@ -254,13 +306,16 @@ def from_regex( if char not in RESERVED_CHARACTERS: input_symbols_set.add(char) - # Include all character class symbols + # Include all character class symbols and escape sequences input_symbols_set.update(class_symbols) + input_symbols_set.update(escape_chars) final_input_symbols = frozenset(input_symbols_set) else: - # For user-provided input_symbols, we need to update with character class - # Create a copy to avoid modifying the original input_symbols - final_input_symbols = frozenset(input_symbols).union(class_symbols) + # For user-provided input_symbols, we need to update + # with character class symbols and escape sequences + final_input_symbols = ( + frozenset(input_symbols).union(class_symbols).union(escape_chars) + ) # Build the NFA nfa_builder = parse_regex(regex, final_input_symbols) diff --git a/automata/regex/parser.py b/automata/regex/parser.py index e76ae80d..63dd9c92 100644 --- a/automata/regex/parser.py +++ b/automata/regex/parser.py @@ -638,6 +638,13 @@ def get_regex_lexer( """Get lexer for parsing regular expressions.""" lexer: Lexer[NFARegexBuilder] = Lexer() + # Process the input string first to handle escape sequences + def process_string_factory(match: re.Match) -> StringToken: + text = match.group() + if text.startswith("\\") and len(text) > 1: + return StringToken(_handle_escape_sequences(text[1]), state_name_counter) + return StringToken(text, state_name_counter) + lexer.register_token(LeftParen.from_match, r"\(") lexer.register_token(RightParen.from_match, r"\)") lexer.register_token(UnionToken.from_match, r"\|") @@ -648,13 +655,11 @@ def get_regex_lexer( lexer.register_token(OptionToken.from_match, r"\?") # Match both {n}, {n,m}, and {,m} formats for quantifiers lexer.register_token(QuantifierToken.from_match, r"\{(-?\d*)(?:,(-?\d*))?\}") - # Register wildcard and character classes next lexer.register_token( lambda match: WildcardToken(match.group(), input_symbols, state_name_counter), r"\.", ) - # Add character class token def character_class_factory(match: re.Match) -> CharacterClassToken: class_str = match.group() negated, class_chars = process_char_class(class_str) @@ -667,6 +672,11 @@ def character_class_factory(match: re.Match) -> CharacterClassToken: r"\[[^\]]*\]", # Match anything between [ and ] ) + # Handle escaped sequences first - must come before general character match + lexer.register_token( + process_string_factory, + r"\\.", + ) lexer.register_token( lambda match: StringToken(match.group(), state_name_counter), r"\S" ) @@ -674,6 +684,38 @@ def character_class_factory(match: re.Match) -> CharacterClassToken: return lexer +def _handle_escape_sequences(char: str) -> str: + """Convert escape sequences to their actual character representation.""" + escape_map = { + "n": "\n", + "r": "\r", + "t": "\t", + "v": "\v", + "f": "\f", + "a": "\a", + "b": "\b", + "\\": "\\", + "+": "+", + "*": "*", + "?": "?", + ".": ".", + "|": "|", + "(": "(", + ")": ")", + "[": "[", + "]": "]", + "{": "{", + "}": "}", + "^": "^", + "$": "$", + "&": "&", + } + + if char in escape_map: + return escape_map[char] + return char + + def parse_regex(regexstr: str, input_symbols: AbstractSet[str]) -> NFARegexBuilder: """Return an NFARegexBuilder corresponding to regexstr.""" @@ -723,16 +765,46 @@ def process_char_class(class_str: str) -> Tuple[bool, Set[str]]: chars = set() i = 0 while i < len(content): + # Handle escape sequences + if content[i] == "\\" and i + 1 < len(content): + escaped_char = _handle_escape_sequences(content[i + 1]) + + if i + 2 < len(content) and content[i + 2] == "-" and i + 3 < len(content): + start_char = escaped_char + + if content[i + 3] == "\\" and i + 4 < len(content): + end_char = _handle_escape_sequences(content[i + 4]) + i += 5 + else: + end_char = content[i + 3] + i += 4 + + for code in range(ord(start_char), ord(end_char) + 1): + chars.add(chr(code)) + else: + chars.add(escaped_char) + i += 2 # Special case: - at the beginning or end is treated as literal - if content[i] == "-" and (i == 0 or i == len(content) - 1): + elif content[i] == "-" and (i == 0 or i == len(content) - 1): chars.add("-") i += 1 # Handle ranges - but only when there are characters on both sides elif i + 2 < len(content) and content[i + 1] == "-": - # Range like a-z - start, end = content[i], content[i + 2] - chars.update(chr(c) for c in range(ord(start), ord(end) + 1)) - i += 3 + # Check if end is an escape sequence + if content[i + 2] == "\\" and i + 3 < len(content): + start_char = content[i] + end_char = _handle_escape_sequences(content[i + 3]) + # Add all characters in the range + for code in range(ord(start_char), ord(end_char) + 1): + chars.add(chr(code)) + i += 4 + else: + # Regular range like a-z + start_char, end_char = content[i], content[i + 2] + # Add all characters in the range + for code in range(ord(start_char), ord(end_char) + 1): + chars.add(chr(code)) + i += 3 else: chars.add(content[i]) i += 1 diff --git a/tests/test_regex.py b/tests/test_regex.py index fae595f1..1e75441c 100644 --- a/tests/test_regex.py +++ b/tests/test_regex.py @@ -453,3 +453,89 @@ def create_range(start_char, end_char): self.assertFalse( ascii_range_nfa.accepts_input(char), f"Should reject {char}" ) + + def test_escape_characters(self) -> None: + """Should correctly handle escape sequences in regex patterns""" + # Basic escape sequences + nfa_newline = NFA.from_regex("\\n") + self.assertTrue(nfa_newline.accepts_input("\n")) + self.assertFalse(nfa_newline.accepts_input("n")) + self.assertFalse(nfa_newline.accepts_input("\\n")) + + nfa_carriage = NFA.from_regex("\\r") + self.assertTrue(nfa_carriage.accepts_input("\r")) + self.assertFalse(nfa_carriage.accepts_input("r")) + + nfa_tab = NFA.from_regex("\\t") + self.assertTrue(nfa_tab.accepts_input("\t")) + self.assertFalse(nfa_tab.accepts_input("t")) + + # Escaping special regex characters + nfa_plus = NFA.from_regex("\\+") + self.assertTrue(nfa_plus.accepts_input("+")) + self.assertFalse(nfa_plus.accepts_input("\\+")) + + nfa_star = NFA.from_regex("\\*") + self.assertTrue(nfa_star.accepts_input("*")) + self.assertFalse(nfa_star.accepts_input("\\*")) + + # Multiple escape sequences + nfa_multi = NFA.from_regex("\\n\\r\\t") + self.assertTrue(nfa_multi.accepts_input("\n\r\t")) + self.assertFalse(nfa_multi.accepts_input("\n\r")) + + # Character classes with escape sequences + nfa_class = NFA.from_regex("[\\n\\r\\t]") + self.assertTrue(nfa_class.accepts_input("\n")) + self.assertTrue(nfa_class.accepts_input("\r")) + self.assertTrue(nfa_class.accepts_input("\t")) + self.assertFalse(nfa_class.accepts_input("n")) + + # Character class ranges with escape sequences + nfa_range = NFA.from_regex("[\\n-\\r]") + self.assertTrue(nfa_range.accepts_input("\n")) + self.assertTrue(nfa_range.accepts_input("\r")) + self.assertTrue(nfa_range.accepts_input("\013")) + self.assertFalse(nfa_range.accepts_input("\t")) + + # Escape sequences with repetition operators + nfa_repeat = NFA.from_regex("\\n+") + self.assertTrue(nfa_repeat.accepts_input("\n")) + self.assertTrue(nfa_repeat.accepts_input("\n\n")) + self.assertTrue(nfa_repeat.accepts_input("\n\n\n")) + self.assertFalse(nfa_repeat.accepts_input("")) + + # Complex patterns with escape sequences + nfa_complex = NFA.from_regex("a\\nb+\\r(c|d)*") + self.assertTrue(nfa_complex.accepts_input("a\nb\r")) + self.assertTrue(nfa_complex.accepts_input("a\nbb\r")) + self.assertTrue(nfa_complex.accepts_input("a\nb\rcd")) + self.assertTrue(nfa_complex.accepts_input("a\nb\rcddc")) + self.assertFalse(nfa_complex.accepts_input("anb\r")) + + # Backslash escaping itself + nfa_backslash = NFA.from_regex("\\\\") + self.assertTrue(nfa_backslash.accepts_input("\\")) + self.assertFalse(nfa_backslash.accepts_input("\\\\")) + + # Testing another regex + nfa_complex2 = NFA.from_regex("a\\nb+\\r(c|d)*") + self.assertTrue(nfa_complex2.accepts_input("a\nb\r")) + self.assertTrue(nfa_complex2.accepts_input("a\nbb\r")) + self.assertTrue(nfa_complex2.accepts_input("a\nb\rcd")) + self.assertTrue(nfa_complex2.accepts_input("a\nb\rcddc")) + self.assertFalse(nfa_complex2.accepts_input("anb\r")) + + # Escaped whitespace in character classes + nfa_whitespace = NFA.from_regex("[\\n\\r\\t ]+") + self.assertTrue(nfa_whitespace.accepts_input("\n\r\t ")) + self.assertTrue(nfa_whitespace.accepts_input(" \n")) + self.assertTrue(nfa_whitespace.accepts_input("\t\r")) + self.assertFalse(nfa_whitespace.accepts_input("a")) + + # Common escape sequences with input symbols + input_symbols = set("abc\n\r\t ") + nfa_with_symbols = NFA.from_regex("a\\nb[\\r\\t]c", input_symbols=input_symbols) + self.assertTrue(nfa_with_symbols.accepts_input("a\nb\rc")) + self.assertTrue(nfa_with_symbols.accepts_input("a\nb\tc")) + self.assertFalse(nfa_with_symbols.accepts_input("a\nbc")) From 00ac1616e3844b8767a78d247fc0ae955a6c21c5 Mon Sep 17 00:00:00 2001 From: Eliot Robson Date: Thu, 6 Mar 2025 13:10:16 -0600 Subject: [PATCH 6/8] Add missing annotation --- tests/test_regex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_regex.py b/tests/test_regex.py index fae595f1..0ba98c9c 100644 --- a/tests/test_regex.py +++ b/tests/test_regex.py @@ -367,7 +367,7 @@ def test_character_class(self) -> None: def test_unicode_character_classes(self) -> None: """Should correctly handle Unicode character ranges in character classes""" - def create_range(start_char, end_char): + def create_range(start_char: str, end_char: str) -> set[str]: return {chr(i) for i in range(ord(start_char), ord(end_char) + 1)} latin_ext_chars = create_range("¡", "ƿ") From f400c345487acf614652b48f152a60f77c632dfb Mon Sep 17 00:00:00 2001 From: Stefanos Chaliasos Date: Fri, 7 Mar 2025 08:52:51 +0200 Subject: [PATCH 7/8] Add support for shorthands --- automata/fa/nfa.py | 67 ++++++++++++++++++++++++++- automata/regex/parser.py | 98 +++++++++++++++++++++++++++++++++------- tests/test_regex.py | 86 +++++++++++++++++++++++++++++++++++ 3 files changed, 234 insertions(+), 17 deletions(-) diff --git a/automata/fa/nfa.py b/automata/fa/nfa.py index 57a10ad2..5f5d612b 100644 --- a/automata/fa/nfa.py +++ b/automata/fa/nfa.py @@ -229,6 +229,33 @@ def from_regex( f"Invalid input symbols: {conflicting_symbols}" ) + # Import the shorthand character classes + from automata.regex.parser import ( + DIGIT_CHARS, + NON_DIGIT_CHARS, + NON_WHITESPACE_CHARS, + NON_WORD_CHARS, + WHITESPACE_CHARS, + WORD_CHARS, + ) + + # Create a set for additional symbols from shorthand classes + additional_symbols = set() + + # Check for shorthand classes in the regex + if "\\s" in regex: + additional_symbols.update(WHITESPACE_CHARS) + if "\\S" in regex: + additional_symbols.update(NON_WHITESPACE_CHARS) + if "\\d" in regex: + additional_symbols.update(DIGIT_CHARS) + if "\\D" in regex: + additional_symbols.update(NON_DIGIT_CHARS) + if "\\w" in regex: + additional_symbols.update(WORD_CHARS) + if "\\W" in regex: + additional_symbols.update(NON_WORD_CHARS) + # Extract escaped sequences from the regex escape_chars = set() i = 0 @@ -236,6 +263,11 @@ def from_regex( if regex[i] == "\\" and i + 1 < len(regex): from automata.regex.parser import _handle_escape_sequences + # Skip shorthand classes + if regex[i + 1] in "sSwWdD": + i += 2 + continue + escaped_char = _handle_escape_sequences(regex[i + 1]) escape_chars.add(escaped_char) i += 2 @@ -249,6 +281,32 @@ def from_regex( pos = 0 while pos < len(class_content): if class_content[pos] == "\\" and pos + 1 < len(class_content): + # Check for shorthand classes in character classes + if class_content[pos + 1] == "s": + additional_symbols.update(WHITESPACE_CHARS) + pos += 2 + continue + elif class_content[pos + 1] == "d": + additional_symbols.update(DIGIT_CHARS) + pos += 2 + continue + elif class_content[pos + 1] == "w": + additional_symbols.update(WORD_CHARS) + pos += 2 + continue + elif class_content[pos + 1] in "S": + additional_symbols.update(NON_WHITESPACE_CHARS) + pos += 2 + continue + elif class_content[pos + 1] in "D": + additional_symbols.update(NON_DIGIT_CHARS) + pos += 2 + continue + elif class_content[pos + 1] in "W": + additional_symbols.update(NON_WORD_CHARS) + pos += 2 + continue + # Handle escape sequence in character class from automata.regex.parser import _handle_escape_sequences @@ -309,12 +367,19 @@ def from_regex( # Include all character class symbols and escape sequences input_symbols_set.update(class_symbols) input_symbols_set.update(escape_chars) + + # Add the shorthand characters + input_symbols_set.update(additional_symbols) + final_input_symbols = frozenset(input_symbols_set) else: # For user-provided input_symbols, we need to update # with character class symbols and escape sequences final_input_symbols = ( - frozenset(input_symbols).union(class_symbols).union(escape_chars) + frozenset(input_symbols) + .union(class_symbols) + .union(escape_chars) + .union(additional_symbols) ) # Build the NFA diff --git a/automata/regex/parser.py b/automata/regex/parser.py index 63dd9c92..de30e73f 100644 --- a/automata/regex/parser.py +++ b/automata/regex/parser.py @@ -4,6 +4,7 @@ import copy import re +import string from collections import deque from itertools import chain, count, product, repeat from typing import AbstractSet, Deque, Dict, Iterable, List, Optional, Set, Tuple, Type @@ -24,6 +25,17 @@ validate_tokens, ) +# Add these at the top of the file to define our shorthand character sets +ASCII_PRINTABLE_CHARS = frozenset(string.printable) +WHITESPACE_CHARS = frozenset(" \t\n\r\f\v") +NON_WHITESPACE_CHARS = ASCII_PRINTABLE_CHARS - WHITESPACE_CHARS +DIGIT_CHARS = frozenset("0123456789") +NON_DIGIT_CHARS = ASCII_PRINTABLE_CHARS - DIGIT_CHARS +WORD_CHARS = frozenset( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" +) +NON_WORD_CHARS = ASCII_PRINTABLE_CHARS - WORD_CHARS + BuilderTransitionsT = Dict[int, Dict[str, Set[int]]] RESERVED_CHARACTERS = frozenset( @@ -638,13 +650,7 @@ def get_regex_lexer( """Get lexer for parsing regular expressions.""" lexer: Lexer[NFARegexBuilder] = Lexer() - # Process the input string first to handle escape sequences - def process_string_factory(match: re.Match) -> StringToken: - text = match.group() - if text.startswith("\\") and len(text) > 1: - return StringToken(_handle_escape_sequences(text[1]), state_name_counter) - return StringToken(text, state_name_counter) - + # Register all token types lexer.register_token(LeftParen.from_match, r"\(") lexer.register_token(RightParen.from_match, r"\)") lexer.register_token(UnionToken.from_match, r"\|") @@ -653,13 +659,50 @@ def process_string_factory(match: re.Match) -> StringToken: lexer.register_token(KleeneStarToken.from_match, r"\*") lexer.register_token(KleenePlusToken.from_match, r"\+") lexer.register_token(OptionToken.from_match, r"\?") - # Match both {n}, {n,m}, and {,m} formats for quantifiers lexer.register_token(QuantifierToken.from_match, r"\{(-?\d*)(?:,(-?\d*))?\}") lexer.register_token( lambda match: WildcardToken(match.group(), input_symbols, state_name_counter), r"\.", ) + # Add specific handlers for shorthand character classes + # These need to come BEFORE the general escape handler + lexer.register_token( + lambda match: WildcardToken( + match.group(), WHITESPACE_CHARS, state_name_counter + ), + r"\\s", + ) + lexer.register_token( + lambda match: WildcardToken( + match.group(), + frozenset(input_symbols) - WHITESPACE_CHARS, + state_name_counter, + ), + r"\\S", + ) + lexer.register_token( + lambda match: WildcardToken(match.group(), DIGIT_CHARS, state_name_counter), + r"\\d", + ) + lexer.register_token( + lambda match: WildcardToken( + match.group(), frozenset(input_symbols) - DIGIT_CHARS, state_name_counter + ), + r"\\D", + ) + lexer.register_token( + lambda match: WildcardToken(match.group(), WORD_CHARS, state_name_counter), + r"\\w", + ) + lexer.register_token( + lambda match: WildcardToken( + match.group(), frozenset(input_symbols) - WORD_CHARS, state_name_counter + ), + r"\\W", + ) + + # Character class tokenizer def character_class_factory(match: re.Match) -> CharacterClassToken: class_str = match.group() negated, class_chars = process_char_class(class_str) @@ -672,11 +715,15 @@ def character_class_factory(match: re.Match) -> CharacterClassToken: r"\[[^\]]*\]", # Match anything between [ and ] ) - # Handle escaped sequences first - must come before general character match + # Handle escaped sequences (must come AFTER shorthand handlers) lexer.register_token( - process_string_factory, - r"\\.", + lambda match: StringToken( + _handle_escape_sequences(match.group()[1]), state_name_counter + ), + r"\\.", # Match any escaped character ) + + # Handle regular characters lexer.register_token( lambda match: StringToken(match.group(), state_name_counter), r"\S" ) @@ -767,11 +814,29 @@ def process_char_class(class_str: str) -> Tuple[bool, Set[str]]: while i < len(content): # Handle escape sequences if content[i] == "\\" and i + 1 < len(content): + # Check for shorthand character classes + if content[i + 1] == "s": + chars.update(WHITESPACE_CHARS) + i += 2 + continue + elif content[i + 1] == "d": + chars.update(DIGIT_CHARS) + i += 2 + continue + elif content[i + 1] == "w": + chars.update(WORD_CHARS) + i += 2 + continue + + # Process regular escape sequence escaped_char = _handle_escape_sequences(content[i + 1]) + # Check if this is part of a range if i + 2 < len(content) and content[i + 2] == "-" and i + 3 < len(content): + # Handle range with escaped start character start_char = escaped_char + # Check if end character is also escaped if content[i + 3] == "\\" and i + 4 < len(content): end_char = _handle_escape_sequences(content[i + 4]) i += 5 @@ -779,29 +844,30 @@ def process_char_class(class_str: str) -> Tuple[bool, Set[str]]: end_char = content[i + 3] i += 4 + # Add all characters in the range for code in range(ord(start_char), ord(end_char) + 1): chars.add(chr(code)) else: chars.add(escaped_char) i += 2 - # Special case: - at the beginning or end is treated as literal + # Handle hyphen at the beginning or end elif content[i] == "-" and (i == 0 or i == len(content) - 1): chars.add("-") i += 1 - # Handle ranges - but only when there are characters on both sides + # Handle ranges elif i + 2 < len(content) and content[i + 1] == "-": # Check if end is an escape sequence if content[i + 2] == "\\" and i + 3 < len(content): start_char = content[i] end_char = _handle_escape_sequences(content[i + 3]) - # Add all characters in the range + for code in range(ord(start_char), ord(end_char) + 1): chars.add(chr(code)) i += 4 else: - # Regular range like a-z + # Regular range start_char, end_char = content[i], content[i + 2] - # Add all characters in the range + for code in range(ord(start_char), ord(end_char) + 1): chars.add(chr(code)) i += 3 diff --git a/tests/test_regex.py b/tests/test_regex.py index fc983b47..217ffb9f 100644 --- a/tests/test_regex.py +++ b/tests/test_regex.py @@ -539,3 +539,89 @@ def test_escape_characters(self) -> None: self.assertTrue(nfa_with_symbols.accepts_input("a\nb\rc")) self.assertTrue(nfa_with_symbols.accepts_input("a\nb\tc")) self.assertFalse(nfa_with_symbols.accepts_input("a\nbc")) + + def test_shorthand_character_classes(self) -> None: + """Should correctly handle shorthand character classes""" + + # \s - Any whitespace character + whitespace_nfa = NFA.from_regex("\\s+") + self.assertTrue(whitespace_nfa.accepts_input(" ")) + self.assertTrue(whitespace_nfa.accepts_input("\t")) + self.assertTrue(whitespace_nfa.accepts_input("\n")) + self.assertTrue(whitespace_nfa.accepts_input("\r")) + self.assertTrue(whitespace_nfa.accepts_input("\f")) # form feed + self.assertTrue(whitespace_nfa.accepts_input(" \t\n")) # multiple whitespace + self.assertFalse(whitespace_nfa.accepts_input("a")) + self.assertFalse(whitespace_nfa.accepts_input("1")) + self.assertFalse(whitespace_nfa.accepts_input("a ")) # contains non-whitespace + + # \S - Any non-whitespace character + non_whitespace_nfa = NFA.from_regex("\\S+") + self.assertTrue(non_whitespace_nfa.accepts_input("a")) + self.assertTrue(non_whitespace_nfa.accepts_input("1")) + self.assertTrue(non_whitespace_nfa.accepts_input("abc123")) + self.assertTrue(non_whitespace_nfa.accepts_input("!@#$%^&*()")) + self.assertFalse(non_whitespace_nfa.accepts_input(" ")) + self.assertFalse(non_whitespace_nfa.accepts_input("\t")) + self.assertFalse(non_whitespace_nfa.accepts_input("a ")) # contains whitespace + + # \d - Any digit + digit_nfa = NFA.from_regex("\\d+") + self.assertTrue(digit_nfa.accepts_input("0")) + self.assertTrue(digit_nfa.accepts_input("9")) + self.assertTrue(digit_nfa.accepts_input("0123456789")) + self.assertFalse(digit_nfa.accepts_input("a")) + self.assertFalse(digit_nfa.accepts_input("a1")) # contains non-digit + + # \D - Any non-digit + non_digit_nfa = NFA.from_regex("\\D+") + self.assertTrue(non_digit_nfa.accepts_input("a")) + self.assertTrue(non_digit_nfa.accepts_input("xyz")) + self.assertTrue(non_digit_nfa.accepts_input("!@#$%^&*()")) + self.assertTrue(non_digit_nfa.accepts_input(" \t\n")) # whitespace is non-digit + self.assertFalse(non_digit_nfa.accepts_input("0")) + self.assertFalse(non_digit_nfa.accepts_input("12345")) + self.assertFalse(non_digit_nfa.accepts_input("a1")) # contains digit + + # \w - Any word character (alphanumeric or underscore) + word_nfa = NFA.from_regex("\\w+") + self.assertTrue(word_nfa.accepts_input("a")) + self.assertTrue(word_nfa.accepts_input("Z")) + self.assertTrue(word_nfa.accepts_input("0")) + self.assertTrue(word_nfa.accepts_input("_")) + self.assertTrue(word_nfa.accepts_input("a1_Z")) + self.assertFalse(word_nfa.accepts_input("!")) + self.assertFalse(word_nfa.accepts_input(" ")) + self.assertFalse(word_nfa.accepts_input("a!")) # contains non-word + + # \W - Any non-word character + non_word_nfa = NFA.from_regex("\\W+") + self.assertTrue(non_word_nfa.accepts_input("!")) + self.assertTrue(non_word_nfa.accepts_input("@#$%^&*()")) + self.assertTrue(non_word_nfa.accepts_input(" \t\n")) # whitespace is non-word + self.assertFalse(non_word_nfa.accepts_input("a")) + self.assertFalse(non_word_nfa.accepts_input("Z")) + self.assertFalse(non_word_nfa.accepts_input("0")) + self.assertFalse(non_word_nfa.accepts_input("_")) + self.assertFalse(non_word_nfa.accepts_input("!a")) # contains word + + # Combinations + mixed_nfa = NFA.from_regex("\\d+\\s+\\w+") + self.assertTrue(mixed_nfa.accepts_input("123 abc")) + self.assertTrue(mixed_nfa.accepts_input("456\t_7")) + self.assertFalse(mixed_nfa.accepts_input("abc 123")) # wrong order + self.assertFalse(mixed_nfa.accepts_input("123abc")) # missing whitespace + + # Inside character classes + class_nfa = NFA.from_regex("[\\d\\s]+") + self.assertTrue(class_nfa.accepts_input("123")) + self.assertTrue(class_nfa.accepts_input(" \t\n")) + self.assertTrue(class_nfa.accepts_input("1 2\t3\n")) + self.assertFalse(class_nfa.accepts_input("a")) + + # With other escape sequences + complex_nfa = NFA.from_regex("\\w+\\t\\d+\\n") + self.assertTrue(complex_nfa.accepts_input("abc\t123\n")) + self.assertTrue(complex_nfa.accepts_input("_\t0\n")) + self.assertFalse(complex_nfa.accepts_input("abc 123\n")) # space instead of tab + self.assertFalse(complex_nfa.accepts_input("abc\t123")) # missing newline From df53edcc873dadb2586facfa65cba75dbf0eb063 Mon Sep 17 00:00:00 2001 From: Stefanos Chaliasos Date: Fri, 7 Mar 2025 10:41:32 +0200 Subject: [PATCH 8/8] Allow reserved chars in input symbols and tokenize spaces We also added more complex tests --- automata/fa/nfa.py | 12 -- automata/regex/parser.py | 8 +- tests/test_nfa.py | 8 +- tests/test_regex.py | 251 ++++++++++++++++++++++++++++++++++++--- 4 files changed, 248 insertions(+), 31 deletions(-) diff --git a/automata/fa/nfa.py b/automata/fa/nfa.py index 5f5d612b..e65fb8b4 100644 --- a/automata/fa/nfa.py +++ b/automata/fa/nfa.py @@ -217,18 +217,6 @@ def from_regex( Self The NFA accepting the language of the input regex. """ - if input_symbols is not None: - # Create a modified set of reserved characters that doesn't include - # whitespace - whitespace_chars = {" ", "\t", "\n", "\r", "\f", "\v"} - non_whitespace_reserved = RESERVED_CHARACTERS - whitespace_chars - - conflicting_symbols = non_whitespace_reserved & input_symbols - if conflicting_symbols: - raise exceptions.InvalidSymbolError( - f"Invalid input symbols: {conflicting_symbols}" - ) - # Import the shorthand character classes from automata.regex.parser import ( DIGIT_CHARS, diff --git a/automata/regex/parser.py b/automata/regex/parser.py index de30e73f..7a9e2878 100644 --- a/automata/regex/parser.py +++ b/automata/regex/parser.py @@ -648,7 +648,7 @@ def get_regex_lexer( input_symbols: AbstractSet[str], state_name_counter: count ) -> Lexer[NFARegexBuilder]: """Get lexer for parsing regular expressions.""" - lexer: Lexer[NFARegexBuilder] = Lexer() + lexer: Lexer[NFARegexBuilder] = Lexer(blank_chars=set()) # Register all token types lexer.register_token(LeftParen.from_match, r"\(") @@ -723,6 +723,12 @@ def character_class_factory(match: re.Match) -> CharacterClassToken: r"\\.", # Match any escaped character ) + # Add specific token for space character - this is the key fix + lexer.register_token( + lambda match: StringToken(match.group(), state_name_counter), + r" ", # Match a space character + ) + # Handle regular characters lexer.register_token( lambda match: StringToken(match.group(), state_name_counter), r"\S" diff --git a/tests/test_nfa.py b/tests/test_nfa.py index b895fdde..a356a121 100644 --- a/tests/test_nfa.py +++ b/tests/test_nfa.py @@ -785,7 +785,7 @@ def test_nfa_equality(self) -> None: self.assertEqual( nfa2, NFA.from_regex( - "(((01) | 1)*)((0*1) | (1*0))(((10) | 0)*)", input_symbols=input_symbols + "(((01)|1)*)((0*1)|(1*0))(((10)|0)*)", input_symbols=input_symbols ), ) @@ -807,7 +807,7 @@ def test_nfa_equality(self) -> None: self.assertEqual( nfa3, - NFA.from_regex("(0(0 | 1)*0) | (1(0 | 1)*1)", input_symbols=input_symbols), + NFA.from_regex("(0(0|1)*0)|(1(0|1)*1)", input_symbols=input_symbols), ) nfa4 = NFA( @@ -828,7 +828,7 @@ def test_nfa_equality(self) -> None: self.assertEqual( nfa4, - NFA.from_regex("((0 | 1)*00) | ((0 | 1)*11)", input_symbols=input_symbols), + NFA.from_regex("((0|1)*00)|((0|1)*11)", input_symbols=input_symbols), ) input_symbols_2 = {"0", "1", "2"} @@ -853,7 +853,7 @@ def test_nfa_equality(self) -> None: self.assertEqual( nfa5, NFA.from_regex( - "((((01)*0) | 2)(100)*1)*(1* | (0*2*))", input_symbols=input_symbols_2 + "((((01)*0)|2)(100)*1)*(1*|(0*2*))", input_symbols=input_symbols_2 ), ) diff --git a/tests/test_regex.py b/tests/test_regex.py index 217ffb9f..362e26bd 100644 --- a/tests/test_regex.py +++ b/tests/test_regex.py @@ -6,7 +6,7 @@ import automata.base.exceptions as exceptions import automata.regex.regex as re -from automata.fa.nfa import NFA, RESERVED_CHARACTERS +from automata.fa.nfa import NFA from automata.regex.parser import StringToken, WildcardToken @@ -114,13 +114,13 @@ def test_intersection(self) -> None: # Test intersection subset regex_3 = "bcdaaa" nfa_5 = NFA.from_regex(regex_3) - nfa_6 = NFA.from_regex(f"({regex_3}) & (bcda*)") + nfa_6 = NFA.from_regex(f"({regex_3})&(bcda*)") self.assertEqual(nfa_5, nfa_6) # Test distributive law - regex_4 = f"{regex_1} & (({regex_2}) | ({regex_3}))" - regex_5 = f"(({regex_1}) & ({regex_2})) | (({regex_1}) & ({regex_3}))" + regex_4 = f"{regex_1}&(({regex_2})|({regex_3}))" + regex_5 = f"(({regex_1})&({regex_2}))|(({regex_1})&({regex_3}))" nfa_7 = NFA.from_regex(regex_4) nfa_8 = NFA.from_regex(regex_5) @@ -159,7 +159,7 @@ def test_shuffle(self) -> None: self.assertTrue( re.isequal( "ab^cd", - "abcd | acbd | cabd | acdb | cadb | cdab", + "abcd|acbd|cabd|acdb|cadb|cdab", input_symbols=input_symbols, ) ) @@ -167,10 +167,10 @@ def test_shuffle(self) -> None: re.isequal("(a*)^(b*)^(c*)^(d*)", ".*", input_symbols=input_symbols) ) self.assertTrue( - re.isequal("ca^db", "(c^db)a | (ca^d)b", input_symbols=input_symbols) + re.isequal("ca^db", "(c^db)a|(ca^d)b", input_symbols=input_symbols) ) self.assertTrue( - re.isequal("a^(b|c)", "ab | ac | ba | ca", input_symbols=input_symbols) + re.isequal("a^(b|c)", "ab|ac|ba|ca", input_symbols=input_symbols) ) reference_nfa = NFA.from_regex("a*^ba") @@ -229,10 +229,14 @@ def test_blank(self) -> None: self.assertTrue(re.isequal("a()", "a")) self.assertTrue(re.isequal("a()b()()c()", "abc")) - def test_invalid_symbols(self) -> None: + def test_reserved_characters_handled_correctly(self) -> None: """Should throw exception if reserved character is in input symbols""" - with self.assertRaises(exceptions.InvalidSymbolError): - NFA.from_regex("a+", input_symbols={"a", "+"}) + nfa = NFA.from_regex("a+", input_symbols={"a", "+"}) + self.assertTrue(nfa.accepts_input("a")) + self.assertTrue(nfa.accepts_input("aa")) + self.assertFalse(nfa.accepts_input("a+")) + self.assertFalse(nfa.accepts_input("")) + self.assertFalse(nfa.accepts_input("+")) def test_character_class(self) -> None: """Should correctly handle character classes""" @@ -344,7 +348,7 @@ def test_character_class(self) -> None: self.assertFalse(nfa1.accepts_input("b")) # One more more complex test with and without input symbols - input_symbols = set(string.printable) - RESERVED_CHARACTERS + input_symbols = set(string.printable) nfa1 = NFA.from_regex("[a-zA-Z0-9._%+-]+", input_symbols=input_symbols) self.assertTrue(nfa1.accepts_input("a")) self.assertTrue(nfa1.accepts_input("1")) @@ -382,8 +386,6 @@ def create_range(start_char: str, end_char: str) -> set[str]: ascii_chars = set(string.printable) input_symbols.update(ascii_chars) - input_symbols = input_symbols - RESERVED_CHARACTERS - latin_nfa = NFA.from_regex("[¡-ƿ]+", input_symbols=input_symbols) greek_nfa = NFA.from_regex("[Ͱ-Ͽ]+", input_symbols=input_symbols) cyrillic_nfa = NFA.from_regex("[Ѐ-ӿ]+", input_symbols=input_symbols) @@ -437,7 +439,7 @@ def create_range(start_char: str, end_char: str) -> set[str]: self.assertFalse(non_latin_nfa.accepts_input("a¡")) alphabet = set("abcdefghijklmnopqrstuvwxyz") - alphabet = alphabet - RESERVED_CHARACTERS + alphabet = alphabet safe_input_symbols = input_symbols.union(alphabet) ascii_range_nfa = NFA.from_regex("[i-p]+", input_symbols=safe_input_symbols) @@ -625,3 +627,224 @@ def test_shorthand_character_classes(self) -> None: self.assertTrue(complex_nfa.accepts_input("_\t0\n")) self.assertFalse(complex_nfa.accepts_input("abc 123\n")) # space instead of tab self.assertFalse(complex_nfa.accepts_input("abc\t123")) # missing newline + + def test_negated_class_with_period(self) -> None: + """Test that negated character classes can match the period character""" + + # Create an NFA with a negated character class + nfa = NFA.from_regex(r"[.]+.", input_symbols={"a"}) + self.assertTrue(nfa.accepts_input(".a")) + self.assertFalse(nfa.accepts_input("]+", input_symbols={"a", "."}) + self.assertTrue(nfa.accepts_input(".")) + self.assertTrue(nfa.accepts_input("...")) + + nfa = NFA.from_regex(r"[^<>]+", input_symbols=set(string.printable)) + # This should match any character except < and > + self.assertTrue(nfa.accepts_input("abc")) + self.assertTrue(nfa.accepts_input("123")) + self.assertTrue(nfa.accepts_input('!@#$%^&*()_+{}|:",./?`~')) + + # These should not match + self.assertFalse(nfa.accepts_input("<")) + self.assertFalse(nfa.accepts_input(">")) + self.assertFalse(nfa.accepts_input("ab")) # contains > + + def test_slash_character(self) -> None: + """Should correctly handle the slash character""" + nfa = NFA.from_regex(r"/", input_symbols=set(string.printable)) + self.assertTrue(nfa.accepts_input("/")) + self.assertFalse(nfa.accepts_input("a/b")) + + def test_email_like_regexes(self) -> None: + """Should correctly handle email-like regexes""" + input_symbols = set(string.printable) + + # Pattern for bracketed email content: ">content[^<>]+<.*", input_symbols=input_symbols) + self.assertTrue(bracketed_nfa.accepts_input(">user@example.com<")) + self.assertTrue(bracketed_nfa.accepts_input(">John Doe + self.assertFalse(bracketed_nfa.accepts_input("><")) # empty content + + # Pattern for "To:" header field + to_header_nfa = NFA.from_regex(r"to:[^\r\n]+\r\n", input_symbols=input_symbols) + self.assertTrue(to_header_nfa.accepts_input("to:user@example.com\r\n")) + self.assertTrue( + to_header_nfa.accepts_input( + "to:Multiple Recipients \r\n" + ) + ) + self.assertFalse( + to_header_nfa.accepts_input("to:user@example.com") + ) # missing newline + self.assertFalse( + to_header_nfa.accepts_input("from:user@example.com\r\n") + ) # wrong header + + # Pattern for "Subject:" header field + subject_nfa = NFA.from_regex( + r"\)subject:[^\r\n]+\r\n", input_symbols=input_symbols + ) + self.assertTrue(subject_nfa.accepts_input(")subject:Hello World\r\n")) + self.assertTrue( + subject_nfa.accepts_input(")subject:Re: Meeting Tomorrow at 10AM\r\n") + ) + self.assertFalse( + subject_nfa.accepts_input("subject:Hello World\r\n") + ) # missing ) + + # Pattern for standard email address + email_nfa = NFA.from_regex( + r"[A-Za-z0-9!#$%&'*+=?\-\^_`{|}~.\/]+@[A-Za-z0-9.\-@]+", + input_symbols=input_symbols, + ) + self.assertTrue(email_nfa.accepts_input("user@example.com")) + self.assertTrue(email_nfa.accepts_input("user.name+tag@sub.example-site.co.uk")) + self.assertTrue(email_nfa.accepts_input("unusual!#$%&'*character@example.com")) + self.assertFalse(email_nfa.accepts_input("@example.com")) # missing local part + self.assertFalse(email_nfa.accepts_input("user@")) # missing domain + + # Pattern for DKIM signature with Base64 hash + dkim_bh_nfa = NFA.from_regex( + r"dkim-signature:([a-z]+=[^;]+; )+bh=[a-zA-Z0-9+/=]+;", + input_symbols=input_symbols, + ) + self.assertTrue( + dkim_bh_nfa.accepts_input( + "dkim-signature:v=1; a=rsa-sha256; bh=47DEQpj8HBSa+/TImW+5JCeuQeR;" + ) + ) + self.assertTrue( + dkim_bh_nfa.accepts_input( + "dkim-signature:v=1; a=rsa-sha256; d=example.org; bh=base64+/hash=;" + ) + ) + self.assertFalse( + dkim_bh_nfa.accepts_input("dkim-signature:v=1; bh=;") + ) # empty hash + + # Pattern for alternative email address format + alt_email_nfa = NFA.from_regex( + r"[A-Za-z0-9!#$%&'*+=?\-\^_`{|}~.\/@]+@[A-Za-z0-9.\-]+", + input_symbols=input_symbols, + ) + self.assertTrue(alt_email_nfa.accepts_input("user@example.com")) + self.assertTrue( + alt_email_nfa.accepts_input("user/dept@example.com") + ) # with slash + self.assertFalse(alt_email_nfa.accepts_input("user@")) # missing domain + + # Pattern for "From:" header field + from_header_nfa = NFA.from_regex( + r"from:[^\r\n]+\r\n", input_symbols=input_symbols + ) + self.assertTrue(from_header_nfa.accepts_input("from:sender@example.com\r\n")) + self.assertTrue( + from_header_nfa.accepts_input("from:John Doe \r\n") + ) + self.assertFalse( + from_header_nfa.accepts_input("from:sender@example.com") + ) # missing newline + + # Pattern for DKIM signature with timestamp + dkim_time_nfa = NFA.from_regex( + r"dkim-signature:([a-z]+=[^;]+; )+t=[0-9]+;", input_symbols=input_symbols + ) + self.assertTrue( + dkim_time_nfa.accepts_input( + "dkim-signature:v=1; a=rsa-sha256; t=1623456789;" + ) + ) + self.assertTrue( + dkim_time_nfa.accepts_input( + "dkim-signature:v=1; a=rsa-sha256; s=selector; t=1623456789;" + ) + ) + self.assertFalse( + dkim_time_nfa.accepts_input("dkim-signature:v=1; t=;") + ) # empty timestamp + + # Pattern for Message-ID header + msgid_nfa = NFA.from_regex( + r"message-id:<[A-Za-z0-9=@\.\+_-]+>\r\n", input_symbols=input_symbols + ) + self.assertTrue(msgid_nfa.accepts_input("message-id:<123abc@example.com>\r\n")) + self.assertTrue( + msgid_nfa.accepts_input("message-id:\r\n") + ) + self.assertFalse( + msgid_nfa.accepts_input("message-id:\r\n") + ) # invalid chars + self.assertFalse( + msgid_nfa.accepts_input("message-id:") + ) # missing newline + + def test_repeating_group_with_space(self) -> None: + """Test a simpler version of the DKIM signature pattern to isolate the issue""" + input_symbols = set(string.printable) + + # Try another variation without the space in the pattern + no_space = NFA.from_regex(r"([a-z]+=[^;]+;)+", input_symbols=input_symbols) + self.assertTrue(no_space.accepts_input("v=1;")) + self.assertTrue(no_space.accepts_input("v=1;a=2;")) + + # Test with explicit space character instead of relying on character class + explicit_space = NFA.from_regex( + r"([a-z]+=[^;]+; )+", input_symbols=input_symbols + ) + self.assertTrue(explicit_space.accepts_input("v=1; ")) + + # Simplified version of the problematic pattern + simple_repeat = NFA.from_regex( + r"([a-z]+=[^;]+; )+", input_symbols=input_symbols + ) + self.assertTrue(simple_repeat.accepts_input("v=1; ")) + self.assertTrue(simple_repeat.accepts_input("v=1; a=2; ")) + + # Test the full pattern but simplified + full_simple = NFA.from_regex( + r"header:([a-z]+=[^;]+; )+value;", input_symbols=input_symbols + ) + self.assertTrue(full_simple.accepts_input("header:v=1; value;")) + self.assertTrue(full_simple.accepts_input("header:v=1; a=2; value;")) + + def test_space_in_patterns(self) -> None: + """Test different patterns with spaces to isolate the issue""" + input_symbols = set(string.printable) + + # Test 1: Basic pattern with space at the end + basic = NFA.from_regex(r"a ", input_symbols=input_symbols) + self.assertTrue(basic.accepts_input("a ")) + + # Test 2: Character class with space + with_class = NFA.from_regex(r"a[b ]", input_symbols=input_symbols) + self.assertTrue(with_class.accepts_input("a ")) + self.assertTrue(with_class.accepts_input("ab")) + + # Test 3: Simple repetition with space + simple_repeat = NFA.from_regex(r"(a )+", input_symbols=input_symbols) + self.assertTrue(simple_repeat.accepts_input("a ")) + self.assertTrue(simple_repeat.accepts_input("a a ")) + + # Test 4: Specific repeating pattern without the semicolon + no_semicolon = NFA.from_regex(r"([a-z]+=. )+", input_symbols=input_symbols) + self.assertTrue(no_semicolon.accepts_input("v=1 ")) + self.assertTrue(no_semicolon.accepts_input("v=1 a=2 ")) + + # Test 5: With semicolon but space before + space_before = NFA.from_regex(r"([a-z]+=[^;]+ ;)+", input_symbols=input_symbols) + self.assertTrue(space_before.accepts_input("v=1 ;")) + self.assertTrue(space_before.accepts_input("v=1 ;a=2 ;")) + + # Test 6: Space as part of negated class + space_in_neg = NFA.from_regex(r"([a-z]+=[^; ]+;)+", input_symbols=input_symbols) + self.assertTrue(space_in_neg.accepts_input("v=1;")) + + # Test 7: Bare minimum to reproduce + minimal = NFA.from_regex(r"(a; )+", input_symbols=input_symbols) + self.assertTrue(minimal.accepts_input("a; ")) + self.assertTrue(minimal.accepts_input("a; a; "))