diff --git a/configs/cliff/resnet50_pw3d_cache.py b/configs/cliff/resnet50_pw3d_cache.py index d0d8becd..ddf7ab1e 100644 --- a/configs/cliff/resnet50_pw3d_cache.py +++ b/configs/cliff/resnet50_pw3d_cache.py @@ -21,7 +21,7 @@ # dict(type='TensorboardLoggerHook') ]) -img_resolution = (192, 256) +img_res = (192, 256) # model settings model = dict( @@ -94,7 +94,7 @@ dict(type='RandomHorizontalFlip', flip_prob=0.5, convention='smpl_54'), dict(type='GetRandomScaleRotation', rot_factor=30, scale_factor=0.25), dict(type='GetBboxInfo'), - dict(type='MeshAffine', img_res=img_resolution), + dict(type='MeshAffine', img_res=img_res), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='ToTensor', keys=data_keys), @@ -111,7 +111,7 @@ dict(type='LoadImageFromFile'), dict(type='GetRandomScaleRotation', rot_factor=0, scale_factor=0), dict(type='GetBboxInfo'), - dict(type='MeshAffine', img_res=img_resolution), + dict(type='MeshAffine', img_res=img_res), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='ToTensor', keys=data_keys), @@ -125,12 +125,21 @@ ] inference_pipeline = [ - dict(type='MeshAffine', img_res=img_resolution), + dict(type='GetBboxInfo'), + dict(type='MeshAffine', img_res=img_res), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), + dict( + type='ToTensor', + keys=[ + 'bbox_info', 'img_h', 'img_w', 'center', 'scale', 'focal_length' + ]), dict( type='Collect', - keys=['img', 'sample_idx'], + keys=[ + 'img', 'sample_idx', 'bbox_info', 'img_h', 'img_w', 'center', + 'scale', 'focal_length' + ], meta_keys=['image_path', 'center', 'scale', 'rotation']) ] diff --git a/mmhuman3d/apis/inference.py b/mmhuman3d/apis/inference.py index ce19fbbc..da02962d 100644 --- a/mmhuman3d/apis/inference.py +++ b/mmhuman3d/apis/inference.py @@ -183,22 +183,20 @@ def inference_image_based_model( batch_data = collate(batch_data, samples_per_gpu=1) - if next(model.parameters()).is_cuda: - # scatter not work so just move image to cuda device - batch_data['img'] = batch_data['img'].to(device) - # get all img_metas of each bounding box batch_data['img_metas'] = [ img_metas[0] for img_metas in batch_data['img_metas'].data ] + if next(model.parameters()).is_cuda: + # scatter not work so just move image to cuda device + batch_data = dict( + map( + lambda item: item if not isinstance(item[1], torch.Tensor) else + (item[0], item[1].to(device)), batch_data.items())) # forward the model with torch.no_grad(): - results = model( - img=batch_data['img'], - img_metas=batch_data['img_metas'], - sample_idx=batch_data['sample_idx'], - ) + results = model(**batch_data) for idx in range(len(det_results)): mesh_result = det_results[idx].copy()