Skip to content

Commit 54cc1a9

Browse files
committed
Let operator ? only wrap the error in Box if needed
This leads to fewer needless allocations, and makes the operator `?` usable in an no-alloc context.
1 parent e7795bd commit 54cc1a9

File tree

6 files changed

+216
-27
lines changed

6 files changed

+216
-27
lines changed

rinja/src/error.rs

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use alloc::boxed::Box;
33
use core::convert::Infallible;
44
use core::error::Error as StdError;
55
use core::fmt;
6+
use core::marker::PhantomData;
67
#[cfg(feature = "std")]
78
use std::io;
89

@@ -119,7 +120,7 @@ impl From<Box<dyn StdError + Send + Sync>> for Error {
119120
impl From<io::Error> for Error {
120121
#[inline]
121122
fn from(err: io::Error) -> Self {
122-
from_from_io_error(err, MAX_ERROR_UNWRAP_COUNT)
123+
error_from_io_error(err, MAX_ERROR_UNWRAP_COUNT)
123124
}
124125
}
125126

@@ -140,7 +141,7 @@ fn error_from_stderror(err: Box<dyn StdError + Send + Sync>, unwraps: usize) ->
140141
Err(_) => Error::Fmt, // unreachable
141142
},
142143
ErrorKind::Io => match err.downcast() {
143-
Ok(err) => from_from_io_error(*err, unwraps),
144+
Ok(err) => error_from_io_error(*err, unwraps),
144145
Err(_) => Error::Fmt, // unreachable
145146
},
146147
ErrorKind::Rinja => match err.downcast() {
@@ -151,7 +152,7 @@ fn error_from_stderror(err: Box<dyn StdError + Send + Sync>, unwraps: usize) ->
151152
}
152153

153154
#[cfg(feature = "std")]
154-
fn from_from_io_error(err: io::Error, unwraps: usize) -> Error {
155+
fn error_from_io_error(err: io::Error, unwraps: usize) -> Error {
155156
let Some(inner) = err.get_ref() else {
156157
return Error::custom(err);
157158
};
@@ -177,7 +178,7 @@ fn from_from_io_error(err: io::Error, unwraps: usize) -> Error {
177178
None => Error::Fmt, // unreachable
178179
},
179180
ErrorKind::Io => match err.downcast() {
180-
Ok(inner) => from_from_io_error(inner, unwraps),
181+
Ok(inner) => error_from_io_error(inner, unwraps),
181182
Err(_) => Error::Fmt, // unreachable
182183
},
183184
}
@@ -239,3 +240,53 @@ const _: () = {
239240
trait AssertSendSyncStatic: Send + Sync + 'static {}
240241
impl AssertSendSyncStatic for Error {}
241242
};
243+
244+
/// Helper trait to convert a custom `?` call into a [`crate::Result`]
245+
pub trait ResultConverter {
246+
/// Okay Value type of the output
247+
type Value;
248+
/// Input type
249+
type Input;
250+
251+
/// Consume an interior mutable `self`, and turn it into a [`crate::Result`]
252+
fn rinja_conv_result(self, result: Self::Input) -> Result<Self::Value, Error>;
253+
}
254+
255+
/// Helper marker to be used with [`ResultConverter`]
256+
#[derive(Debug, Clone, Copy)]
257+
pub struct ErrorMarker<T>(PhantomData<Result<T>>);
258+
259+
impl<T> ErrorMarker<T> {
260+
/// Get marker for a [`Result`] type
261+
#[inline]
262+
pub fn of(_: &T) -> Self {
263+
Self(PhantomData)
264+
}
265+
}
266+
267+
#[cfg(feature = "alloc")]
268+
impl<T, E> ResultConverter for &ErrorMarker<Result<T, E>>
269+
where
270+
E: Into<Box<dyn StdError + Send + Sync>>,
271+
{
272+
type Value = T;
273+
type Input = Result<T, E>;
274+
275+
#[inline]
276+
fn rinja_conv_result(self, result: Self::Input) -> Result<Self::Value, Error> {
277+
result.map_err(Error::custom)
278+
}
279+
}
280+
281+
impl<T, E> ResultConverter for &&ErrorMarker<Result<T, E>>
282+
where
283+
E: Into<Error>,
284+
{
285+
type Value = T;
286+
type Input = Result<T, E>;
287+
288+
#[inline]
289+
fn rinja_conv_result(self, result: Self::Input) -> Result<Self::Value, Error> {
290+
result.map_err(Into::into)
291+
}
292+
}

rinja/src/helpers.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::iter::{Enumerate, Peekable};
1212
use std::ops::Deref;
1313
use std::pin::Pin;
1414

15+
pub use crate::error::{ErrorMarker, ResultConverter};
1516
use crate::filters::FastWritable;
1617

1718
pub struct TemplateLoop<I>
@@ -267,12 +268,3 @@ impl<L: FastWritable, R: FastWritable> FastWritable for Concat<L, R> {
267268
self.1.write_into(dest)
268269
}
269270
}
270-
271-
#[inline]
272-
#[cfg(feature = "alloc")]
273-
pub fn map_try<T, E>(result: Result<T, E>) -> Result<T, crate::Error>
274-
where
275-
E: Into<alloc::boxed::Box<dyn std::error::Error + Send + Sync>>,
276-
{
277-
result.map_err(crate::Error::custom)
278-
}

