Skip to content

Shuffle regex #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 6, 2022
76 changes: 59 additions & 17 deletions automata/regex/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
parse_postfix_tokens, tokens_to_postfix,
validate_tokens)

RESERVED_CHARACTERS = frozenset({'*', '|', '(', ')', '?', ' ', '\t', '&', '+', '.'})
RESERVED_CHARACTERS = frozenset({'*', '|', '(', ')', '?', ' ', '\t', '&', '+', '.', '^'})


class NFARegexBuilder:
Expand Down Expand Up @@ -44,7 +44,7 @@ def from_string_literal(cls, literal):
end_states.add(start_state+1)

final_state = cls.__get_next_state_name()
transitions[final_state] = dict()
transitions[final_state] = {}

return cls(
transitions=transitions,
Expand All @@ -63,7 +63,7 @@ def wildcard(cls, input_symbols):

transitions = {
initial_state: {symbol: {final_state} for symbol in input_symbols},
final_state: dict()
final_state: {}
}

return cls(
Expand Down Expand Up @@ -93,25 +93,24 @@ def intersection(self, other):
Apply the intersection operation to the NFA represented by this builder and other.
Use BFS to only traverse reachable part (keeps number of states down).
"""
new_state_name_dict = dict()
new_state_name_dict = {}

def get_state_name(state_name):
return new_state_name_dict.setdefault(state_name, self.__get_next_state_name())

new_final_states = set()
new_transitions = dict()
new_transitions = {}
new_initial_state = (self._initial_state, other._initial_state)

new_initial_state_name = get_state_name(new_initial_state)
new_input_symbols = set(chain.from_iterable(
transition_dict.keys()
for transition_dict in chain(self._transitions.values(), other._transitions.values())
)) - {''}
new_input_symbols = tuple(set(chain.from_iterable(
map(dict.keys, chain(self._transitions.values(), other._transitions.values()))
)) - {''})

queue = deque()

queue.append(new_initial_state)
new_transitions[new_initial_state_name] = dict()
new_transitions[new_initial_state_name] = {}

while queue:
curr_state = queue.popleft()
Expand All @@ -129,9 +128,9 @@ def get_state_name(state_name):
# Add epsilon transitions for first set of transitions
epsilon_transitions_a = transitions_a.get('')
if epsilon_transitions_a is not None:
state_dict = new_transitions.setdefault(curr_state_name, dict())
state_dict = new_transitions.setdefault(curr_state_name, {})
state_dict.setdefault('', set()).update(
get_state_name(state) for state in product(epsilon_transitions_a, [q_b])
map(get_state_name, product(epsilon_transitions_a, [q_b]))
)
next_states_iterables.append(product(epsilon_transitions_a, [q_b]))

Expand All @@ -140,9 +139,9 @@ def get_state_name(state_name):
# Add epsilon transitions for second set of transitions
epsilon_transitions_b = transitions_b.get('')
if epsilon_transitions_b is not None:
state_dict = new_transitions.setdefault(curr_state_name, dict())
state_dict = new_transitions.setdefault(curr_state_name, {})
state_dict.setdefault('', set()).update(
get_state_name(state) for state in product([q_a], epsilon_transitions_b)
map(get_state_name, product([q_a], epsilon_transitions_b))
)
next_states_iterables.append(product([q_a], epsilon_transitions_b))

Expand All @@ -152,17 +151,17 @@ def get_state_name(state_name):
end_states_b = transitions_b.get(symbol)

if end_states_a is not None and end_states_b is not None:
state_dict = new_transitions.setdefault(curr_state_name, dict())
state_dict = new_transitions.setdefault(curr_state_name, {})
state_dict.setdefault(symbol, set()).update(
get_state_name(state) for state in product(end_states_a, end_states_b)
map(get_state_name, product(end_states_a, end_states_b))
)
next_states_iterables.append(product(end_states_a, end_states_b))

# Finally, try visiting every state we found.
for product_state in chain.from_iterable(next_states_iterables):
product_state_name = get_state_name(product_state)
if product_state_name not in new_transitions:
new_transitions[product_state_name] = dict()
new_transitions[product_state_name] = {}
queue.append(product_state)

self._final_states = new_final_states
Expand Down Expand Up @@ -216,6 +215,37 @@ def option(self):
self._initial_state = new_initial_state
self._final_states.add(new_initial_state)

def shuffle_product(self, other):
"""
Apply the shuffle operation to the NFA represented by this builder and other.
No need for BFS since all states are accessible.
"""
new_state_name_dict = {}

def get_state_name(state_name):
return new_state_name_dict.setdefault(state_name, self.__get_next_state_name())

self._initial_state = get_state_name((self._initial_state, other._initial_state))

new_transitions = {}

transition_product = product(self._transitions.items(), other._transitions.items())
for (q_a, transitions_a), (q_b, transitions_b) in transition_product:
state_dict = new_transitions.setdefault(get_state_name((q_a, q_b)), {})

for symbol, end_states in transitions_a.items():
state_dict.setdefault(symbol, set()).update(
map(get_state_name, product(end_states, [q_b]))
)

for symbol, end_states in transitions_b.items():
state_dict.setdefault(symbol, set()).update(
map(get_state_name, product([q_a], end_states))
)

self._final_states = set(map(get_state_name, product(self._final_states, other._final_states)))
self._transitions = new_transitions

@classmethod
def __get_next_state_name(cls):
return next(cls._state_name_counter)
Expand Down Expand Up @@ -243,6 +273,17 @@ def op(self, left, right):
return left


class ShuffleToken(InfixOperator):
"""Subclass of infix operator defining the shuffle operator."""

def get_precedence(self):
return 1

def op(self, left, right):
left.shuffle_product(right)
return left


class KleeneStarToken(PostfixOperator):
"""Subclass of postfix operator defining the kleene star operator."""

Expand Down Expand Up @@ -340,6 +381,7 @@ def get_regex_lexer(input_symbols):
lexer.register_token(StringToken, r'[A-Za-z0-9]')
lexer.register_token(UnionToken, r'\|')
lexer.register_token(IntersectionToken, r'\&')
lexer.register_token(ShuffleToken, r'\^')
lexer.register_token(KleeneStarToken, r'\*')
lexer.register_token(KleenePlusToken, r'\+')
lexer.register_token(OptionToken, r'\?')
Expand Down
1 change: 1 addition & 0 deletions docs/regular-expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ A regular expression with the following operations only are supported in this li
- `|`: Union. Ex: `a|b`
- `&`: Intersection. Ex: `a&b`
- `.`: Wildcard. Ex: `a.b`
- `^`: Shuffle. Ex: `a^b`
- `()`: Grouping.

This is similar to the python RE module but this library does not support any other
Expand Down
15 changes: 15 additions & 0 deletions tests/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,21 @@ def test_wildcard(self):
self.assertTrue(re.issubset('a.b', '...', input_symbols=input_symbols))
self.assertTrue(re.issuperset('.', 'a|b', input_symbols=input_symbols))

def test_shuffle(self):
"""Should correctly check shuffle"""

input_symbols = {'a', 'b', 'c', 'd'}

self.assertTrue(re.isequal('a^b', 'ab|ba', input_symbols=input_symbols))
self.assertTrue(re.isequal('ab^cd', 'abcd | acbd | cabd | acdb | cadb | cdab', input_symbols=input_symbols))
self.assertTrue(re.isequal('(a*)^(b*)^(c*)^(d*)', '.*', input_symbols=input_symbols))
self.assertTrue(re.isequal('ca^db', '(c^db)a | (ca^d)b', input_symbols=input_symbols))
self.assertTrue(re.isequal('a^(b|c)', 'ab | ac | ba | ca', input_symbols=input_symbols))

reference_nfa = NFA.from_regex('a*^ba')
other_nfa = NFA.shuffle_product(NFA.from_regex('a*'), NFA.from_regex('ba'))
self.assertEqual(reference_nfa, other_nfa)

def test_invalid_symbols(self):
"""Should throw exception if reserved character is in input symbols"""
with self.assertRaises(exceptions.InvalidSymbolError):
Expand Down