Skip to content

Commit 14b4dfe

Browse files
authored
Shuffle regex (#112)
1 parent d8088ad commit 14b4dfe

File tree

3 files changed

+75
-17
lines changed

3 files changed

+75
-17
lines changed

automata/regex/parser.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
parse_postfix_tokens, tokens_to_postfix,
1111
validate_tokens)
1212

13-
RESERVED_CHARACTERS = frozenset({'*', '|', '(', ')', '?', ' ', '\t', '&', '+', '.'})
13+
RESERVED_CHARACTERS = frozenset({'*', '|', '(', ')', '?', ' ', '\t', '&', '+', '.', '^'})
1414

1515

1616
class NFARegexBuilder:
@@ -44,7 +44,7 @@ def from_string_literal(cls, literal):
4444
end_states.add(start_state+1)
4545

4646
final_state = cls.__get_next_state_name()
47-
transitions[final_state] = dict()
47+
transitions[final_state] = {}
4848

4949
return cls(
5050
transitions=transitions,
@@ -63,7 +63,7 @@ def wildcard(cls, input_symbols):
6363

6464
transitions = {
6565
initial_state: {symbol: {final_state} for symbol in input_symbols},
66-
final_state: dict()
66+
final_state: {}
6767
}
6868

6969
return cls(
@@ -93,25 +93,24 @@ def intersection(self, other):
9393
Apply the intersection operation to the NFA represented by this builder and other.
9494
Use BFS to only traverse reachable part (keeps number of states down).
9595
"""
96-
new_state_name_dict = dict()
96+
new_state_name_dict = {}
9797

9898
def get_state_name(state_name):
9999
return new_state_name_dict.setdefault(state_name, self.__get_next_state_name())
100100

101101
new_final_states = set()
102-
new_transitions = dict()
102+
new_transitions = {}
103103
new_initial_state = (self._initial_state, other._initial_state)
104104

105105
new_initial_state_name = get_state_name(new_initial_state)
106-
new_input_symbols = set(chain.from_iterable(
107-
transition_dict.keys()
108-
for transition_dict in chain(self._transitions.values(), other._transitions.values())
109-
)) - {''}
106+
new_input_symbols = tuple(set(chain.from_iterable(
107+
map(dict.keys, chain(self._transitions.values(), other._transitions.values()))
108+
)) - {''})
110109

111110
queue = deque()
112111

113112
queue.append(new_initial_state)
114-
new_transitions[new_initial_state_name] = dict()
113+
new_transitions[new_initial_state_name] = {}
115114

116115
while queue:
117116
curr_state = queue.popleft()
@@ -129,9 +128,9 @@ def get_state_name(state_name):
129128
# Add epsilon transitions for first set of transitions
130129
epsilon_transitions_a = transitions_a.get('')
131130
if epsilon_transitions_a is not None:
132-
state_dict = new_transitions.setdefault(curr_state_name, dict())
131+
state_dict = new_transitions.setdefault(curr_state_name, {})
133132
state_dict.setdefault('', set()).update(
134-
get_state_name(state) for state in product(epsilon_transitions_a, [q_b])
133+
map(get_state_name, product(epsilon_transitions_a, [q_b]))
135134
)
136135
next_states_iterables.append(product(epsilon_transitions_a, [q_b]))
137136

@@ -140,9 +139,9 @@ def get_state_name(state_name):
140139
# Add epsilon transitions for second set of transitions
141140
epsilon_transitions_b = transitions_b.get('')
142141
if epsilon_transitions_b is not None:
143-
state_dict = new_transitions.setdefault(curr_state_name, dict())
142+
state_dict = new_transitions.setdefault(curr_state_name, {})
144143
state_dict.setdefault('', set()).update(
145-
get_state_name(state) for state in product([q_a], epsilon_transitions_b)
144+
map(get_state_name, product([q_a], epsilon_transitions_b))
146145
)
147146
next_states_iterables.append(product([q_a], epsilon_transitions_b))
148147

@@ -152,17 +151,17 @@ def get_state_name(state_name):
152151
end_states_b = transitions_b.get(symbol)
153152

154153
if end_states_a is not None and end_states_b is not None:
155-
state_dict = new_transitions.setdefault(curr_state_name, dict())
154+
state_dict = new_transitions.setdefault(curr_state_name, {})
156155
state_dict.setdefault(symbol, set()).update(
157-
get_state_name(state) for state in product(end_states_a, end_states_b)
156+
map(get_state_name, product(end_states_a, end_states_b))
158157
)
159158
next_states_iterables.append(product(end_states_a, end_states_b))
160159

161160
# Finally, try visiting every state we found.
162161
for product_state in chain.from_iterable(next_states_iterables):
163162
product_state_name = get_state_name(product_state)
164163
if product_state_name not in new_transitions:
165-
new_transitions[product_state_name] = dict()
164+
new_transitions[product_state_name] = {}
166165
queue.append(product_state)
167166

168167
self._final_states = new_final_states
@@ -216,6 +215,37 @@ def option(self):
216215
self._initial_state = new_initial_state
217216
self._final_states.add(new_initial_state)
218217

218+
def shuffle_product(self, other):
219+
"""
220+
Apply the shuffle operation to the NFA represented by this builder and other.
221+
No need for BFS since all states are accessible.
222+
"""
223+
new_state_name_dict = {}
224+
225+
def get_state_name(state_name):
226+
return new_state_name_dict.setdefault(state_name, self.__get_next_state_name())
227+
228+
self._initial_state = get_state_name((self._initial_state, other._initial_state))
229+
230+
new_transitions = {}
231+
232+
transition_product = product(self._transitions.items(), other._transitions.items())
233+
for (q_a, transitions_a), (q_b, transitions_b) in transition_product:
234+
state_dict = new_transitions.setdefault(get_state_name((q_a, q_b)), {})
235+
236+
for symbol, end_states in transitions_a.items():
237+
state_dict.setdefault(symbol, set()).update(
238+
map(get_state_name, product(end_states, [q_b]))
239+
)
240+
241+
for symbol, end_states in transitions_b.items():
242+
state_dict.setdefault(symbol, set()).update(
243+
map(get_state_name, product([q_a], end_states))
244+
)
245+
246+
self._final_states = set(map(get_state_name, product(self._final_states, other._final_states)))
247+
self._transitions = new_transitions
248+
219249
@classmethod
220250
def __get_next_state_name(cls):
221251
return next(cls._state_name_counter)
@@ -243,6 +273,17 @@ def op(self, left, right):
243273
return left
244274

245275

276+
class ShuffleToken(InfixOperator):
277+
"""Subclass of infix operator defining the shuffle operator."""
278+
279+
def get_precedence(self):
280+
return 1
281+
282+
def op(self, left, right):
283+
left.shuffle_product(right)
284+
return left
285+
286+
246287
class KleeneStarToken(PostfixOperator):
247288
"""Subclass of postfix operator defining the kleene star operator."""
248289

@@ -340,6 +381,7 @@ def get_regex_lexer(input_symbols):
340381
lexer.register_token(StringToken, r'[A-Za-z0-9]')
341382
lexer.register_token(UnionToken, r'\|')
342383
lexer.register_token(IntersectionToken, r'\&')
384+
lexer.register_token(ShuffleToken, r'\^')
343385
lexer.register_token(KleeneStarToken, r'\*')
344386
lexer.register_token(KleenePlusToken, r'\+')
345387
lexer.register_token(OptionToken, r'\?')

docs/regular-expressions.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ A regular expression with the following operations only are supported in this li
1515
- `|`: Union. Ex: `a|b`
1616
- `&`: Intersection. Ex: `a&b`
1717
- `.`: Wildcard. Ex: `a.b`
18+
- `^`: Shuffle. Ex: `a^b`
1819
- `()`: Grouping.
1920

2021
This is similar to the python RE module but this library does not support any other

tests/test_regex.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,21 @@ def test_wildcard(self):
116116
self.assertTrue(re.issubset('a.b', '...', input_symbols=input_symbols))
117117
self.assertTrue(re.issuperset('.', 'a|b', input_symbols=input_symbols))
118118

119+
def test_shuffle(self):
120+
"""Should correctly check shuffle"""
121+
122+
input_symbols = {'a', 'b', 'c', 'd'}
123+
124+
self.assertTrue(re.isequal('a^b', 'ab|ba', input_symbols=input_symbols))
125+
self.assertTrue(re.isequal('ab^cd', 'abcd | acbd | cabd | acdb | cadb | cdab', input_symbols=input_symbols))
126+
self.assertTrue(re.isequal('(a*)^(b*)^(c*)^(d*)', '.*', input_symbols=input_symbols))
127+
self.assertTrue(re.isequal('ca^db', '(c^db)a | (ca^d)b', input_symbols=input_symbols))
128+
self.assertTrue(re.isequal('a^(b|c)', 'ab | ac | ba | ca', input_symbols=input_symbols))
129+
130+
reference_nfa = NFA.from_regex('a*^ba')
131+
other_nfa = NFA.shuffle_product(NFA.from_regex('a*'), NFA.from_regex('ba'))
132+
self.assertEqual(reference_nfa, other_nfa)
133+
119134
def test_invalid_symbols(self):
120135
"""Should throw exception if reserved character is in input symbols"""
121136
with self.assertRaises(exceptions.InvalidSymbolError):

0 commit comments

Comments
 (0)