diff --git a/src/parse.rs b/src/parse.rs index 5ec0ca0..cf304ac 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -20,6 +20,7 @@ //! A regex parser yielding an AST. +use regex::escape; use bit_set::BitSet; use std::str::FromStr; use std::usize; @@ -199,7 +200,13 @@ impl<'a> Parser<'a> { } )), b'(' => self.parse_group(ix, depth), - b'\\' => self.parse_escape(ix), + b'\\' => { + let (next, expr) = try!(self.parse_escape(ix)); + if let Expr::Backref(group) = expr { + self.backrefs.insert(group); + } + Ok((next, expr)) + }, b'+' | b'*' | b'?' | b'|' | b')' => Ok((ix, Expr::Empty)), b'[' => self.parse_class(ix), @@ -221,7 +228,7 @@ impl<'a> Parser<'a> { } // ix points to \ character - fn parse_escape(&mut self, ix: usize) -> Result<(usize, Expr)> { + fn parse_escape(&self, ix: usize) -> Result<(usize, Expr)> { if ix + 1 == self.re.len() { return Err(Error::TrailingBackslash); } @@ -233,7 +240,6 @@ impl<'a> Parser<'a> { if let Some((end, group)) = parse_decimal(self.re, ix + 1) { // protect BitSet against unreasonably large value if group < self.re.len() / 2 { - self.backrefs.insert(group); return Ok((end, Expr::Backref(group))); } } @@ -331,9 +337,9 @@ impl<'a> Parser<'a> { fn parse_class(&self, ix: usize) -> Result<(usize, Expr)> { let bytes = self.re.as_bytes(); let mut ix = ix + 1; // skip opening '[' - let mut inner = String::new(); + let mut class = String::new(); let mut nest = 1; - inner.push('['); + class.push('['); loop { ix = self.optional_whitespace(ix); if ix == self.re.len() { @@ -344,10 +350,25 @@ impl<'a> Parser<'a> { if ix + 1 == self.re.len() { return Err(Error::InvalidClass); } - ix + 1 + codepoint_len(bytes[ix + 1]) + + // We support more escapes than regex, so parse it ourselves before delegating. + let (end, expr) = try!(self.parse_escape(ix)); + match expr { + Expr::Literal { val, .. } => { + class.push_str(&escape(&val)); + } + Expr::Delegate { inner, .. } => { + class.push_str(&inner); + } + _ => { + return Err(Error::InvalidClass); + } + } + end } b'[' => { nest += 1; + class.push('['); ix + 1 } b']' => { @@ -355,16 +376,20 @@ impl<'a> Parser<'a> { if nest == 0 { break; } + class.push(']'); ix + 1 } - b => ix + codepoint_len(b) + b => { + let end = ix + codepoint_len(b); + class.push_str(&self.re[ix..end]); + end + } }; - inner.push_str(&self.re[ix..end]); ix = end; } - inner.push(']'); + class.push(']'); let ix = ix + 1; // skip closing ']' - Ok((ix, Expr::Delegate { inner: inner, size: 1 })) + Ok((ix, Expr::Delegate { inner: class, size: 1 })) } fn parse_group(&mut self, ix: usize, depth: usize) -> Result<(usize, Expr)> { diff --git a/tests/matching.rs b/tests/matching.rs index 5947d94..2691458 100644 --- a/tests/matching.rs +++ b/tests/matching.rs @@ -14,6 +14,25 @@ fn control_character_escapes() { assert_matches(r"\v", "\x0B"); } +#[test] +fn character_class_escapes() { + assert_matches(r"[\[]", "["); + assert_matches(r"[\^]", "^"); + + // The regex crate would reject the following because it's not necessary to escape them. + // Other engines allow to escape any non-alphanumeric character. + assert_matches(r"[\<]", "<"); + assert_matches(r"[\>]", ">"); + assert_matches(r"[\.]", "."); + + // Character class escape + assert_matches(r"[\d]", "1"); + + // Control characters + assert_matches(r"[\e]", "\x1B"); + assert_matches(r"[\n]", "\x0A"); +} + fn assert_matches(re: &str, text: &str) { let parse_result = Regex::new(re);