@@ -114,6 +114,25 @@ struct IntermediateOutputASTFragments {
114
114
inline_wdf_fn_invocation : ExprCall ,
115
115
}
116
116
117
+ /// Struct to represent a file lock guard. This struct enforces RAII, ensuring
118
+ /// that the file lock is released when the guard goes out of scope.
119
+ struct FileLockGuard {
120
+ file : std:: fs:: File ,
121
+ }
122
+
123
+ impl FileLockGuard {
124
+ fn new ( file : std:: fs:: File , span : Span ) -> Result < Self > {
125
+ FileExt :: lock_exclusive ( & file) . to_syn_result ( span, "unable to obtain file lock" ) ?;
126
+ Ok ( Self { file } )
127
+ }
128
+ }
129
+
130
+ impl Drop for FileLockGuard {
131
+ fn drop ( & mut self ) {
132
+ let _ = FileExt :: unlock ( & self . file ) ;
133
+ }
134
+ }
135
+
117
136
impl StringExt for String {
118
137
fn to_snake_case ( & self ) -> String {
119
138
// There will be, at max, 2 characters unhandled by the 3-char windows. It is
@@ -406,42 +425,60 @@ fn get_wdf_function_info_map(
406
425
types_path : & LitStr ,
407
426
span : Span ,
408
427
) -> Result < BTreeMap < String , CachedFunctionInfo > > {
409
- let scratch_dir = scratch:: path ( concat ! ( env!( "CARGO_CRATE_NAME" ) , "_ast_fragments" ) ) ;
410
- let flock = std:: fs:: File :: create ( scratch_dir. join ( ".lock" ) )
411
- . to_syn_result ( span, "unable to create file" ) ?;
428
+ cfg_if:: cfg_if! {
429
+ if #[ cfg( test) ] {
430
+ let scratch_dir = scratch:: path( concat!( env!( "CARGO_CRATE_NAME" ) , "_ast_fragments_test" ) ) ;
431
+ } else {
432
+ let scratch_dir = scratch:: path( concat!( env!( "CARGO_CRATE_NAME" ) , "_ast_fragments" ) ) ;
433
+ }
434
+ }
412
435
413
436
let cached_function_info_map_path = scratch_dir. join ( "cached_function_info_map.json" ) ;
414
437
415
438
if !cached_function_info_map_path. exists ( ) {
416
- FileExt :: lock_exclusive ( & flock) . to_syn_result ( span, "unable to obtain file lock" ) ?;
439
+ let flock = std:: fs:: File :: create ( scratch_dir. join ( ".lock" ) )
440
+ . to_syn_result ( span, "unable to create file lock" ) ?;
441
+
442
+ // When _flock_guard goes out of scope, the file lock is released
443
+ let _flock_guard = FileLockGuard :: new ( flock, span)
444
+ . to_syn_result ( span, "unable to create file lock guard" ) ?;
417
445
418
446
// Before this thread acquires the lock, it's possible that a concurrent thread
419
447
// already created the cache. If so, this thread skips cache generation.
420
448
if !cached_function_info_map_path. exists ( ) {
421
- let generated_map = create_wdf_function_info_file_cache (
449
+ let function_info_map = create_wdf_function_info_file_cache (
422
450
types_path,
423
- & cached_function_info_map_path,
451
+ cached_function_info_map_path. as_path ( ) ,
424
452
span,
425
453
) ?;
426
-
427
- FileExt :: unlock ( & flock) . to_syn_result ( span, "unable to unlock file lock" ) ?;
428
- return Ok ( generated_map) ;
454
+ return Ok ( function_info_map) ;
429
455
}
430
- FileExt :: unlock ( & flock) . to_syn_result ( span, "unable to unlock file lock" ) ?;
431
456
}
457
+ let function_info_map =
458
+ read_wdf_function_info_file_cache ( cached_function_info_map_path. as_path ( ) , span) ?;
459
+ Ok ( function_info_map)
460
+ }
432
461
433
- let generated_map_string = std:: fs:: read_to_string ( & cached_function_info_map_path)
462
+ /// Reads the cache of function information, then deserializes it into a
463
+ /// `BTreeMap`.
464
+ fn read_wdf_function_info_file_cache (
465
+ cached_function_info_map_path : & std:: path:: Path ,
466
+ span : Span ,
467
+ ) -> Result < BTreeMap < String , CachedFunctionInfo > > {
468
+ let generated_map_string = std:: fs:: read_to_string ( cached_function_info_map_path)
434
469
. to_syn_result ( span, "unable to read cache to string" ) ?;
435
470
let map: BTreeMap < String , CachedFunctionInfo > = serde_json:: from_str ( & generated_map_string)
436
471
. to_syn_result ( span, "unable to parse cache to BTreeMap" ) ?;
437
472
Ok ( map)
438
473
}
439
474
440
- /// This function generates the cache of function information, then
475
+ /// Generates the cache of function information, then
441
476
/// serializes it into a JSON string and writes it to a designated location.
477
+ /// Must obtain an exclusive file lock prior to calling this function to prevent
478
+ /// concurrent threads from reading and writing to the same file.
442
479
fn create_wdf_function_info_file_cache (
443
480
types_path : & LitStr ,
444
- cached_function_info_map_path : & PathBuf ,
481
+ cached_function_info_map_path : & std :: path :: Path ,
445
482
span : Span ,
446
483
) -> Result < BTreeMap < String , CachedFunctionInfo > > {
447
484
let generated_map = generate_wdf_function_info_file_cache ( types_path, span) ?;
@@ -452,7 +489,7 @@ fn create_wdf_function_info_file_cache(
452
489
Ok ( generated_map)
453
490
}
454
491
455
- /// Parse file from `types_path` to generate a `BTreeMap` of
492
+ /// Parses file from `types_path` to generate a `BTreeMap` of
456
493
/// function information, where `key` is the function name and `value` is
457
494
/// the cached function table information.
458
495
fn generate_wdf_function_info_file_cache (
@@ -930,29 +967,31 @@ mod tests {
930
967
use super :: * ;
931
968
932
969
static SCRATCH_DIR : LazyLock < PathBuf > =
933
- LazyLock :: new ( || scratch:: path ( concat ! ( env!( "CARGO_CRATE_NAME" ) , "_ast_fragments " ) ) ) ;
970
+ LazyLock :: new ( || scratch:: path ( concat ! ( env!( "CARGO_CRATE_NAME" ) , "_ast_fragments_test " ) ) ) ;
934
971
const CACHE_FILE_NAME : & str = "cached_function_info_map.json" ;
935
972
936
- fn with_file_lock < F > ( f : F )
973
+ fn with_file_lock_clean_env < F > ( f : F )
937
974
where
938
975
F : FnOnce ( ) ,
939
976
{
940
- // test flock has to be different than the cache flock since `f` can call
941
- // functions that obtain the cache flock
942
- let test_flock = std:: fs:: File :: create ( SCRATCH_DIR . join ( "test.lock" ) ) . unwrap ( ) ;
943
- let cached_function_info_map_path = SCRATCH_DIR . join ( CACHE_FILE_NAME ) ;
944
-
977
+ let test_flock: std:: fs:: File =
978
+ std:: fs:: File :: create ( SCRATCH_DIR . join ( "test.lock" ) ) . unwrap ( ) ;
945
979
FileExt :: lock_exclusive ( & test_flock) . unwrap ( ) ;
946
980
947
- // make sure environment is clean
948
- pretty_assert_eq ! ( cached_function_info_map_path. exists( ) , false ) ;
981
+ let cached_function_info_map_path = SCRATCH_DIR . join ( CACHE_FILE_NAME ) ;
982
+
983
+ pretty_assert_eq ! (
984
+ cached_function_info_map_path. exists( ) ,
985
+ false ,
986
+ "could not remove file {}" ,
987
+ cached_function_info_map_path. display( )
988
+ ) ;
949
989
950
990
f ( ) ;
951
991
952
992
if cached_function_info_map_path. exists ( ) {
953
- std:: fs:: remove_file ( & cached_function_info_map_path) . unwrap ( ) ;
993
+ std:: fs:: remove_file ( cached_function_info_map_path) . unwrap ( ) ;
954
994
}
955
- pretty_assert_eq ! ( cached_function_info_map_path. exists( ) , false ) ;
956
995
957
996
FileExt :: unlock ( & test_flock) . unwrap ( ) ;
958
997
}
@@ -1126,7 +1165,7 @@ mod tests {
1126
1165
1127
1166
#[ test]
1128
1167
fn valid_input ( ) {
1129
- with_file_lock ( || {
1168
+ with_file_lock_clean_env ( || {
1130
1169
let inputs = Inputs {
1131
1170
types_path : parse_quote ! { "tests/unit-tests-input/generated-types.rs" } ,
1132
1171
wdf_function_identifier : format_ident ! ( "WdfDriverCreate" ) ,
@@ -1172,7 +1211,7 @@ mod tests {
1172
1211
1173
1212
#[ test]
1174
1213
fn valid_input_with_no_arguments ( ) {
1175
- with_file_lock ( || {
1214
+ with_file_lock_clean_env ( || {
1176
1215
let inputs = Inputs {
1177
1216
types_path : parse_quote ! { "tests/unit-tests-input/generated-types.rs" } ,
1178
1217
wdf_function_identifier : format_ident ! ( "WdfVerifierDbgBreakPoint" ) ,
@@ -1199,7 +1238,7 @@ mod tests {
1199
1238
1200
1239
#[ test]
1201
1240
fn valid_input_no_cache ( ) {
1202
- with_file_lock ( || {
1241
+ with_file_lock_clean_env ( || {
1203
1242
let inputs = Inputs {
1204
1243
types_path : parse_quote ! { "tests/unit-tests-input/generated-types.rs" } ,
1205
1244
wdf_function_identifier : format_ident ! ( "WdfVerifierDbgBreakPoint" ) ,
@@ -1241,7 +1280,7 @@ mod tests {
1241
1280
1242
1281
#[ test]
1243
1282
fn valid_input_cache_exists ( ) {
1244
- with_file_lock ( || {
1283
+ with_file_lock_clean_env ( || {
1245
1284
let inputs = Inputs {
1246
1285
types_path : parse_quote ! { "tests/unit-tests-input/generated-types.rs" } ,
1247
1286
wdf_function_identifier : format_ident ! ( "WdfVerifierDbgBreakPoint" ) ,
0 commit comments