Skip to content

Adding error messaging for unsupported tensor shapes #2089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions monai/csrc/filtering/bilateral/bilateral.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
Copyright 2020 - 2021 MONAI Consortium
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include <torch/extension.h>
#include <stdexcept>
#include <string>

#include "bilateral.h"
#include "utils/common_utils.h"

torch::Tensor BilateralFilter(torch::Tensor input, float spatial_sigma, float color_sigma, bool usePHL) {
torch::Tensor (*filterFunction)(torch::Tensor, float, float);

#ifdef WITH_CUDA

if (torch::cuda::is_available() && input.is_cuda()) {
CHECK_CONTIGUOUS_CUDA(input);

if (input.size(1) > BF_CUDA_MAX_CHANNELS) {
throw std::runtime_error(
"Bilateral filtering not implemented for channel count > " + std::to_string(BF_CUDA_MAX_CHANNELS));
}

if (input.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) {
throw std::runtime_error(
"Bilateral filtering not implemented for spatial dimension > " +
std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION));
}

filterFunction = usePHL ? &BilateralFilterPHLCuda : &BilateralFilterCuda;
} else {
filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu;
}
#else
filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu;
#endif

return filterFunction(input, spatial_sigma, color_sigma);
}
21 changes: 4 additions & 17 deletions monai/csrc/filtering/bilateral/bilateral.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ limitations under the License.
#pragma once

#include <torch/extension.h>
#include "utils/common_utils.h"

#define BF_CUDA_MAX_CHANNELS 16
#define BF_CUDA_MAX_SPATIAL_DIMENSION 3

torch::Tensor BilateralFilterCpu(torch::Tensor input, float spatial_sigma, float color_sigma);
torch::Tensor BilateralFilterPHLCpu(torch::Tensor input, float spatial_sigma, float color_sigma);
Expand All @@ -24,19 +26,4 @@ torch::Tensor BilateralFilterCuda(torch::Tensor input, float spatial_sigma, floa
torch::Tensor BilateralFilterPHLCuda(torch::Tensor input, float spatial_sigma, float color_sigma);
#endif

torch::Tensor BilateralFilter(torch::Tensor input, float spatial_sigma, float color_sigma, bool usePHL) {
torch::Tensor (*filterFunction)(torch::Tensor, float, float);

#ifdef WITH_CUDA
if (torch::cuda::is_available() && input.is_cuda()) {
CHECK_CONTIGUOUS_CUDA(input);
filterFunction = usePHL ? &BilateralFilterPHLCuda : &BilateralFilterCuda;
} else {
filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu;
}
#else
filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu;
#endif

return filterFunction(input, spatial_sigma, color_sigma);
}
torch::Tensor BilateralFilter(torch::Tensor input, float spatial_sigma, float color_sigma, bool usePHL);
3 changes: 2 additions & 1 deletion monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
#include <cuda_runtime.h>
#include <torch/extension.h>

#include "bilateral.h"
#include "utils/meta_macros.h"
#include "utils/tensor_description.h"

Expand Down Expand Up @@ -253,7 +254,7 @@ torch::Tensor BilateralFilterCuda(torch::Tensor inputTensor, float spatialSigma,
torch::Tensor outputTensor = torch::zeros_like(inputTensor);

#define CASE(c, d) BilateralFilterCuda<c, d>(inputTensor, outputTensor, spatialSigma, colorSigma);
SWITCH_AB(CASE, 16, 3, inputTensor.size(1), inputTensor.dim() - 2);
SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, inputTensor.size(1), inputTensor.dim() - 2);

return outputTensor;
}
3 changes: 2 additions & 1 deletion monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
#include <cuda_runtime.h>
#include <torch/extension.h>

#include "bilateral.h"
#include "filtering/permutohedral/permutohedral.h"
#include "utils/meta_macros.h"
#include "utils/tensor_description.h"
Expand Down Expand Up @@ -135,7 +136,7 @@ torch::Tensor BilateralFilterPHLCuda(torch::Tensor inputTensor, float spatialSig
inputTensor, outputTensor, spatialSigma, colorSigma); \
}));

SWITCH_AB(CASE, 16, 3, inputTensor.size(1), inputTensor.dim() - 2);
SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, inputTensor.size(1), inputTensor.dim() - 2);

return outputTensor;
}
15 changes: 14 additions & 1 deletion monai/csrc/filtering/permutohedral/permutohedral.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

#include <stdexcept>
#include <string>

#include "utils/common_utils.h"
#include "utils/meta_macros.h"

Expand Down Expand Up @@ -46,6 +49,16 @@ torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) {
if (torch::cuda::is_available() && data.is_cuda()) {
CHECK_CONTIGUOUS_CUDA(data);

if (channelCount > PHL_CUDA_MAX_CHANNELS) {
throw std::runtime_error(
"PHL filtering not implemented for channel count > " + std::to_string(PHL_CUDA_MAX_CHANNELS));
}

if (featureCount > PHL_CUDA_MAX_FEATURES) {
throw std::runtime_error(
"PHL filtering not implemented for feature count > " + std::to_string(PHL_CUDA_MAX_FEATURES));
}

#define CASE(dc, fc) \
AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCuda", ([&] { \
for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) { \
Expand All @@ -55,7 +68,7 @@ torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) {
PermutohedralCuda<scalar_t, dc, fc>(offsetData, offsetFeatures, elementCount, true); \
} \
}));
SWITCH_AB(CASE, 16, 19, channelCount, featureCount);
SWITCH_AB(CASE, PHL_CUDA_MAX_CHANNELS, PHL_CUDA_MAX_FEATURES, channelCount, featureCount);

} else {
#endif
Expand Down
6 changes: 5 additions & 1 deletion monai/csrc/filtering/permutohedral/permutohedral.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

#pragma once

#include <torch/extension.h>

#pragma once
#define PHL_CUDA_MAX_CHANNELS 16
#define PHL_CUDA_MAX_FEATURES 19

template <typename scalar_t>
void PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount);
#ifdef WITH_CUDA
Expand Down
3 changes: 2 additions & 1 deletion monai/csrc/filtering/permutohedral/permutohedral_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/

#define BLOCK_SIZE 64
#define BLOCK_SIZE 32

#include <cuda.h>
#include <cuda_runtime.h>
Expand All @@ -47,6 +47,7 @@ SOFTWARE.
#include <THC/THCAtomics.cuh>

#include "hash_table.cuh"
#include "permutohedral.h"
#include "utils/meta_macros.h"

template <typename scalar_t>
Expand Down