Skip to content

Add support for character classes [...] #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: develop
Choose a base branch
from
161 changes: 153 additions & 8 deletions automata/fa/nfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import re
from collections import deque
from itertools import chain, count, product, repeat
from typing import (
Expand Down Expand Up @@ -216,21 +217,165 @@ def from_regex(
Self
The NFA accepting the language of the input regex.
"""
# Import the shorthand character classes
from automata.regex.parser import (
DIGIT_CHARS,
NON_DIGIT_CHARS,
NON_WHITESPACE_CHARS,
NON_WORD_CHARS,
WHITESPACE_CHARS,
WORD_CHARS,
)
Comment on lines +221 to +228
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@StefanosChaliasos Can you please keep all imports at the top of the file? There's no particular need for the tighter scoping here, IMO.

cc @eliotwrobson

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think ruff will complain about the imports.


# Create a set for additional symbols from shorthand classes
additional_symbols = set()

# Check for shorthand classes in the regex
if "\\s" in regex:
additional_symbols.update(WHITESPACE_CHARS)
if "\\S" in regex:
additional_symbols.update(NON_WHITESPACE_CHARS)
if "\\d" in regex:
additional_symbols.update(DIGIT_CHARS)
if "\\D" in regex:
additional_symbols.update(NON_DIGIT_CHARS)
if "\\w" in regex:
additional_symbols.update(WORD_CHARS)
if "\\W" in regex:
additional_symbols.update(NON_WORD_CHARS)
Comment on lines +234 to +245
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@StefanosChaliasos Can you please refactor this to use a dict-based lookup table? That would make this much less repetitive.

cc @eliotwrobson

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this could make things much cleaner, especially since this can be done in a loop 👍🏽


# Extract escaped sequences from the regex
escape_chars = set()
i = 0
while i < len(regex):
if regex[i] == "\\" and i + 1 < len(regex):
from automata.regex.parser import _handle_escape_sequences

# Skip shorthand classes
if regex[i + 1] in "sSwWdD":
i += 2
continue

escaped_char = _handle_escape_sequences(regex[i + 1])
escape_chars.add(escaped_char)
i += 2
else:
i += 1

class_symbols = set()
range_pattern = re.compile(r"\[([^\]]*)\]")
for match in range_pattern.finditer(regex):
class_content = match.group(1)
pos = 0
while pos < len(class_content):
if class_content[pos] == "\\" and pos + 1 < len(class_content):
# Check for shorthand classes in character classes
if class_content[pos + 1] == "s":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might be able to use a lookup table from a dictionary here? Just use the character as a key and a tuple of additional symbols and position increment as the value.

additional_symbols.update(WHITESPACE_CHARS)
pos += 2
continue
elif class_content[pos + 1] == "d":
additional_symbols.update(DIGIT_CHARS)
pos += 2
continue
elif class_content[pos + 1] == "w":
additional_symbols.update(WORD_CHARS)
pos += 2
continue
elif class_content[pos + 1] in "S":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@StefanosChaliasos What is the intention of using in here as opposed to ==? If the right-hand side is just a single character, the only difference that seems to make is permitting class_content[pos + 1] to be empty string (in addition to the character itself). In other words:

"S" in "S" True
"" in "S"  # True

additional_symbols.update(NON_WHITESPACE_CHARS)
pos += 2
continue
elif class_content[pos + 1] in "D":
additional_symbols.update(NON_DIGIT_CHARS)
pos += 2
continue
elif class_content[pos + 1] in "W":
additional_symbols.update(NON_WORD_CHARS)
pos += 2
continue

# Handle escape sequence in character class
from automata.regex.parser import _handle_escape_sequences
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@StefanosChaliasos Can you also please move this import to the top of the file?


escaped_char = _handle_escape_sequences(class_content[pos + 1])
class_symbols.add(escaped_char)

# Check if this is part of a range
if (
pos + 2 < len(class_content)
and class_content[pos + 2] == "-"
and pos + 3 < len(class_content)
):
# Handle range with escaped start character
start_char = escaped_char

# Check if end character is also escaped
if class_content[pos + 3] == "\\" and pos + 4 < len(
class_content
):
end_char = _handle_escape_sequences(class_content[pos + 4])
pos += 5
else:
end_char = class_content[pos + 3]
pos += 4

# Add all characters in the range to input symbols
for i in range(ord(start_char), ord(end_char) + 1):
class_symbols.add(chr(i))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be done with a python update call instead of a loop.

continue

pos += 2
elif pos + 2 < len(class_content) and class_content[pos + 1] == "-":
# Handle normal range
start_char = class_content[pos]

if class_content[pos + 2] == "\\" and pos + 3 < len(class_content):
end_char = _handle_escape_sequences(class_content[pos + 3])
pos += 4
else:
end_char = class_content[pos + 2]
pos += 3

for i in range(ord(start_char), ord(end_char) + 1):
class_symbols.add(chr(i))
else:
if class_content[pos] != "^": # Skip negation symbol
class_symbols.add(class_content[pos])
pos += 1

# Set up the final input symbols
if input_symbols is None:
input_symbols = frozenset(regex) - RESERVED_CHARACTERS
# If no input_symbols provided, collect all non-reserved chars from regex
input_symbols_set = set()
for char in regex:
if char not in RESERVED_CHARACTERS:
input_symbols_set.add(char)

# Include all character class symbols and escape sequences
input_symbols_set.update(class_symbols)
input_symbols_set.update(escape_chars)

# Add the shorthand characters
input_symbols_set.update(additional_symbols)

final_input_symbols = frozenset(input_symbols_set)
else:
conflicting_symbols = RESERVED_CHARACTERS & input_symbols
if conflicting_symbols:
raise exceptions.InvalidSymbolError(
f"Invalid input symbols: {conflicting_symbols}"
)
# For user-provided input_symbols, we need to update
# with character class symbols and escape sequences
final_input_symbols = (
frozenset(input_symbols)
.union(class_symbols)
.union(escape_chars)
.union(additional_symbols)
)

nfa_builder = parse_regex(regex, input_symbols)
# Build the NFA
nfa_builder = parse_regex(regex, final_input_symbols)

return cls(
states=frozenset(nfa_builder._transitions.keys()),
input_symbols=input_symbols,
input_symbols=final_input_symbols,
transitions=nfa_builder._transitions,
initial_state=nfa_builder._initial_state,
final_states=nfa_builder._final_states,
Expand Down
Loading
Loading