rinja_derive/src/generator.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ impl<'a, 'h> Generator<'a, 'h> {
173173
RinjaW: rinja::helpers::core::fmt::Write + ?rinja::helpers::core::marker::Sized\
174174
{\
175175
use rinja::filters::{AutoEscape as _, WriteWritable as _};\
176+
use rinja::helpers::ResultConverter as _;
176177
use rinja::helpers::core::fmt::Write as _;",
177178
);
178179

@@ -1496,16 +1497,9 @@ impl<'a, 'h> Generator<'a, 'h> {
14961497
buf: &mut Buffer,
14971498
expr: &WithSpan<'_, Expr<'_>>,
14981499
) -> Result<DisplayWrap, CompileError> {
1499-
if !cfg!(feature = "alloc") {
1500-
return Err(ctx.generate_error(
1501-
"the `?` operator requires the `alloc` feature to be enabled",
1502-
expr.span(),
1503-
));
1504-
}
1505-
1506-
buf.write("rinja::helpers::map_try(");
1500+
buf.write("match (");
15071501
self.visit_expr(ctx, buf, expr)?;
1508-
buf.write(")?");
1502+
buf.write(") { res => (&&rinja::helpers::ErrorMarker::of(&res)).rinja_conv_result(res)? }");
15091503
Ok(DisplayWrap::Unwrapped)
15101504
}
15111505

rinja_derive/src/tests.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ fn compare(jinja: &str, expected: &str, fields: &[(&str, &str)], size_hint: usiz
3232
RinjaW: rinja::helpers::core::fmt::Write + ?rinja::helpers::core::marker::Sized,
3333
{
3434
use rinja::filters::{AutoEscape as _, WriteWritable as _};
35+
use rinja::helpers::ResultConverter as _;
3536
use rinja::helpers::core::fmt::Write as _;
3637
#expected
3738
rinja::Result::Ok(())

testing/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ core = { package = "intentionally-empty", version = "1.0.0" }
2323
[dev-dependencies]
2424
rinja = { path = "../rinja", version = "0.3.5", features = ["code-in-doc", "serde_json"] }
2525

26+
assert_matches = "1.5.0"
2627
criterion = "0.5"
2728
phf = { version = "0.11", features = ["macros" ] }
2829
trybuild = "1.0.76"

testing/tests/try.rs

Lines changed: 155 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
use std::{fmt, io};
2+
3+
use assert_matches::assert_matches;
14
use rinja::Template;
25

36
#[test]
@@ -15,7 +18,7 @@ fn test_int_parser() {
1518
}
1619

1720
let template = IntParserTemplate { s: "💯" };
18-
assert!(matches!(template.render(), Err(rinja::Error::Custom(_))));
21+
assert_matches!(template.render(), Err(rinja::Error::Custom(_)));
1922
assert_eq!(
2023
format!("{}", &template.render().unwrap_err()),
2124
"invalid digit found in string"
@@ -34,17 +37,17 @@ fn fail_fmt() {
3437
}
3538

3639
impl FailFmt {
37-
fn value(&self) -> Result<&'static str, std::fmt::Error> {
40+
fn value(&self) -> Result<&'static str, fmt::Error> {
3841
if let Some(inner) = self.inner {
3942
Ok(inner)
4043
} else {
41-
Err(std::fmt::Error)
44+
Err(fmt::Error)
4245
}
4346
}
4447
}
4548

4649
let template = FailFmt { inner: None };
47-
assert!(matches!(template.render(), Err(rinja::Error::Custom(_))));
50+
assert_matches!(template.render(), Err(rinja::Error::Fmt));
4851
assert_eq!(
4952
format!("{}", &template.render().unwrap_err()),
5053
format!("{}", std::fmt::Error)
@@ -75,9 +78,156 @@ fn fail_str() {
7578
}
7679

7780
let template = FailStr { value: false };
78-
assert!(matches!(template.render(), Err(rinja::Error::Custom(_))));
81+
assert_matches!(template.render(), Err(rinja::Error::Custom(_)));
7982
assert_eq!(format!("{}", &template.render().unwrap_err()), "FAIL");
8083

8184
let template = FailStr { value: true };
8285
assert_eq!(template.render().unwrap(), "hello world");
8386
}
87+
88+
#[test]
89+
fn error_conversion_from_fmt() {
90+
#[derive(Template)]
91+
#[template(source = "{{ value()? }}", ext = "txt")]
92+
struct ResultTemplate {
93+
succeed: bool,
94+
}
95+
96+
impl ResultTemplate {
97+
fn value(&self) -> Result<&'static str, fmt::Error> {
98+
match self.succeed {
99+
true => Ok("hello"),
100+
false => Err(fmt::Error),
101+
}
102+
}
103+
}
104+
105+
assert_matches!(
106+
ResultTemplate { succeed: true }.render().as_deref(),
107+
Ok("hello")
108+
);
109+
assert_matches!(
110+
ResultTemplate { succeed: false }.render().as_deref(),
111+
Err(rinja::Error::Fmt)
112+
);
113+
}
114+
115+
#[test]
116+
fn error_conversion_from_rinja_custom() {
117+
#[derive(Template)]
118+
#[template(source = "{{ value()? }}", ext = "txt")]
119+
struct ResultTemplate {
120+
succeed: bool,
121+
}
122+
123+
impl ResultTemplate {
124+
fn value(&self) -> Result<&'static str, rinja::Error> {
125+
match self.succeed {
126+
true => Ok("hello"),
127+
false => Err(rinja::Error::custom(CustomError)),
128+
}
129+
}
130+
}
131+
132+
#[derive(Debug)]
133+
struct CustomError;
134+
135+
impl fmt::Display for CustomError {
136+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137+
f.write_str("custom")
138+
}
139+
}
140+
141+
impl std::error::Error for CustomError {}
142+
143+
assert_matches!(
144+
ResultTemplate { succeed: true }.render().as_deref(),
145+
Ok("hello")
146+
);
147+
148+
let err = match (ResultTemplate { succeed: false }.render().unwrap_err()) {
149+
rinja::Error::Custom(err) => err,
150+
err => panic!("Expected Error::Custom(_), got {err:#?}"),
151+
};
152+
assert!(err.is::<CustomError>());
153+
}
154+
155+
#[test]
156+
fn error_conversion_from_custom() {
157+
#[derive(Template)]
158+
#[template(source = "{{ value()? }}", ext = "txt")]
159+
struct ResultTemplate {
160+
succeed: bool,
161+
}
162+
163+
impl ResultTemplate {
164+
fn value(&self) -> Result<&'static str, CustomError> {
165+
match self.succeed {
166+
true => Ok("hello"),
167+
false => Err(CustomError),
168+
}
169+
}
170+
}
171+
172+
#[derive(Debug)]
173+
struct CustomError;
174+
175+
impl fmt::Display for CustomError {
176+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177+
f.write_str("custom")
178+
}
179+
}
180+
181+
impl std::error::Error for CustomError {}
182+
183+
assert_matches!(
184+
ResultTemplate { succeed: true }.render().as_deref(),
185+
Ok("hello")
186+
);
187+
188+
let err = match (ResultTemplate { succeed: false }.render().unwrap_err()) {
189+
rinja::Error::Custom(err) => err,
190+
err => panic!("Expected Error::Custom(_), got {err:#?}"),
191+
};
192+
assert!(err.is::<CustomError>());
193+
}
194+
195+
#[test]
196+
fn error_conversion_from_wrapped_in_io() {
197+
#[derive(Template)]
198+
#[template(source = "{{ value()? }}", ext = "txt")]
199+
struct ResultTemplate {
200+
succeed: bool,
201+
}
202+
203+
impl ResultTemplate {
204+
fn value(&self) -> Result<&'static str, io::Error> {
205+
match self.succeed {
206+
true => Ok("hello"),
207+
false => Err(io::Error::new(io::ErrorKind::InvalidData, CustomError)),
208+
}
209+
}
210+
}
211+
212+
#[derive(Debug)]
213+
struct CustomError;
214+
215+
impl fmt::Display for CustomError {
216+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217+
f.write_str("custom")
218+
}
219+
}
220+
221+
impl std::error::Error for CustomError {}
222+
223+
assert_matches!(
224+
ResultTemplate { succeed: true }.render().as_deref(),
225+
Ok("hello")
226+
);
227+
228+
let err = match (ResultTemplate { succeed: false }.render().unwrap_err()) {
229+
rinja::Error::Custom(err) => err,
230+
err => panic!("Expected Error::Custom(_), got {err:#?}"),
231+
};
232+
assert!(err.is::<CustomError>());
233+
}

0 commit comments

Comments
 (0)