diff --git a/core/compiler.cpp b/core/compiler.cpp
index bf128b714a..118ca7aa1c 100644
--- a/core/compiler.cpp
+++ b/core/compiler.cpp
@@ -11,7 +11,6 @@
 
 #include "torch/csrc/jit/frontend/function_schema_parser.h"
 #include "torch/csrc/jit/ir/ir.h"
-#include "torch/csrc/jit/ir/ir_views.h"
 #include "torch/csrc/jit/passes/graph_fuser.h"
 #include "torch/csrc/jit/passes/loop_unrolling.h"
 #include "torch/csrc/jit/passes/lower_graph.h"
@@ -128,179 +127,54 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
   return conversion::VerifyConverterSupportForBlock(g->block());
 }
 
-void AddSegmentedBlockToGraph(
-    std::shared_ptr<torch::jit::Graph>& g,
-    partitioning::SegmentedBlock& seg,
-    std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
-  // old_to_new_g contains: original global graph value => new global graph value,
-  // mini_to_new_g: mini graph value -> new graph value
-  std::unordered_map<torch::jit::Value*, torch::jit::Value*> mini_to_new_g;
-  size_t input_idx = 0;
-  if (seg.target() == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) {
-    if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
-      auto self = g->insertInput(0, "self_1");
-      self->setType(seg.inputs()[0]->type());
-    }
-    mini_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0];
-  }
-
-  for (auto& raw_input : seg.raw_inputs()) {
-    if (old_to_new_g.count(raw_input)) {
-      mini_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input];
-    }
-  }
-
-  for (const auto n : seg.nodes()) {
-    util::cloneNode(n, g, mini_to_new_g);
-  }
-
-  // original graph value => new global graph value
-  for (size_t i = 0; i < seg.raw_outputs().size(); ++i) {
-    old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]];
-  }
-  size_t offset = seg.target() == partitioning::SegmentedBlock::kTensorRT ? 1 : 0;
-  for (size_t i = 0; i < seg.raw_inputs().size(); ++i) {
-    if (!old_to_new_g.count(seg.raw_inputs()[i])) {
-      old_to_new_g[seg.raw_inputs()[i]] = mini_to_new_g[seg.inputs()[i + offset]];
-    }
-  }
-
-  return;
-}
-
-typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
-    GraphAndMapping;
-
-void AddIfBlockToGraph(
-    std::shared_ptr<torch::jit::Graph>& new_g,
-    torch::jit::Node* if_node,
-    const std::vector<GraphAndMapping>& graph_and_mappings,
-    std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
-  torch::jit::IfView if_view(if_node);
-
-  // create a new if node in new_g and add corresponding inputs
-  auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0));
-  new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g));
-
-  // iterate over all blocks and add them to new created prim::If
-  for (auto graph_and_mapping : graph_and_mappings) {
-    auto new_if_block = new_if->addBlock();
-    auto cur_block_graph = graph_and_mapping.first;
-    auto cur_block_mapping = graph_and_mapping.second;
-    std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
-    for (auto& i : cur_block_mapping) {
-      // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then
-      // it's mini graph's input
-      if (old_to_new_g.count(i.first)) {
-        block_graph_to_new_g[i.second] = old_to_new_g[i.first];
-      }
-    }
-
-    auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); };
-    new_if_block->cloneFrom(cur_block_graph->block(), env);
-    if (cur_block_graph->inputs().size() &&
-        cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
-      if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
-        auto self = new_g->insertInput(0, "self_1");
-        self->setType(cur_block_graph->inputs()[0]->type());
-      }
-      block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0];
-    }
-    for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) {
-      new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]);
-      new_if_block->eraseInput(i);
-    }
-  }
-  for (auto ov : if_view.outputs()) {
-    auto no = new_if->addOutput();
-    old_to_new_g[ov] = no;
-    no->copyMetadata(ov);
-  }
-  return;
-}
-
-GraphAndMapping ConstructFallbackGraph(
+partitioning::GraphAndMapping BuildHybridGraph(
     torch::jit::script::Module& new_mod,
     torch::jit::Block* block,
-    std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
     CompileSpec cfg,
     ir::StaticParams static_params,
-    std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
-  auto convert_cfg = cfg.convert_info;
-  auto partition_info = cfg.partition_info;
-
-  auto new_g = std::make_shared<torch::jit::Graph>();
-
-  auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info, fallback_nodes);
-
-  // the mapping from lowering graph => fallback global graph
-  std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
-  for (auto input : block->inputs()) {
-    util::getOrAddInputForValue(input, new_g, old_to_new_g);
-  }
-
-  for (auto& seg_block : segmented_blocks) {
-    LOG_INFO(seg_block << "(GraphInSegmentedBlock)\n");
-    std::ostringstream trt_engine_id;
-    trt_engine_id << reinterpret_cast<const int*>(&seg_block);
-
-    if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
-      auto shapes = seg_block.in_shapes();
-      auto types = seg_block.in_types();
-      std::vector<ir::Input> inputs;
-      for (size_t i = 0; i < shapes.size(); i++) {
-        auto in = ir::Input(shapes[i]);
-        in.dtype = util::ScalarTypeToTRTDataType(types[i]);
-        inputs.push_back(in);
-      }
-      // update the input ranges for each segments
-      convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
-
-      // TODO mapping Inputs Ivalue to flatten one here
-      auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
-      auto temp_g = std::make_shared<torch::jit::Graph>();
-      auto device_spec = convert_cfg.engine_settings.device;
-      auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
-      AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
-
-      seg_block.update_graph(temp_g);
-      AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
-    } else {
-      if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) {
-        auto if_node = seg_block.raw_nodes()[0];
-
-        // convert the 2 blocks in prim::if and get the converted graph with mappings
-        std::vector<GraphAndMapping> graph_and_mappings;
-        for (auto cur_block : if_node->blocks()) {
-          graph_and_mappings.push_back(
-              ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes));
+    ir::CollectionTypeMap first_use_types) {
+  auto convert_info = cfg.convert_info;
+  auto partitioning_info = cfg.partitioning_info;
+
+  auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
+  auto collection_input_ivalues_map =
+      partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);
+
+  partitioning::partition(&partitioning_ctx, collection_input_ivalues_map);
+
+  for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
+    partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
+
+    for (auto& seg_block : segmented_blocks) {
+      LOG_INFO("Block segment:" << seg_block);
+      std::ostringstream trt_engine_id;
+      trt_engine_id << reinterpret_cast<const int*>(&seg_block);
+
+      if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
+        auto shapes = seg_block.in_shapes();
+        auto types = seg_block.in_types();
+        std::vector<ir::Input> inputs;
+        for (size_t i = 0; i < shapes.size(); i++) {
+          auto in = ir::Input(shapes[i]);
+          in.dtype = util::ScalarTypeToTRTDataType(types[i]);
+          inputs.push_back(in);
         }
-        AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
+        // update the input ranges for each segments
+        convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
 
-      } else {
-        AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
-      }
-    }
-  }
+        // TODO mapping Inputs Ivalue to flatten one here
+        auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params);
+        auto temp_g = std::make_shared<torch::jit::Graph>();
+        auto device_spec = convert_info.engine_settings.device;
+        auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
+        AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
 
-  if (block->outputs().size() > 1) {
-    std::vector<torch::jit::Value*> fallback_graph_vector;
-    for (auto& output : block->outputs()) {
-      if (old_to_new_g.count(output)) {
-        fallback_graph_vector.push_back(old_to_new_g[output]);
+        seg_block.update_graph(temp_g);
       }
     }
-    torch::jit::ArrayRef<torch::jit::Value*> fallback_graph_outputs(fallback_graph_vector);
-    auto return_tuple_node = new_g->createTuple(fallback_graph_outputs);
-    new_g->block()->appendNode(return_tuple_node);
-    // Set the output as the produced tuple
-    new_g->registerOutput(return_tuple_node->outputs()[0]);
-  } else {
-    if (block->outputs().size() && old_to_new_g.count(block->outputs()[0])) {
-      new_g->registerOutput(old_to_new_g[block->outputs()[0]]);
-    }
   }
-  return {new_g, old_to_new_g};
+
+  return partitioning::stitch(&partitioning_ctx, block);
 }
 
 void MapInputsAndDetermineDTypes(
@@ -310,6 +184,8 @@ void MapInputsAndDetermineDTypes(
     ir::CollectionTypeMap& first_use_type_map) {
   cfg.convert_info.collection_input_spec_map =
       std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
+  cfg.partitioning_info.collection_input_spec_map =
+      ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map);
 
   auto collection_inputs = ir::get_collection_inputs(g, static_params);
   LOG_DEBUG(
@@ -339,7 +215,7 @@ void MapInputsAndDetermineDTypes(
             "Cannot infer input type from calcuations in graph for input "
             << in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
         spec[i].dtype = nvinfer1::DataType::kFLOAT;
-      } else if (spec[i].dtype_is_user_defined && cfg.partition_info.enabled) {
+      } else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) {
         if (!est_type_opt[i]) {
           LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
           std::stringstream ss;
@@ -424,22 +300,18 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
       MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
       auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
       auto outputIsCollection = conversion::OutputIsCollection(g->block());
-      if (cfg.partition_info.enabled &&
+      if (cfg.partitioning_info.enabled &&
           (cfg.lower_info.forced_fallback_modules.size() == 0 &&
-           cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
+           cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
           !outputIsCollection) {
         LOG_INFO("Skipping partitioning since model is fully supported");
       }
 
-      if (cfg.partition_info.enabled &&
+      if (cfg.partitioning_info.enabled &&
           (!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
-             cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
+             cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
            outputIsCollection)) {
-        std::unordered_map<torch::jit::Node*, int> fallback_nodes;
-        auto collection_input_ivalues_map =
-            partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);
-        auto graph_and_mapping = ConstructFallbackGraph(
-            new_mod, g->block(), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
+        auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types);
         new_g = graph_and_mapping.first;
         // renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
         for (size_t i = 0; i < new_g->inputs().size(); ++i) {
diff --git a/core/compiler.h b/core/compiler.h
index c8dc85020b..1b7b3defe8 100644
--- a/core/compiler.h
+++ b/core/compiler.h
@@ -19,7 +19,7 @@ struct CompileSpec {
   ir::GraphInputs graph_inputs;
   conversion::ConversionInfo convert_info;
   lowering::LowerInfo lower_info;
-  partitioning::PartitionInfo partition_info;
+  partitioning::PartitioningInfo partitioning_info;
 };
 
 bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);
diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp
index 8bbae296c3..5442440422 100644
--- a/core/lowering/lowering.cpp
+++ b/core/lowering/lowering.cpp
@@ -41,6 +41,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
     passes::MarkNodesForFallback(g, true);
   }
   passes::UnpackHardSwish(g);
+  passes::UnpackHardSigmoid(g);
   passes::EliminateExceptionOrPassPattern(g);
   passes::ReduceToOperation(g);
   passes::ReduceGelu(g);
diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD
index 1f6a0cde8f..d5f3616f8d 100644
--- a/core/lowering/passes/BUILD
+++ b/core/lowering/passes/BUILD
@@ -30,6 +30,7 @@ cc_library(
         "silu_to_sigmoid_multiplication.cpp",
         "unpack_addmm.cpp",
         "unpack_batch_norm.cpp",
+        "unpack_hardsigmoid.cpp",
         "unpack_hardswish.cpp",
         "unpack_log_softmax.cpp",
         "unpack_std.cpp",
diff --git a/core/lowering/passes/CMakeLists.txt b/core/lowering/passes/CMakeLists.txt
index a8cda65e71..48e644a70d 100644
--- a/core/lowering/passes/CMakeLists.txt
+++ b/core/lowering/passes/CMakeLists.txt
@@ -17,6 +17,7 @@ target_sources(${lib_name}
             "${CMAKE_CURRENT_SOURCE_DIR}/silu_to_sigmoid_multiplication.cpp"
             "${CMAKE_CURRENT_SOURCE_DIR}/unpack_addmm.cpp"
             "${CMAKE_CURRENT_SOURCE_DIR}/unpack_batch_norm.cpp"
+            "${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardsigmoid.cpp"
             "${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardswish.cpp"
             "${CMAKE_CURRENT_SOURCE_DIR}/unpack_log_softmax.cpp"
             "${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp"
diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h
index 73bd9f61d7..3b946593e2 100644
--- a/core/lowering/passes/passes.h
+++ b/core/lowering/passes/passes.h
@@ -38,6 +38,7 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
 void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
 void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
 void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
+void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);
 
 } // namespace passes
 } // namespace lowering
diff --git a/core/lowering/passes/unpack_hardsigmoid.cpp b/core/lowering/passes/unpack_hardsigmoid.cpp
new file mode 100644
index 0000000000..876196215a
--- /dev/null
+++ b/core/lowering/passes/unpack_hardsigmoid.cpp
@@ -0,0 +1,43 @@
+#include "torch/csrc/jit/passes/subgraph_rewrite.h"
+
+#include "core/util/prelude.h"
+
+namespace torch_tensorrt {
+namespace core {
+namespace lowering {
+namespace passes {
+
+void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph) {
+  std::string hardsigmoid_pattern = R"IR(
+        graph(%input):
+            %result = aten::hardsigmoid(%input)
+            return (%result))IR";
+
+  std::string hardsigmoid_pattern_inplace = R"IR(
+        graph(%input):
+            %result = aten::hardsigmoid_(%input)
+            return (%result))IR";
+
+  std::string new_pattern = R"IR(
+        graph(%x.1):
+            %22 : float = prim::Constant[value=0.5]()
+            %3 : int = prim::Constant[value=6]()
+            %5 : int = prim::Constant[value=1]()
+            %10 : int = prim::Constant[value=0]()
+            %4 : Tensor = aten::div(%x.1, %3)
+            %9 : Tensor = aten::add(%4, %22, %5)
+            %21 : Tensor = aten::clamp(%9, %10, %5)
+            return (%21))IR";
+
+  torch::jit::SubgraphRewriter rewriter;
+  rewriter.RegisterRewritePattern(hardsigmoid_pattern, new_pattern);
+  rewriter.RegisterRewritePattern(hardsigmoid_pattern_inplace, new_pattern);
+  rewriter.runOnGraph(graph);
+
+  LOG_GRAPH("Post unpack hardsigmoid: " << *graph);
+}
+
+} // namespace passes
+} // namespace lowering
+} // namespace core
+} // namespace torch_tensorrt
diff --git a/core/partitioning/BUILD b/core/partitioning/BUILD
index fbc9eeac7a..4204939684 100644
--- a/core/partitioning/BUILD
+++ b/core/partitioning/BUILD
@@ -13,22 +13,21 @@ config_setting(
 cc_library(
     name = "partitioning",
     srcs = [
-        "PartitionInfo.cpp",
-        "SegmentedBlock.cpp",
         "partitioning.cpp",
         "shape_analysis.cpp",
+        "stitching.cpp",
     ],
     hdrs = [
-        "PartitionInfo.h",
-        "SegmentedBlock.h",
         "partitioning.h",
-        "shape_analysis.h",
     ],
     deps = [
         "//core/util:prelude",
         "//core/ir",
         "//core/conversion",
         "//core/lowering",
+        "//core/partitioning/partitioningctx",
+        "//core/partitioning/partitioninginfo",
+        "//core/partitioning/segmentedblock",
     ] + select({
         ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
         "//conditions:default": ["@libtorch//:libtorch"],
@@ -39,10 +38,7 @@ cc_library(
 pkg_tar(
     name = "include",
     srcs = [
-        "PartitionInfo.h",
-        "SegmentedBlock.h",
         "partitioning.h",
-        "shape_analysis.h",
     ],
     package_dir = "core/partitioning/",
 )
diff --git a/core/partitioning/CMakeLists.txt b/core/partitioning/CMakeLists.txt
index 15784f638e..7f83b3d891 100644
--- a/core/partitioning/CMakeLists.txt
+++ b/core/partitioning/CMakeLists.txt
@@ -1,33 +1,39 @@
 set(lib_name "core_partitioning")
 add_library(${lib_name} OBJECT)
 
-target_sources(${lib_name}
-    PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/SegmentedBlock.cpp"
-            "${CMAKE_CURRENT_SOURCE_DIR}/shape_analysis.cpp"
-            "${CMAKE_CURRENT_SOURCE_DIR}/partitioning.cpp"
-            "${CMAKE_CURRENT_SOURCE_DIR}/PartitionInfo.cpp"
-            $<TARGET_OBJECTS:core_conversion>
-    PUBLIC  $<TARGET_OBJECTS:core_ir>
-            $<TARGET_OBJECTS:core_util>
+set(CXX_SRCS
+    "${CMAKE_CURRENT_SOURCE_DIR}/partitioning.cpp"
+    "${CMAKE_CURRENT_SOURCE_DIR}/shape_analysis.cpp"
 )
 
 set(HEADER_FILES
-    "${CMAKE_CURRENT_SOURCE_DIR}/SegmentedBlock.h"
-    "${CMAKE_CURRENT_SOURCE_DIR}/shape_analysis.h"
-    "${CMAKE_CURRENT_SOURCE_DIR}/PartitionInfo.h"
     "${CMAKE_CURRENT_SOURCE_DIR}/partitioning.h"
 )
 
-target_include_directories(${lib_name} PUBLIC "$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}>")
+target_sources(${lib_name}
+    PRIVATE
+        ${CXX_SRCS}
+    PUBLIC
+        $<TARGET_OBJECTS:core_conversion>
+        $<TARGET_OBJECTS:core_ir>
+        $<TARGET_OBJECTS:core_util>
+)
+
 target_link_libraries(${lib_name}
     PUBLIC
-        torch
         TensorRT::nvinfer
+        torch
         core_ir
         core_util
-    PRIVATE
         core_conversion
 )
 
-# Install headers
-install(FILES ${HEADER_FILES} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/partitioning/")
+target_include_directories(${lib_name}
+    PUBLIC "$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}>"
+)
+
+add_subdirectory(partitioningctx)
+add_subdirectory(partitioninginfo)
+add_subdirectory(segmentedblock)
+
+install(FILES ${HEADER_FILES} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/partitioning")
\ No newline at end of file
diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp
index c329b33ef6..eb8c86de50 100644
--- a/core/partitioning/partitioning.cpp
+++ b/core/partitioning/partitioning.cpp
@@ -1,9 +1,7 @@
-#include "partitioning.h"
-
+#include "core/partitioning/partitioning.h"
 #include <queue>
 #include "core/conversion/conversion.h"
 #include "core/conversion/evaluators/evaluators.h"
-#include "core/partitioning/shape_analysis.h"
 #include "torch/csrc/jit/passes/constant_pooling.h"
 #include "torch/csrc/jit/passes/dead_code_elimination.h"
 
@@ -30,6 +28,132 @@ bool containNonTensorOutputs(torch::jit::Node* n) {
   return false;
 }
 
+// Check if the inputs and outputs of the graph are Tensor. If not, then fallback connected nodes
+void setInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
+  // fallback nodes that produce entire graph's nonTensor output
+  for (auto i : block->outputs()) {
+    if (!isTensor(i)) {
+      ctx->setNodeExecutorDecision(i->node(), NodeExecutorDecision::kNON_TENSOR);
+    }
+  }
+
+  // fallback nodes that consume entire graph's nonTensor input
+  for (auto i : block->inputs()) {
+    if (!isTensor(i)) {
+      for (auto use : i->uses()) {
+        ctx->setNodeExecutorDecision(use.user, NodeExecutorDecision::kNON_TENSOR);
+      }
+    }
+  }
+}
+
+// Find and set all explicit fallback nodes (nodes that are unsupported or forced fallback)
+// we use a map to indicate the reason why it's fallback to torch
+// For any node that's not explicitly fallback, we set it to run in TensorRT for now
+void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
+  auto nodes = block->nodes();
+  const auto to_compile_sym = c10::Symbol::attr("to_compile");
+
+  for (const auto n : nodes) {
+    if (n->kind() == torch::jit::prim::Constant) {
+      continue;
+    }
+
+    if (!conversion::OpSupported(n)) {
+      // If the op is not supported by the conversion phase it should run in PyTorch
+      ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kUNSUPPORTED);
+    } else if (ctx->forced_fallback_ops.find(n->kind().toQualString()) != ctx->forced_fallback_ops.end()) {
+      // If the user specifies the op to run in Torch it should run in PyTorch
+      ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kOPERATOR_FALLBACK);
+    } else if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) {
+      // If the user specifies the module containing this op to run in torch it should run in PyTorch
+      ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kMODULE_FALLBACK);
+    } else {
+      // Set the rest nodes to TensorRt
+      ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kCONVERT);
+    }
+  }
+  return;
+}
+
+// For a given set of fallback nodes, check their inputs/outputs, if any inputs/outputs of them are NonTensor,
+// then the nodes that produces/consumes those values should also fallback
+void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::Node*>& initial_fallback_nodes) {
+  // initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
+  std::queue<torch::jit::Node*> q;
+  for (auto& node : initial_fallback_nodes) {
+    q.push(node);
+  }
+
+  while (!q.empty()) {
+    auto cur_node = q.front();
+    q.pop();
+    // for every node that produces this fallback node's NonTensor input, they should fallback too
+    for (auto input : cur_node->inputs()) {
+      if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
+          ctx->shouldNodeRunInTensorRT(input->node())) {
+        ctx->setNodeExecutorDecision(input->node(), NodeExecutorDecision::kNON_TENSOR);
+        q.push(input->node());
+      }
+    }
+    // for every node that consumes this fallback node's NonTensor output, they should fallback too
+    for (auto output : cur_node->outputs()) {
+      if (!isTensor(output)) {
+        for (auto use : output->uses()) {
+          auto node = use.user;
+          if (node->kind() != torch::jit::prim::Constant && ctx->shouldNodeRunInTensorRT(node)) {
+            ctx->setNodeExecutorDecision(node, NodeExecutorDecision::kNON_TENSOR);
+            q.push(node);
+          }
+        }
+      }
+    }
+  }
+}
+
+// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size
+std::vector<torch::jit::Node*> traverseNodesForMinBlockSize(PartitioningCtx* ctx, torch::jit::Block* block) {
+  auto nodes = block->nodes();
+  std::vector<torch::jit::Node*> cur_trt_nodes;
+  std::vector<torch::jit::Node*> min_block_fallback_nodes;
+  for (const auto n : nodes) {
+    if (n->kind() == torch::jit::prim::Constant) {
+      continue;
+    }
+
+    // check if current node fallback or not
+    if (!ctx->shouldNodeRunInTorch(n)) {
+      cur_trt_nodes.push_back(n);
+    } else {
+      if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
+        min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
+      }
+      cur_trt_nodes.clear();
+    }
+  }
+  if (cur_trt_nodes.size() < ctx->settings.min_block_size) {
+    min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
+  }
+  return min_block_fallback_nodes;
+}
+
+// Set the nodes that fallback because of min_block_size
+void setMinBlockFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
+  // first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
+  auto min_block_fallback_nodes = traverseNodesForMinBlockSize(ctx, block);
+
+  // keep fallback until all segments meet the min_block_size requirement
+  while (!min_block_fallback_nodes.empty()) {
+    for (const auto i : min_block_fallback_nodes) {
+      ctx->setNodeExecutorDecision(i, NodeExecutorDecision::kMIN_BLOCK_FALLBACK);
+    }
+    // find the fallback nodes because of dependency with min_block_size caused fallback nodes
+    setNonTensorConnectedNodes(ctx, min_block_fallback_nodes);
+    // keep traverse the graph until there is no node fallback because of min_block_size
+    min_block_fallback_nodes = traverseNodesForMinBlockSize(ctx, block);
+  }
+}
+
 bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) {
   const torch::jit::FunctionSchema* schema = node->maybeSchema();
   if (!schema) {
@@ -96,91 +220,38 @@ std::vector<torch::jit::Node*> getDependencyNodes(
   return stk;
 }
 
-// check if the input and output of the graph is Tensor after collection is enabled. If it is, then fallback related
-// nodes
-void fallback_graph_nontensor_in_out(
-    torch::jit::Block* block,
-    std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
-  // fallback nodes that produce entire graph's nonTensor output
-  for (auto i : block->outputs()) {
-    if (!isTensor(i)) {
-      global_fallback_nodes.insert({i->node(), FallbackNodeType::kNON_TENSOR});
-    }
-  }
-
-  // fallback nodes that consume entire graph's nonTensor input
-  for (auto i : block->inputs()) {
-    if (!isTensor(i)) {
-      for (auto use : i->uses()) {
-        global_fallback_nodes.insert({use.user, FallbackNodeType::kNON_TENSOR});
-      }
-    }
-  }
-}
-
-void find_all_fallback_nodes(
-    std::unordered_map<torch::jit::Node*, int>& initial_fallback_nodes,
-    std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
-  // initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
-  // global_fallback_nodes are the fallback nodes that we maintain globally
-  std::queue<torch::jit::Node*> q;
-  for (auto& node : initial_fallback_nodes) {
-    q.push(node.first);
-  }
-
-  std::unordered_set<torch::jit::Node*> visited_nodes;
-  while (!q.empty()) {
-    auto cur_node = q.front();
-    q.pop();
-    // for every node that produces this fallback node's NonTensor input, they should fallback too
-    for (auto input : cur_node->inputs()) {
-      if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
-          global_fallback_nodes.insert({input->node(), FallbackNodeType::kNON_TENSOR}).second) {
-        q.push(input->node());
-      }
-    }
-    // for every node that consumes this fallback node's NonTensor output, they should fallback too
-    for (auto output : cur_node->outputs()) {
-      if (!isTensor(output)) {
-        for (auto use : output->uses()) {
-          auto node = use.user;
-          if (node->kind() != torch::jit::prim::Constant &&
-              global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
-            q.push(node);
-          }
-        }
-      }
-    }
-  }
-}
-
-void resolveTRTNonTensorInputs(PartitionedGraph& segmented_blocks) {
+void resolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) {
   // if a TRT segment has nonTensor Inputs, the nodes that produce this nonTensor Inputs must in another TensorRT engine
   // because we have already found the interface between Torch and TRT in segmentation phase
   // what we do here is just find the dependency nodes of the TRT segments that have nonTensor inputs
-  for (size_t i = 0; i < segmented_blocks.size(); ++i) {
-    if (segmented_blocks[i].target() == SegmentedBlock::kTensorRT) {
+  PartitionedGraph& cur_partitioned_block = ctx->partitioned_blocks[block];
+  for (size_t i = 0; i < cur_partitioned_block.size(); ++i) {
+    if (cur_partitioned_block[i].target() == SegmentedBlock::kTensorRT) {
       std::vector<torch::jit::Value*> inputs_to_resolve;
-      for (auto input : segmented_blocks[i].raw_inputs()) {
+      for (auto input : cur_partitioned_block[i].raw_inputs()) {
         if (!isTensor(input)) {
           inputs_to_resolve.push_back(input);
         }
       }
       if (!inputs_to_resolve.empty()) {
-        std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(inputs_to_resolve, segmented_blocks[i]);
+        std::vector<torch::jit::Node*> dependency_nodes =
+            getDependencyNodes(inputs_to_resolve, cur_partitioned_block[i]);
         dependency_nodes.insert(
-            dependency_nodes.end(), segmented_blocks[i].raw_nodes().begin(), segmented_blocks[i].raw_nodes().end());
-        segmented_blocks[i] = SegmentedBlock(SegmentedBlock::kTensorRT, dependency_nodes);
+            dependency_nodes.end(),
+            cur_partitioned_block[i].raw_nodes().begin(),
+            cur_partitioned_block[i].raw_nodes().end());
+        cur_partitioned_block[i] = SegmentedBlock(SegmentedBlock::kTensorRT, dependency_nodes);
       }
     }
   }
 }
 
-void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Block* block) {
+void registerSegmentsOutputs(PartitioningCtx* ctx, torch::jit::Block* block) {
   // find the corresponding raw values in original global graph for this segmented block's inputs/outputs
+  PartitionedGraph& cur_partitioned_block = ctx->partitioned_blocks[block];
   auto cmp = [](torch::jit::Value* a, torch::jit::Value* b) { return a->unique() < b->unique(); };
   std::set<torch::jit::Value*, decltype(cmp)> input_values(cmp);
-  for (auto& seg_block : segmented_blocks) {
+  for (auto& seg_block : cur_partitioned_block) {
     for (auto& input : seg_block.raw_inputs()) {
       input_values.insert(input);
     }
@@ -193,7 +264,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
   // should be careful here because some in-place operations don't return any values, there is no output for this kind
   // of segment identify the output for each mini-graph by checking if any value in this graph is used later we
   // shouldn't register nonTensor output for TensorRT segments
-  for (auto& seg_block : segmented_blocks) {
+  for (auto& seg_block : cur_partitioned_block) {
     for (auto& mini_graph_input : input_values) {
       if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) ==
               seg_block.raw_inputs().end() &&
@@ -222,20 +293,21 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
     }
   }
 
-  std::for_each(segmented_blocks.begin(), segmented_blocks.end(), [](SegmentedBlock& seg_block) {
+  std::for_each(cur_partitioned_block.begin(), cur_partitioned_block.end(), [](SegmentedBlock& seg_block) {
     torch::jit::EliminateDeadCode(seg_block.g());
   });
   // erase segments which still have no output
-  segmented_blocks.erase(
+  cur_partitioned_block.erase(
       std::remove_if(
-          segmented_blocks.begin(),
-          segmented_blocks.end(),
+          cur_partitioned_block.begin(),
+          cur_partitioned_block.end(),
           [](SegmentedBlock& seg_block) { return seg_block.raw_outputs().empty(); }),
-      segmented_blocks.end());
+      cur_partitioned_block.end());
 
   return;
 }
 
+// Need to check if this makes sense might be a root cause of some issues of over aggressive fallback
 bool checkLoopEvaluatable(torch::jit::Node* n) {
   bool compile_to_trt = true;
   for (auto bn : n->blocks()[0]->nodes()) {
@@ -250,29 +322,7 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
   return compile_to_trt;
 }
 
-bool check_node_fallback(torch::jit::Node* n, const std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
-  if (fallback_nodes.count(n)) {
-    if (fallback_nodes.at(n) == FallbackNodeType::kUNSUPPORTED) {
-      LOG_GRAPH("Node not supported by conversion: " << util::node_info(n));
-    } else if (fallback_nodes.at(n) == FallbackNodeType::kOPERATOR_FALLBACK) {
-      LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n));
-    } else if (fallback_nodes.at(n) == FallbackNodeType::kMODULE_FALLBACK) {
-      LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n));
-    } else if (fallback_nodes.at(n) == FallbackNodeType::kMIN_BLOCK_FALLBACK) {
-      LOG_GRAPH("Node fallback to Torch because of min_block_size" << util::node_info(n));
-    } else {
-      LOG_GRAPH(
-          "Node fallback to Torch because the NonTensor dependencies with other fallback nodes: "
-          << util::node_info(n));
-    }
-    return false;
-  }
-
-  LOG_GRAPH("Node is going to run in TensorRT: " << util::node_info(n));
-  return true;
-}
-
-void finalize_block(
+void finalizeNewBlock(
     PartitionedGraph& g,
     SegmentedBlock::SegmentedBlockTarget kind,
     std::vector<torch::jit::Node*>& nodes) {
@@ -282,110 +332,38 @@ void finalize_block(
   LOG_DEBUG(g.back());
 }
 
-// use this function to get all initial fallback nodes (nodes that are unsupported or forced fallback)
-// we use a map to indicate the reason why it's fallback to torch
-void get_fallback_nodes(
-    torch::jit::Block* block,
-    const std::unordered_set<std::string>& forced_fallback_ops,
-    std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
-  auto nodes = block->nodes();
-  for (const auto n : nodes) {
-    if (n->kind() == torch::jit::prim::Constant) {
-      continue;
-    }
-
-    // If the op is not supported by the conversion phase it should run in PyTorch
-    if (!conversion::OpSupported(n)) {
-      fallback_nodes.insert({n, FallbackNodeType::kUNSUPPORTED});
-    }
-
-    // If the user specifies the op to run in Torch it should run in PyTorch
-    if (forced_fallback_ops.find(n->kind().toQualString()) != forced_fallback_ops.end()) {
-      fallback_nodes.insert({n, FallbackNodeType::kOPERATOR_FALLBACK});
-    }
-
-    // If the user specifies the module containing this op to run in torch it should run in PyTorch
-    const auto to_compile_sym = c10::Symbol::attr("to_compile");
-    if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) {
-      fallback_nodes.insert({n, FallbackNodeType::kMODULE_FALLBACK});
-    }
-  }
-  return;
+void setNodeExecutorLUT(PartitioningCtx* ctx, torch::jit::Block* block) {
+  // First, find all the explicit fallback nodes that should run in Torch:
+  // 1. nodes that are unsupported
+  // 2. nodes that the user specifies to run in torch
+  // 3. nodes that the user specifies the module containing this op to run in torch
+  // At the same time, set all the rest nodes to NodeExecutorDecision::kCONVERT
+  setExplicitFallbackNodes(ctx, block);
+
+  // Second, check if there is nonTensor input/output for the block, if there is, then fallback the nodes that
+  // consume/produce this nonTensor value
+  setInputsOutputsConnectedNodes(ctx, block);
+
+  // Third, for fallback nodes, if it consumes any NonTensor inputs, then the nodes that produce this
+  // input should also fallback. Similarly, if it produces any NonTensor outputs, then the nodes
+  // that consume this output should also fallback
+  auto cur_fallback_nodes = ctx->getNodesRunInTorch();
+  setNonTensorConnectedNodes(ctx, cur_fallback_nodes);
+
+  // Finally, check if all current tensorrt blocks satisfy the min_block_size requirement.
+  // We need to traverse the whole graph many times here
+  setMinBlockFallbackNodes(ctx, block);
 }
 
-std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size(
-    torch::jit::Block* block,
-    const std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes,
-    size_t min_block_size) {
-  auto nodes = block->nodes();
-  std::vector<torch::jit::Node*> cur_trt_nodes;
-  std::vector<torch::jit::Node*> min_block_fallback_nodes;
-  for (const auto n : nodes) {
-    if (n->kind() == torch::jit::prim::Constant)
-      continue;
-
-    // check if current node fallback or not
-    if (!global_fallback_nodes.count(n)) {
-      // if this node is not in fallback nodes, then it's in trt segments
-      cur_trt_nodes.push_back(n);
-    } else {
-      if (cur_trt_nodes.size() < min_block_size) {
-        min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
-      }
-      cur_trt_nodes.clear();
-    }
-  }
-  if (cur_trt_nodes.size() < min_block_size) {
-    min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
-  }
-  return min_block_fallback_nodes;
-}
-
-void find_min_block_size_fallback_nodes(
-    torch::jit::Block* block,
-    std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes,
-    size_t min_block_size) {
-  // first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
-  auto min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size);
-  std::unordered_map<torch::jit::Node*, int> initial_fallback_nodes;
-
-  // keep fallback until all segments meet the min_block_size requirement
-  while (!min_block_fallback_nodes.empty()) {
-    for (const auto i : min_block_fallback_nodes) {
-      initial_fallback_nodes.insert({i, FallbackNodeType::kMIN_BLOCK_FALLBACK});
-    }
-    global_fallback_nodes.insert(initial_fallback_nodes.begin(), initial_fallback_nodes.end());
-    // find the fallback nodes because of dependency with min_block_size caused fallback nodes
-    find_all_fallback_nodes(initial_fallback_nodes, global_fallback_nodes);
-    // keep traverse the graph until there is no node fallback because of min_block_size
-    min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size);
-  }
-}
-
-PartitionedGraph segment_graph(
-    torch::jit::Block* block,
-    const PartitionInfo& partition_info,
-    std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
-  auto min_block_size = partition_info.min_block_size;
-  std::unordered_set<std::string> forced_fallback_ops(
-      partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end());
-
-  // get the initial fallback nodes (nodes that are unsupported or forced fallback)
-  get_fallback_nodes(block, forced_fallback_ops, global_fallback_nodes);
-
-  // For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this
-  // input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node
-  // that produces this input should also fallback
-  // TODO: don't need to fallback the TensorList related nodes once the collection feature is supported
-  find_all_fallback_nodes(global_fallback_nodes, global_fallback_nodes);
-
-  // find all fallback nodes because of the min_block_size requirement
-  find_min_block_size_fallback_nodes(block, global_fallback_nodes, min_block_size);
+void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
+  // Find all the fallback nodes and build execution decision LUT for all nodes
+  setNodeExecutorLUT(ctx, block);
 
   auto nodes = block->nodes();
-  PartitionedGraph segmented_blocks;
 
   // segment the nodes
+  PartitionedGraph segmented_blocks;
+
   std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
   for (const auto n : nodes) {
     // Skip constant nodes as they are resources for both kinds of modules
@@ -393,22 +371,24 @@ PartitionedGraph segment_graph(
       continue;
     }
     // the outputs of trt subgraph shouldn't be collections
-    if (check_node_fallback(n, global_fallback_nodes)) {
+    if (ctx->shouldNodeRunInTensorRT(n)) {
       in_prog_trt_blk_nodes.push_back(n);
 
       // If there is an active PyTorch block and we have passed the threshold for a valid TRT
       // block then segment and reset the active PyTorch block
-      if (in_prog_trt_blk_nodes.size() >= min_block_size && !in_prog_pyt_blk_nodes.empty()) {
-        finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
+      if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size && !in_prog_pyt_blk_nodes.empty()) {
+        finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
       }
     } else {
       // If there is an active TRT block that is valid segment and reset the active TRT block
       // otherwise add it to the active PyTorch block and reset
-      if (in_prog_trt_blk_nodes.size() >= min_block_size) {
-        finalize_block(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
+      if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size) {
+        finalizeNewBlock(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
       } else {
         LOG_DEBUG(
-            "In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block");
+            "In progress TRT block does not meet minimum block size requirements ("
+            << in_prog_trt_blk_nodes.size() << ", expected at least " << ctx->settings.min_block_size
+            << "), therefore folding into in progress PyTorch block");
         in_prog_pyt_blk_nodes.insert(
             in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
       }
@@ -419,20 +399,20 @@ PartitionedGraph segment_graph(
         LOG_DEBUG(
             "Hit a conditional statement, finializing in progress PYT block and creating a new one for the conditional");
         if (!in_prog_pyt_blk_nodes.empty()) {
-          finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
+          finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
         }
         auto cond_node = std::vector<torch::jit::Node*>{n};
-        finalize_block(segmented_blocks, SegmentedBlock::kTorch, cond_node);
+        finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, cond_node);
         continue;
       } else if (n->kind() == torch::jit::prim::Loop) {
         if (!in_prog_pyt_blk_nodes.empty()) {
-          finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
+          finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
         }
         if (checkLoopEvaluatable(n)) {
           in_prog_trt_blk_nodes.push_back(n);
         } else {
           auto loop_node = std::vector<torch::jit::Node*>{n};
-          finalize_block(segmented_blocks, SegmentedBlock::kTorch, loop_node);
+          finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, loop_node);
         }
         continue;
       }
@@ -442,60 +422,39 @@ PartitionedGraph segment_graph(
 
   // if there is any kTorch nodes left, then either the last nodes are kTorch or last nodes are kTensorRT but num <
   // min_block_size
-  if (in_prog_trt_blk_nodes.size() >= min_block_size) {
-    finalize_block(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
+  if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size) {
+    finalizeNewBlock(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
   }
 
   if (!in_prog_pyt_blk_nodes.empty() || !in_prog_trt_blk_nodes.empty()) {
     in_prog_pyt_blk_nodes.insert(
         in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
-    finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
+    finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
   }
-  return segmented_blocks;
-}
-
-PartitionedGraph Partition(
-    torch::jit::Block* block,
-    std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
-    const PartitionInfo& partition_info,
-    std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
-  LOG_DEBUG(partition_info);
-  // if there is nonTensor input/output for the entire graph, fallback the node that consumes/produces this nonTensor
-  // output
-  fallback_graph_nontensor_in_out(block, global_fallback_nodes);
-
-  // segment lowering global graph into blocks
-  LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");
-  PartitionedGraph segmented_blocks = segment_graph(block, partition_info, global_fallback_nodes);
 
-  // It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
-
-  // resolve nonTensor inputs/outputs
-  resolveTRTNonTensorInputs(segmented_blocks);
-
-  // register input/output torch::jit::Value for segmented graphs
-  LOG_DEBUG("Registering input/output torch::jit::Value for segmented graphs");
-  registerSegmentsOutputs(segmented_blocks, block);
+  ctx->partitioned_blocks.insert({block, segmented_blocks});
+  return;
+}
 
-  // run shape analysis on each segmented block
-  runShapeAnalysis(segmented_blocks, example_tensor_map, partition_info);
+void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
+  LOG_DEBUG(ctx->settings);
 
-  for (uint64_t i = 0; i < segmented_blocks.size(); i++) {
-    segmented_blocks[i].update_id(i);
-  }
+  // Go through all the blocks to do the partitioning
+  for (torch::jit::Block* block : ctx->original_blocks) {
+    // segment lowering global graph into blocks
+    segmentGraph(ctx, block);
 
-  LOG_INFO(segmented_blocks);
+    // It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
+    // resolve nonTensor inputs/outputs
+    resolveTRTNonTensorInputs(ctx, block);
 
-  return segmented_blocks;
-}
+    // register input/output torch::jit::Value for segmented graphs
+    LOG_DEBUG("Registering input/output torch::jit::Value for segmented graphs");
+    registerSegmentsOutputs(ctx, block);
 
-std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g) {
-  os << "Partitioned Graph: [";
-  for (auto b : g) {
-    os << b;
+    // run shape analysis on each segmented block
+    runShapeAnalysis(ctx, block, example_tensor_map);
   }
-  os << "]";
-  return os;
 }
 
 } // namespace partitioning
diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h
index f1eb38df8a..3038f6c52f 100644
--- a/core/partitioning/partitioning.h
+++ b/core/partitioning/partitioning.h
@@ -3,45 +3,30 @@
 #include <iostream>
 #include <vector>
 
+#include "torch/csrc/jit/ir/ir.h"
+
 #include "core/ir/ir.h"
-#include "core/partitioning/PartitionInfo.h"
-#include "core/partitioning/SegmentedBlock.h"
-#include "core/partitioning/shape_analysis.h"
+#include "core/partitioning/partitioningctx/PartitioningCtx.h"
 #include "core/util/prelude.h"
-#include "torch/csrc/jit/ir/ir.h"
 
 namespace torch_tensorrt {
 namespace core {
 namespace partitioning {
 
-typedef std::vector<SegmentedBlock> PartitionedGraph;
-
-enum FallbackNodeType {
-  /// Node is not supported by TensorRT
-  kUNSUPPORTED,
-  /// Node is explicitly forced to fallback to Pytorch due to operator fallback
-  kOPERATOR_FALLBACK,
-  /// Node is explicitly forced to fallback to Pytorch due to module fallback
-  kMODULE_FALLBACK,
-  /// This node is in a TRT segment which does not satisfy min_block_size
-  /// and hence is forced to fallback.
-  kMIN_BLOCK_FALLBACK,
-  /// This node produces/consumes non-tensor inputs
-  kNON_TENSOR,
-};
-
-PartitionedGraph segment_graph(
-    torch::jit::Block* block,
-    const PartitionInfo& partition_info,
-    std::unordered_map<torch::jit::Node*, int>& fallback_nodes);
-
-PartitionedGraph Partition(
-    torch::jit::Block* block,
-    std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
-    const PartitionInfo& partition_info,
-    std::unordered_map<torch::jit::Node*, int>& fallback_nodes);
-
-std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g);
+typedef std::unordered_map<const torch::jit::Value*, torch::jit::IValue> ExampleIValues;
+
+typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
+    GraphAndMapping;
+
+ExampleIValues generateRandomInputs(ir::CollectionInputSpecMap& input_ranges, ir::CollectionTypeMap& input_types);
+
+void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& ivalues_maps);
+
+void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block);
+
+GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block);
+
+void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map);
 
 } // namespace partitioning
 } // namespace core
diff --git a/core/partitioning/partitioningctx/BUILD b/core/partitioning/partitioningctx/BUILD
new file mode 100644
index 0000000000..6895f8d451
--- /dev/null
+++ b/core/partitioning/partitioningctx/BUILD
@@ -0,0 +1,40 @@
+load("@rules_cc//cc:defs.bzl", "cc_library")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
+package(default_visibility = ["//visibility:public"])
+
+config_setting(
+    name = "use_pre_cxx11_abi",
+    values = {
+        "define": "abi=pre_cxx11_abi",
+    },
+)
+
+cc_library(
+    name = "partitioningctx",
+    srcs = [
+        "PartitioningCtx.cpp",
+    ],
+    hdrs = [
+        "PartitioningCtx.h",
+    ],
+    deps = [
+        "//core/util:prelude",
+        "//core/ir",
+        "//core/conversion",
+        "//core/partitioning/segmentedblock",
+        "//core/partitioning/partitioninginfo",
+    ] + select({
+        ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
+        "//conditions:default": ["@libtorch//:libtorch"],
+    }),
+    alwayslink = True,
+)
+
+pkg_tar(
+    name = "include",
+    srcs = [
+        "PartitioningCtx.h",
+    ],
+    package_dir = "core/partitioning/partitioningctx",
+)
diff --git a/core/partitioning/partitioningctx/CMakeLists.txt b/core/partitioning/partitioningctx/CMakeLists.txt
new file mode 100644
index 0000000000..090167f829
--- /dev/null
+++ b/core/partitioning/partitioningctx/CMakeLists.txt
@@ -0,0 +1,12 @@
+set(sub_lib_name "partitioningctx")
+
+target_sources(${lib_name}
+    PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/PartitioningCtx.cpp"
+)
+
+set(HEADER_FILES
+    "${CMAKE_CURRENT_SOURCE_DIR}/PartitioningCtx.h"
+)
+
+# Install headers
+install(FILES ${HEADER_FILES} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/partitioning/${sub_lib_name}")
diff --git a/core/partitioning/partitioningctx/PartitioningCtx.cpp b/core/partitioning/partitioningctx/PartitioningCtx.cpp
new file mode 100644
index 0000000000..7bcaaea120
--- /dev/null
+++ b/core/partitioning/partitioningctx/PartitioningCtx.cpp
@@ -0,0 +1,113 @@
+#include <queue>
+
+#include "core/partitioning/partitioningctx/PartitioningCtx.h"
+#include "core/util/prelude.h"
+
+namespace torch_tensorrt {
+namespace core {
+namespace partitioning {
+
+PartitioningCtx::PartitioningCtx(torch::jit::Block* b, PartitioningInfo info)
+    : settings(info),
+      forced_fallback_ops(info.forced_fallback_operators.begin(), info.forced_fallback_operators.end()) {
+  LOG_DEBUG(settings);
+  _load_nodes_into_decision_map(b);
+}
+
+void PartitioningCtx::_load_nodes_into_decision_map(torch::jit::Block* b) {
+  if (!b->owningNode() || b->owningNode()->kind() != torch::jit::prim::Loop) {
+    original_blocks.push_back(b);
+  }
+  for (const auto n : b->nodes()) {
+    if (n->kind() == torch::jit::prim::Constant) {
+      continue;
+    }
+    node_executor_decision_map[n] = NodeExecutorDecision::kUNKNOWN;
+    for (const auto sub_b : n->blocks()) {
+      _load_nodes_into_decision_map(sub_b);
+    }
+  }
+}
+
+void PartitioningCtx::setNodeExecutorDecision(torch::jit::Node* n, NodeExecutorDecision decision) {
+  auto iter = node_executor_decision_map.find(n);
+  auto prev_decision = NodeExecutorDecision::kUNKNOWN;
+  if (iter != node_executor_decision_map.end()) {
+    prev_decision = iter->second;
+  }
+  LOG_DEBUG("Setting node " << util::node_info(n) << " " << decision << " (previously was " << prev_decision << ")");
+
+  node_executor_decision_map[n] = decision;
+  return;
+}
+
+bool PartitioningCtx::shouldNodeRunInTorch(torch::jit::Node* n) {
+  auto iter = node_executor_decision_map.find(n);
+  auto decision = NodeExecutorDecision::kUNKNOWN;
+
+  if (iter != node_executor_decision_map.end()) {
+    decision = iter->second;
+  }
+  if (decision == NodeExecutorDecision::kCONVERT || decision == NodeExecutorDecision::kUNKNOWN) {
+    return false;
+  } else {
+    return true;
+  }
+}
+
+bool PartitioningCtx::shouldNodeRunInTensorRT(torch::jit::Node* n) {
+  auto iter = node_executor_decision_map.find(n);
+  auto decision = NodeExecutorDecision::kUNKNOWN;
+  if (iter != node_executor_decision_map.end()) {
+    decision = iter->second;
+  }
+
+  if (decision == NodeExecutorDecision::kCONVERT) {
+    return true;
+  } else {
+    return false;
+  }
+}
+
+std::vector<torch::jit::Node*> PartitioningCtx::getNodesRunInTorch() {
+  std::vector<torch::jit::Node*> nodes_run_in_torch;
+  for (auto i : node_executor_decision_map) {
+    if (i.second != NodeExecutorDecision::kCONVERT) {
+      nodes_run_in_torch.push_back(i.first);
+    }
+  }
+  return nodes_run_in_torch;
+}
+
+std::ostream& operator<<(std::ostream& os, const NodeExecutorDecision& format) {
+  switch (format) {
+    case NodeExecutorDecision::kUNSUPPORTED:
+      return os << "to run torch due to lack of converter support";
+    case NodeExecutorDecision::kOPERATOR_FALLBACK:
+      return os << "to run torch due to user expectily requesting op kind runs in torch";
+    case NodeExecutorDecision::kMODULE_FALLBACK:
+      return os << "to run torch due to being a member of a module user has requested to run in torch";
+    case NodeExecutorDecision::kMIN_BLOCK_FALLBACK:
+      return os << "to run torch due owning block not large enough to exceed user specified min_block_size";
+    case NodeExecutorDecision::kNON_TENSOR:
+      return os << "to run torch due to producing or consuming non-tensor values";
+    case NodeExecutorDecision::kCONVERT:
+      return os << "to run in tensorrt";
+    case NodeExecutorDecision::kUNKNOWN:
+    default:
+      return os << "unknown node executor decision";
+  }
+}
+
+std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g) {
+  os << "Partitioned Graph: [";
+  for (auto b : g) {
+    os << b;
+  }
+  os << "]";
+  return os;
+}
+
+} // namespace partitioning
+} // namespace core
+} // namespace torch_tensorrt
diff --git a/core/partitioning/partitioningctx/PartitioningCtx.h b/core/partitioning/partitioningctx/PartitioningCtx.h
new file mode 100644
index 0000000000..ed8e705be5
--- /dev/null
+++ b/core/partitioning/partitioningctx/PartitioningCtx.h
@@ -0,0 +1,72 @@
+#pragma once
+
+#include <cstdint>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "core/partitioning/partitioninginfo/PartitioningInfo.h"
+#include "core/partitioning/segmentedblock/SegmentedBlock.h"
+
+namespace torch_tensorrt {
+namespace core {
+namespace partitioning {
+
+enum NodeExecutorDecision {
+  /// Node is not supported by TensorRT
+  kUNSUPPORTED,
+  /// Node is explicitly forced to fallback to Pytorch due to operator fallback
+  kOPERATOR_FALLBACK,
+  /// Node is explicitly forced to fallback to Pytorch due to module fallback
+  kMODULE_FALLBACK,
+  /// This node is in a TRT segment which does not satisfy min_block_size
+  /// and hence is forced to fallback.
+  kMIN_BLOCK_FALLBACK,
+  /// This node produces/consumes non-tensor inputs
+  kNON_TENSOR,
+  /// This node is going to be converted
+  kCONVERT,
+  /// Sentinel
+  kUNKNOWN,
+};
+
+std::ostream& operator<<(std::ostream& os, const NodeExecutorDecision& format);
+
+typedef std::unordered_map<torch::jit::Node*, NodeExecutorDecision> NodeExecutorDecisionMap;
+
+typedef std::vector<SegmentedBlock> PartitionedGraph;
+
+std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g);
+
+struct UsageInfo {
+  size_t produce_id; // id of segmented block which contains a raw value of a given torch::jit::Value
+  std::vector<size_t> torch_use_id; // ids of segmented blocks which are of type Pytorch
+  std::vector<size_t> tensorrt_use_id; // ids of segmented blocks which are of type TensorRT
+};
+
+struct PartitioningCtx {
+  // TODO: Make the set a part of settings not stand alone
+  PartitioningInfo settings;
+  // records all the original blocks topologically in the module
+  std::vector<torch::jit::Block*> original_blocks;
+  // mapping: node=> execution status
+  NodeExecutorDecisionMap node_executor_decision_map;
+  // LUT of the segmented blocks for each blocks in the module
+  std::unordered_map<torch::jit::Block*, PartitionedGraph> partitioned_blocks;
+  std::unordered_set<std::string> forced_fallback_ops;
+
+  PartitioningCtx(torch::jit::Block* b, PartitioningInfo info);
+  void setNodeExecutorDecision(torch::jit::Node* n, NodeExecutorDecision decision);
+  bool shouldNodeRunInTorch(torch::jit::Node* n);
+  bool shouldNodeRunInTensorRT(torch::jit::Node* n);
+  std::vector<torch::jit::Node*> getNodesRunInTorch();
+
+ private:
+  void _load_nodes_into_decision_map(torch::jit::Block* b);
+};
+
+std::ostream& operator<<(std::ostream& os, const PartitioningCtx& s);
+
+} // namespace partitioning
+} // namespace core
+} // namespace torch_tensorrt
diff --git a/core/partitioning/partitioninginfo/BUILD b/core/partitioning/partitioninginfo/BUILD
new file mode 100644
index 0000000000..74e34d134b
--- /dev/null
+++ b/core/partitioning/partitioninginfo/BUILD
@@ -0,0 +1,39 @@
+load("@rules_cc//cc:defs.bzl", "cc_library")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
+package(default_visibility = ["//visibility:public"])
+
+config_setting(
+    name = "use_pre_cxx11_abi",
+    values = {
+        "define": "abi=pre_cxx11_abi",
+    },
+)
+
+cc_library(
+    name = "partitioninginfo",
+    srcs = [
+        "PartitioningInfo.cpp",
+    ],
+    hdrs = [
+        "PartitioningInfo.h",
+    ],
+    deps = [
+        "//core/util:prelude",
+        "//core/ir",
+        "//core/conversion",
+        "//core/lowering",
+    ] + select({
+        ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
+        "//conditions:default": ["@libtorch//:libtorch"],
+    }),
+    alwayslink = True,
+)
+
+pkg_tar(
+    name = "include",
+    srcs = [
+        "PartitioningInfo.h",
+    ],
+    package_dir = "core/partitioning/partitioninginfo",
+)
diff --git a/core/partitioning/partitioninginfo/CMakeLists.txt b/core/partitioning/partitioninginfo/CMakeLists.txt
new file mode 100644
index 0000000000..86c7388daf
--- /dev/null
+++ b/core/partitioning/partitioninginfo/CMakeLists.txt
@@ -0,0 +1,12 @@
+set(sub_lib_name "partitioninginfo")
+
+target_sources(${lib_name}
+    PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/PartitioningInfo.cpp"
+)
+
+set(HEADER_FILES
+    "${CMAKE_CURRENT_SOURCE_DIR}/PartitioningInfo.h"
+)
+
+# Install headers
+install(FILES ${HEADER_FILES} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/partitioning/${sub_lib_name}")
diff --git a/core/partitioning/PartitionInfo.cpp b/core/partitioning/partitioninginfo/PartitioningInfo.cpp
similarity index 82%
rename from core/partitioning/PartitionInfo.cpp
rename to core/partitioning/partitioninginfo/PartitioningInfo.cpp
index 59e29a9bf1..16bdd7b9a7 100644
--- a/core/partitioning/PartitionInfo.cpp
+++ b/core/partitioning/partitioninginfo/PartitioningInfo.cpp
@@ -2,13 +2,13 @@
 #include <sstream>
 #include <utility>
 
-#include "core/partitioning/PartitionInfo.h"
+#include "core/partitioning/partitioninginfo/PartitioningInfo.h"
 
 namespace torch_tensorrt {
 namespace core {
 namespace partitioning {
 // clang-format off
-std::ostream& operator<<(std::ostream& os, const PartitionInfo& s) {
+std::ostream& operator<<(std::ostream& os, const PartitioningInfo& s) {
   os << "Settings requested for Torch Fallback:" \
      << "\n    \"enabled\": ";
   if (s.enabled) {
diff --git a/core/partitioning/PartitionInfo.h b/core/partitioning/partitioninginfo/PartitioningInfo.h
similarity index 67%
rename from core/partitioning/PartitionInfo.h
rename to core/partitioning/partitioninginfo/PartitioningInfo.h
index dc63597912..8eb052e0fa 100644
--- a/core/partitioning/PartitionInfo.h
+++ b/core/partitioning/partitioninginfo/PartitioningInfo.h
@@ -4,18 +4,21 @@
 #include <string>
 #include <vector>
 
+#include "core/ir/ir.h"
+
 namespace torch_tensorrt {
 namespace core {
 namespace partitioning {
 
-struct PartitionInfo {
+struct PartitioningInfo {
+  ir::CollectionInputSpecMap collection_input_spec_map;
   bool enabled = false;
   uint64_t min_block_size = 1;
   std::vector<std::string> forced_fallback_operators;
   bool truncate_long_and_double;
 };
 
-std::ostream& operator<<(std::ostream& os, const PartitionInfo& s);
+std::ostream& operator<<(std::ostream& os, const PartitioningInfo& s);
 
 } // namespace partitioning
 } // namespace core
diff --git a/core/partitioning/segmentedblock/BUILD b/core/partitioning/segmentedblock/BUILD
new file mode 100644
index 0000000000..8efe1e6b0a
--- /dev/null
+++ b/core/partitioning/segmentedblock/BUILD
@@ -0,0 +1,39 @@
+load("@rules_cc//cc:defs.bzl", "cc_library")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
+package(default_visibility = ["//visibility:public"])
+
+config_setting(
+    name = "use_pre_cxx11_abi",
+    values = {
+        "define": "abi=pre_cxx11_abi",
+    },
+)
+
+cc_library(
+    name = "segmentedblock",
+    srcs = [
+        "SegmentedBlock.cpp",
+    ],
+    hdrs = [
+        "SegmentedBlock.h",
+    ],
+    deps = [
+        "//core/util:prelude",
+        "//core/ir",
+        "//core/conversion",
+        "//core/lowering",
+    ] + select({
+        ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
+        "//conditions:default": ["@libtorch//:libtorch"],
+    }),
+    alwayslink = True,
+)
+
+pkg_tar(
+    name = "include",
+    srcs = [
+        "SegmentedBlock.h",
+    ],
+    package_dir = "core/partitioning/segmentedblock",
+)
diff --git a/core/partitioning/segmentedblock/CMakeLists.txt b/core/partitioning/segmentedblock/CMakeLists.txt
new file mode 100644
index 0000000000..ad6d9ee875
--- /dev/null
+++ b/core/partitioning/segmentedblock/CMakeLists.txt
@@ -0,0 +1,12 @@
+set(sub_lib_name "segmentedblock")
+
+target_sources(${lib_name}
+    PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/SegmentedBlock.cpp"
+)
+
+set(HEADER_FILES
+    "${CMAKE_CURRENT_SOURCE_DIR}/SegmentedBlock.h"
+)
+
+# Install headers
+install(FILES ${HEADER_FILES} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/partitioning/${sub_lib_name}")
diff --git a/core/partitioning/SegmentedBlock.cpp b/core/partitioning/segmentedblock/SegmentedBlock.cpp
similarity index 100%
rename from core/partitioning/SegmentedBlock.cpp
rename to core/partitioning/segmentedblock/SegmentedBlock.cpp
diff --git a/core/partitioning/SegmentedBlock.h b/core/partitioning/segmentedblock/SegmentedBlock.h
similarity index 98%
rename from core/partitioning/SegmentedBlock.h
rename to core/partitioning/segmentedblock/SegmentedBlock.h
index f7d8a0b612..0e04237f63 100644
--- a/core/partitioning/SegmentedBlock.h
+++ b/core/partitioning/segmentedblock/SegmentedBlock.h
@@ -5,7 +5,6 @@
 
 #include "NvInfer.h"
 #include "core/ir/ir.h"
-#include "core/partitioning/PartitionInfo.h"
 #include "torch/csrc/jit/ir/ir.h"
 
 namespace torch_tensorrt {
diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp
index f940c87751..514681a088 100644
--- a/core/partitioning/shape_analysis.cpp
+++ b/core/partitioning/shape_analysis.cpp
@@ -1,9 +1,10 @@
-#include "core/partitioning/shape_analysis.h"
-#include <ATen/ATen.h>
-#include "core/util/prelude.h"
+#include "ATen/ATen.h"
 #include "torch/csrc/jit/api/module.h"
 #include "torch/csrc/jit/passes/constant_pooling.h"
 
+#include "core/partitioning/partitioning.h"
+#include "core/util/prelude.h"
+
 namespace torch_tensorrt {
 namespace core {
 namespace partitioning {
@@ -61,7 +62,7 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
 void getSegmentsOutputByRunning(
     SegmentedBlock& seg_block,
     std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
-    const PartitionInfo& partition_info) {
+    const PartitioningInfo& partitioning_info) {
   // create a module to run the graph
   auto g = seg_block.g();
   auto copy_g = g->copy();
@@ -151,13 +152,13 @@ void getSegmentsOutputByRunning(
       // shape inference
       auto cur_ivalue = ivalues_maps[i];
       at::ScalarType t = cur_ivalue.toTensor().scalar_type();
-      if (!partition_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) {
+      if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) {
         TORCHTRT_THROW_ERROR(
             "Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled");
-      } else if (partition_info.truncate_long_and_double && t == at::kLong) {
+      } else if (partitioning_info.truncate_long_and_double && t == at::kLong) {
         cur_ivalue = cur_ivalue.toTensor().to(at::kInt);
         LOG_WARNING("Truncating graph input type from at::kLong to at::kInt");
-      } else if (partition_info.truncate_long_and_double && t == at::kDouble) {
+      } else if (partitioning_info.truncate_long_and_double && t == at::kDouble) {
         cur_ivalue = cur_ivalue.toTensor().to(at::kFloat);
         LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat");
       }
@@ -180,14 +181,11 @@ void getSegmentsOutputByRunning(
   seg_block.register_intypes(input_types);
 }
 
-void runShapeAnalysis(
-    std::vector<SegmentedBlock>& segmented_blocks,
-    std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
-    const PartitionInfo& partition_info) {
+void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) {
   // register every segment's input shape, and it's running output IValues
-  for (auto& seg_block : segmented_blocks) {
+  for (auto& seg_block : ctx->partitioned_blocks[block]) {
     torch::jit::ConstantPooling(seg_block.g());
-    getSegmentsOutputByRunning(seg_block, example_tensor_map, partition_info);
+    getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings);
   }
   return;
 }
diff --git a/core/partitioning/shape_analysis.h b/core/partitioning/shape_analysis.h
deleted file mode 100644
index 780449d514..0000000000
--- a/core/partitioning/shape_analysis.h
+++ /dev/null
@@ -1,20 +0,0 @@
-#include "core/ir/ir.h"
-#include "core/partitioning/SegmentedBlock.h"
-#include "torch/csrc/jit/ir/ir.h"
-
-namespace torch_tensorrt {
-namespace core {
-namespace partitioning {
-
-std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomInputs(
-    std::unordered_map<const torch::jit::Value*, std::vector<ir::Input>>& input_ranges,
-    std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>>& input_types);
-
-void runShapeAnalysis(
-    std::vector<SegmentedBlock>& segmented_blocks,
-    std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
-    const PartitionInfo& partition_info);
-
-} // namespace partitioning
-} // namespace core
-} // namespace torch_tensorrt
diff --git a/core/partitioning/stitching.cpp b/core/partitioning/stitching.cpp
new file mode 100644
index 0000000000..6ed5a27463
--- /dev/null
+++ b/core/partitioning/stitching.cpp
@@ -0,0 +1,151 @@
+#include "ATen/ATen.h"
+#include "torch/csrc/jit/api/module.h"
+#include "torch/csrc/jit/ir/ir_views.h"
+
+#include "core/partitioning/partitioning.h"
+#include "core/util/prelude.h"
+
+namespace torch_tensorrt {
+namespace core {
+namespace partitioning {
+
+void addSegmentedBlockToGraph(
+    std::shared_ptr<torch::jit::Graph>& g,
+    partitioning::SegmentedBlock& seg,
+    std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
+  // old_to_new_g contains: original global graph value => new global graph value,
+  // mini_to_new_g: mini graph value -> new graph value
+  std::unordered_map<torch::jit::Value*, torch::jit::Value*> mini_to_new_g;
+  size_t input_idx = 0;
+  if (seg.target() == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) {
+    if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
+      auto self = g->insertInput(0, "self_1");
+      self->setType(seg.inputs()[0]->type());
+    }
+    mini_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0];
+  }
+
+  for (auto& raw_input : seg.raw_inputs()) {
+    if (old_to_new_g.count(raw_input)) {
+      mini_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input];
+    }
+  }
+
+  for (const auto n : seg.nodes()) {
+    util::cloneNode(n, g, mini_to_new_g);
+  }
+
+  // original graph value => new global graph value
+  for (size_t i = 0; i < seg.raw_outputs().size(); ++i) {
+    old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]];
+  }
+  size_t offset = seg.target() == partitioning::SegmentedBlock::kTensorRT ? 1 : 0;
+  for (size_t i = 0; i < seg.raw_inputs().size(); ++i) {
+    if (!old_to_new_g.count(seg.raw_inputs()[i])) {
+      old_to_new_g[seg.raw_inputs()[i]] = mini_to_new_g[seg.inputs()[i + offset]];
+    }
+  }
+
+  return;
+}
+
+void addIfBlockToGraph(
+    std::shared_ptr<torch::jit::Graph>& new_g,
+    torch::jit::Node* if_node,
+    const std::vector<GraphAndMapping>& graph_and_mappings,
+    std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
+  torch::jit::IfView if_view(if_node);
+
+  // create a new if node in new_g and add corresponding inputs
+  auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0));
+  new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g));
+
+  // iterate over all blocks and add them to new created prim::If
+  for (auto graph_and_mapping : graph_and_mappings) {
+    auto new_if_block = new_if->addBlock();
+    auto cur_block_graph = graph_and_mapping.first;
+    auto cur_block_mapping = graph_and_mapping.second;
+    std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
+    for (auto& i : cur_block_mapping) {
+      // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then
+      // it's mini graph's input
+      if (old_to_new_g.count(i.first)) {
+        block_graph_to_new_g[i.second] = old_to_new_g[i.first];
+      }
+    }
+
+    auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); };
+    new_if_block->cloneFrom(cur_block_graph->block(), env);
+    if (cur_block_graph->inputs().size() &&
+        cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
+      if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
+        auto self = new_g->insertInput(0, "self_1");
+        self->setType(cur_block_graph->inputs()[0]->type());
+      }
+      block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0];
+    }
+    for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) {
+      new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]);
+      new_if_block->eraseInput(i);
+    }
+  }
+  for (auto ov : if_view.outputs()) {
+    auto no = new_if->addOutput();
+    old_to_new_g[ov] = no;
+    no->copyMetadata(ov);
+  }
+  return;
+}
+
+GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block) {
+  auto new_g = std::make_shared<torch::jit::Graph>();
+
+  // the mapping from lowering graph => fallback global graph
+  std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
+  for (auto input : block->inputs()) {
+    util::getOrAddInputForValue(input, new_g, old_to_new_g);
+  }
+
+  for (auto seg_block : ctx->partitioned_blocks[block]) {
+    LOG_INFO("Block segment:" << seg_block);
+    if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
+      addSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
+    } else {
+      if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) {
+        auto if_node = seg_block.raw_nodes()[0];
+
+        // convert the 2 blocks in prim::if and get the converted graph with mappings
+        std::vector<GraphAndMapping> graph_and_mappings;
+        for (auto cur_block : if_node->blocks()) {
+          graph_and_mappings.push_back(stitch(ctx, cur_block));
+        }
+        addIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
+
+      } else {
+        addSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
+      }
+    }
+  }
+
+  if (block->outputs().size() > 1) {
+    std::vector<torch::jit::Value*> fallback_graph_vector;
+    for (auto& output : block->outputs()) {
+      if (old_to_new_g.count(output)) {
+        fallback_graph_vector.push_back(old_to_new_g[output]);
+      }
+    }
+    torch::jit::ArrayRef<torch::jit::Value*> fallback_graph_outputs(fallback_graph_vector);
+    auto return_tuple_node = new_g->createTuple(fallback_graph_outputs);
+    new_g->block()->appendNode(return_tuple_node);
+    // Set the output as the produced tuple
+    new_g->registerOutput(return_tuple_node->outputs()[0]);
+  } else {
+    if (block->outputs().size() && old_to_new_g.count(block->outputs()[0])) {
+      new_g->registerOutput(old_to_new_g[block->outputs()[0]]);
+    }
+  }
+  return {new_g, old_to_new_g};
+}
+} // namespace partitioning
+} // namespace core
+} // namespace torch_tensorrt
diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp
index cfbc228396..3d7d9b15d3 100644
--- a/cpp/src/compile_spec.cpp
+++ b/cpp/src/compile_spec.cpp
@@ -121,10 +121,10 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
       "require_full_compilation is enabled however the list of modules to run in torch is not empty (Found "
           << external.torch_executed_modules.size() << " modules)");
 
-  internal.partition_info.enabled = !external.require_full_compilation;
-  internal.partition_info.min_block_size = external.min_block_size;
-  internal.partition_info.forced_fallback_operators = std::move(external.torch_executed_ops);
-  internal.partition_info.truncate_long_and_double = external.truncate_long_and_double;
+  internal.partitioning_info.enabled = !external.require_full_compilation;
+  internal.partitioning_info.min_block_size = external.min_block_size;
+  internal.partitioning_info.forced_fallback_operators = std::move(external.torch_executed_ops);
+  internal.partitioning_info.truncate_long_and_double = external.truncate_long_and_double;
   internal.lower_info.forced_fallback_modules = std::move(external.torch_executed_modules);
 
   switch (external.device.device_type) {
diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp
index 96fef793fd..1721ffd6c9 100644
--- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp
+++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp
@@ -313,10 +313,10 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
   info.convert_info.engine_settings.device.gpu_id = device.gpu_id;
   info.convert_info.engine_settings.device.dla_core = device.dla_core;
   info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback;
-  info.partition_info.enabled = torch_fallback.enabled;
-  info.partition_info.min_block_size = torch_fallback.min_block_size;
-  info.partition_info.forced_fallback_operators = torch_fallback.forced_fallback_operators;
-  info.partition_info.truncate_long_and_double = truncate_long_and_double;
+  info.partitioning_info.enabled = torch_fallback.enabled;
+  info.partitioning_info.min_block_size = torch_fallback.min_block_size;
+  info.partitioning_info.forced_fallback_operators = torch_fallback.forced_fallback_operators;
+  info.partitioning_info.truncate_long_and_double = truncate_long_and_double;
   info.lower_info.forced_fallback_modules = torch_fallback.forced_fallback_modules;
   info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double;
 
diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD
index b33685a647..75ae818905 100644
--- a/tests/core/lowering/BUILD
+++ b/tests/core/lowering/BUILD
@@ -75,6 +75,10 @@ lowering_test(
     name = "test_silu_to_sigmoid_multiplication",
 )
 
+lowering_test(
+    name = "test_unpack_hardsigmoid",
+)
+
 lowering_test(
     name = "test_unpack_hardswish",
 )
@@ -98,6 +102,7 @@ test_suite(
         ":test_remove_detach_pass",
         ":test_remove_dropout_pass",
         ":test_remove_unnecessary_casts",
+        ":test_unpack_hardsigmoid",
         ":test_unpack_hardswish",
         ":test_unpack_reduce_ops",
         ":test_view_to_reshape_pass",
diff --git a/tests/core/lowering/test_module_fallback_passes.cpp b/tests/core/lowering/test_module_fallback_passes.cpp
index e6eb098079..5f4ac5f0c2 100644
--- a/tests/core/lowering/test_module_fallback_passes.cpp
+++ b/tests/core/lowering/test_module_fallback_passes.cpp
@@ -100,7 +100,7 @@ TEST(Lowering, LowerAndPartitionSimpleModuleFallbackCorrectly) {
 
   std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 1, 16, 16})};
   torch_tensorrt::core::CompileSpec cfg(input_ranges);
-  cfg.partition_info.enabled = true;
+  cfg.partitioning_info.enabled = true;
   cfg.lower_info.forced_fallback_modules.push_back("ModuleFallbackSub");
 
   auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
diff --git a/tests/core/lowering/test_unpack_hardsigmoid.cpp b/tests/core/lowering/test_unpack_hardsigmoid.cpp
new file mode 100644
index 0000000000..f8206511be
--- /dev/null
+++ b/tests/core/lowering/test_unpack_hardsigmoid.cpp
@@ -0,0 +1,87 @@
+#include <string>
+#include "core/compiler.h"
+#include "core/lowering/passes/passes.h"
+#include "gtest/gtest.h"
+#include "tests/util/util.h"
+#include "torch/csrc/jit/ir/irparser.h"
+#include "torch/csrc/jit/ir/subgraph_matcher.h"
+
+TEST(LoweringPasses, UnpackHardSigmoid) {
+  std::string source_graph = R"IR(
+        graph(%input):
+            %result = aten::hardsigmoid(%input)
+            return (%result))IR";
+
+  std::string target_graph = R"IR(
+        graph(%x.1):
+            %22 : float = prim::Constant[value=0.5]()
+            %3 : int = prim::Constant[value=6]()
+            %5 : int = prim::Constant[value=1]()
+            %10 : int = prim::Constant[value=0]()
+            %4 : Tensor = aten::div(%x.1, %3)
+            %9 : Tensor = aten::add(%4, %22, %5)
+            %21 : Tensor = aten::clamp(%9, %10, %5)
+            return (%21))IR";
+
+  torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
+      torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
+  auto sg = std::make_shared<torch::jit::Graph>();
+  torch::jit::parseIR(source_graph, &*sg);
+
+  auto in = at::rand({10, 100}, {at::kCUDA});
+  auto sg_params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {});
+  auto sg_results = torch_tensorrt::tests::util::RunGraph(sg, sg_params, {in});
+
+  torch_tensorrt::core::lowering::passes::UnpackHardSigmoid(sg);
+
+  auto tg = std::make_shared<torch::jit::Graph>();
+  torch::jit::parseIR(target_graph, &*tg);
+
+  ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
+
+  in = at::clone(in);
+  auto tg_params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {});
+  auto tg_results = torch_tensorrt::tests::util::RunGraph(tg, tg_params, {in});
+
+  ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(sg_results[0], tg_results[0], 2e-6));
+}
+
+TEST(LoweringPasses, UnpackHardSigmoidInPlace) {
+  std::string source_graph = R"IR(
+        graph(%input):
+            %result = aten::hardsigmoid_(%input)
+            return (%result))IR";
+
+  std::string target_graph = R"IR(
+        graph(%x.1):
+            %22 : float = prim::Constant[value=0.5]()
+            %3 : int = prim::Constant[value=6]()
+            %5 : int = prim::Constant[value=1]()
+            %10 : int = prim::Constant[value=0]()
+            %4 : Tensor = aten::div(%x.1, %3)
+            %9 : Tensor = aten::add(%4, %22, %5)
+            %21 : Tensor = aten::clamp(%9, %10, %5)
+            return (%21))IR";
+
+  torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
+      torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
+  auto sg = std::make_shared<torch::jit::Graph>();
+  torch::jit::parseIR(source_graph, &*sg);
+
+  auto in = at::rand({10, 100}, {at::kCUDA});
+  auto sg_params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {});
+  auto sg_results = torch_tensorrt::tests::util::RunGraph(sg, sg_params, {in});
+
+  torch_tensorrt::core::lowering::passes::UnpackHardSigmoid(sg);
+
+  auto tg = std::make_shared<torch::jit::Graph>();
+  torch::jit::parseIR(target_graph, &*tg);
+
+  ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
+
+  in = at::clone(in);
+  auto tg_params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {});
+  auto tg_results = torch_tensorrt::tests::util::RunGraph(tg, tg_params, {in});
+
+  ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(sg_results[0], tg_results[0], 2e-6));
+}
diff --git a/tests/core/lowering/test_view_to_reshape_pass.cpp b/tests/core/lowering/test_view_to_reshape_pass.cpp
index d1f787bc10..a6254bccde 100644
--- a/tests/core/lowering/test_view_to_reshape_pass.cpp
+++ b/tests/core/lowering/test_view_to_reshape_pass.cpp
@@ -66,8 +66,8 @@ TEST(LoweringPasses, ViewToReshapeResultsCorrectly) {
   std::vector<torch_tensorrt::core::ir::Input> inputs;
   inputs.push_back(torch_tensorrt::core::ir::Input({2, 3, 4, 5}));
   torch_tensorrt::core::CompileSpec cfg(inputs);
-  cfg.partition_info.enabled = true;
-  cfg.partition_info.forced_fallback_operators.push_back("aten::permute");
+  cfg.partitioning_info.enabled = true;
+  cfg.partitioning_info.forced_fallback_operators.push_back("aten::permute");
 
   torch::jit::script::Module mod(c10::QualifiedName("module"));
 
diff --git a/tests/core/partitioning/test_conditionals.cpp b/tests/core/partitioning/test_conditionals.cpp
index 424fac86e0..ba336db663 100644
--- a/tests/core/partitioning/test_conditionals.cpp
+++ b/tests/core/partitioning/test_conditionals.cpp
@@ -34,7 +34,7 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) {
   std::vector<torch_tensorrt::core::ir::Input> inputs{torch_tensorrt::core::ir::Input({3, 3, 16, 16})};
   auto g = mod.get_method("forward").graph();
   torch_tensorrt::core::CompileSpec cfg(inputs);
-  cfg.partition_info.enabled = true;
+  cfg.partitioning_info.enabled = true;
   torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
   auto new_g = new_mod.get_method("forward").graph();
 
@@ -65,8 +65,8 @@ TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {
       torch_tensorrt::core::ir::Input({4, 4}), torch_tensorrt::core::ir::Input({4, 4})};
   auto g = mod.get_method("forward").graph();
   torch_tensorrt::core::CompileSpec cfg(inputs);
-  cfg.partition_info.enabled = true;
-  cfg.partition_info.forced_fallback_operators.push_back("prim::ListConstruct");
+  cfg.partitioning_info.enabled = true;
+  cfg.partitioning_info.forced_fallback_operators.push_back("prim::ListConstruct");
 
   auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
   auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
diff --git a/tests/core/partitioning/test_fallback_graph_output.cpp b/tests/core/partitioning/test_fallback_graph_output.cpp
index 3da717074a..f6ce657ae3 100644
--- a/tests/core/partitioning/test_fallback_graph_output.cpp
+++ b/tests/core/partitioning/test_fallback_graph_output.cpp
@@ -28,8 +28,8 @@ TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) {
   std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};
 
   torch_tensorrt::core::CompileSpec cfg(input_ranges);
-  cfg.partition_info.enabled = true;
-  cfg.partition_info.forced_fallback_operators.push_back("aten::add");
+  cfg.partitioning_info.enabled = true;
+  cfg.partitioning_info.forced_fallback_operators.push_back("aten::add");
 
   auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
   auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
@@ -58,8 +58,8 @@ TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) {
   std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};
   auto g = mod.get_method("forward").graph();
   torch_tensorrt::core::CompileSpec cfg(input_ranges);
-  cfg.partition_info.enabled = true;
-  cfg.partition_info.forced_fallback_operators.push_back("aten::hardtanh");
+  cfg.partitioning_info.enabled = true;
+  cfg.partitioning_info.forced_fallback_operators.push_back("aten::hardtanh");
 
   auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
   auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
diff --git a/tests/core/partitioning/test_loading_model.cpp b/tests/core/partitioning/test_loading_model.cpp
index 057aaff2d8..b42368fe3e 100644
--- a/tests/core/partitioning/test_loading_model.cpp
+++ b/tests/core/partitioning/test_loading_model.cpp
@@ -28,7 +28,7 @@ TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) {
   std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};
 
   torch_tensorrt::core::CompileSpec cfg(input_ranges);
-  cfg.partition_info.enabled = true;
+  cfg.partitioning_info.enabled = true;
 
   auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
   auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
diff --git a/tests/core/partitioning/test_loop_fallback.cpp b/tests/core/partitioning/test_loop_fallback.cpp
index 83556b5512..5f6bc2ae4d 100644
--- a/tests/core/partitioning/test_loop_fallback.cpp
+++ b/tests/core/partitioning/test_loop_fallback.cpp
@@ -25,7 +25,7 @@ TEST(Partitioning, CheckLoopFallbackEvalCompilesCorrectly) {
 
   std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 10})};
   torch_tensorrt::core::CompileSpec cfg(input_ranges);
-  cfg.partition_info.enabled = true;
+  cfg.partitioning_info.enabled = true;
 
   auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
   auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
@@ -53,7 +53,7 @@ TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) {
 
   std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 10})};
   torch_tensorrt::core::CompileSpec cfg(input_ranges);
-  cfg.partition_info.enabled = true;
+  cfg.partitioning_info.enabled = true;
 
   auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
   auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp
index 30656a3d9e..950859e524 100644
--- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp
+++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp
@@ -60,10 +60,10 @@ TEST(Partitioning, ResolveNonTensorInputsForIFBlockCorrectly) {
   inputs.push_back(torch_tensorrt::core::ir::Input({3, 4}));
   inputs.push_back(torch_tensorrt::core::ir::Input({3, 4}));
   torch_tensorrt::core::CompileSpec cfg(inputs);
-  cfg.partition_info.enabled = true;
-  cfg.partition_info.forced_fallback_operators.push_back("aten::sub");
+  cfg.partitioning_info.enabled = true;
+  cfg.partitioning_info.forced_fallback_operators.push_back("aten::sub");
   cfg.convert_info.engine_settings.truncate_long_and_double = true;
-  cfg.partition_info.truncate_long_and_double = true;
+  cfg.partitioning_info.truncate_long_and_double = true;
 
   torch::jit::script::Module mod(c10::QualifiedName("module"));
 
@@ -109,8 +109,8 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) {
   auto g = std::make_shared<torch::jit::Graph>();
   torch::jit::parseIR(graph, g.get());
 
-  torch_tensorrt::core::partitioning::PartitionInfo partition_info;
-  partition_info.enabled = true;
+  torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
+  partitioning_info.enabled = true;
   std::vector<torch_tensorrt::core::ir::Input> inputs;
   inputs.push_back(torch_tensorrt::core::ir::Input({1, 3, 16, 16}));
   inputs.push_back(torch_tensorrt::core::ir::Input({16, 3, 3, 3}));
@@ -123,9 +123,10 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) {
     input_types.insert({g->inputs()[i], {{at::kFloat}}});
   }
   auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
-  std::unordered_map<torch::jit::Node*, int> fallback_nodes;
+  torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
+  torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
   std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
-      torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes);
+      ctx.partitioned_blocks.begin()->second;
 
   int torch_block_cnt = 0, trt_block_cnt = 0;
   for (const auto& segmented_block : segmented_blocks) {
@@ -168,8 +169,8 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) {
   auto g = std::make_shared<torch::jit::Graph>();
   torch::jit::parseIR(graph, g.get());
 
-  torch_tensorrt::core::partitioning::PartitionInfo partition_info;
-  partition_info.enabled = true;
+  torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
+  partitioning_info.enabled = true;
   std::vector<torch_tensorrt::core::ir::Input> inputs;
   inputs.push_back(torch_tensorrt::core::ir::Input({1, 3, 16, 16}));
   inputs.push_back(torch_tensorrt::core::ir::Input({16, 6, 3, 3}));
@@ -182,9 +183,11 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) {
     input_types.insert({g->inputs()[i], {{at::kFloat}}});
   }
   auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
-  std::unordered_map<torch::jit::Node*, int> fallback_nodes;
+  torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
+
+  torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
   std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
-      torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes);
+      ctx.partitioned_blocks.begin()->second;
 
   int torch_block_cnt = 0, trt_block_cnt = 0;
   for (const auto& segmented_block : segmented_blocks) {
@@ -244,7 +247,7 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
   std::vector<torch_tensorrt::core::ir::Input> inputs;
   inputs.push_back(torch_tensorrt::core::ir::Input({1, 3, 16, 16}));
   torch_tensorrt::core::CompileSpec cfg(inputs);
-  cfg.partition_info.enabled = true;
+  cfg.partitioning_info.enabled = true;
   torch::jit::script::Module mod(c10::QualifiedName("module"));
 
   auto self = g->insertInput(0, "self_1");
@@ -361,8 +364,8 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
   g->registerOutput(get_ins_node->output());
   g->registerOutput(get_outs_node->output());
 
-  torch_tensorrt::core::partitioning::PartitionInfo partition_info;
-  partition_info.enabled = true;
+  torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
+  partitioning_info.enabled = true;
   std::vector<torch_tensorrt::core::ir::Input> inputs;
   inputs.push_back(torch_tensorrt::core::ir::Input({4, 4}));
   inputs.push_back(torch_tensorrt::core::ir::Input({4, 4}));
@@ -374,9 +377,9 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
     input_types.insert({g->inputs()[i], {{at::kFloat}}});
   }
   auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
-  std::unordered_map<torch::jit::Node*, int> fallback_nodes;
-  auto segmented_blocks =
-      torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes);
+  torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
+  torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
+  auto segmented_blocks = ctx.partitioned_blocks.begin()->second;
 
   int torch_block_cnt = 0, trt_block_cnt = 0;
   for (const auto& segmented_block : segmented_blocks) {
diff --git a/tests/core/partitioning/test_segmentation.cpp b/tests/core/partitioning/test_segmentation.cpp
index bf8a36d081..8d47af553e 100644
--- a/tests/core/partitioning/test_segmentation.cpp
+++ b/tests/core/partitioning/test_segmentation.cpp
@@ -6,9 +6,14 @@
 #include "torch/script.h"
 #include "torch_tensorrt/torch_tensorrt.h"
 
+namespace torch_tensorrt {
+namespace core {
+namespace partitioning {
+namespace tests {
+
 bool checkSegmentedBlockNumber(
-    torch_tensorrt::core::partitioning::PartitionedGraph& segmented_blocks,
-    torch_tensorrt::core::partitioning::SegmentedBlock::SegmentedBlockTarget target,
+    PartitionedGraph& segmented_blocks,
+    SegmentedBlock::SegmentedBlockTarget target,
     int target_count) {
   int64_t cnt = 0;
   for (auto& seg_block : segmented_blocks) {
@@ -27,7 +32,7 @@ bool checkSegmentedBlockNumber(
 }
 
 bool checkSegmentedBlockNodesMapping(
-    std::vector<torch_tensorrt::core::partitioning::SegmentedBlock>& segmented_blocks,
+    std::vector<SegmentedBlock>& segmented_blocks,
     std::shared_ptr<torch::jit::Graph> g,
     std::vector<std::vector<int>> nodes_index) {
   std::vector<torch::jit::Node*> graph_nodes;
@@ -71,17 +76,15 @@ TEST(Partitioning, SegmentSequentialModelCorrectly) {
 
   auto g = std::make_shared<torch::jit::Graph>();
   torch::jit::parseIR(graph, g.get());
+  LOG_GRAPH(*g);
 
-  torch_tensorrt::core::partitioning::PartitionInfo partition_info;
-  partition_info.enabled = true;
-  std::unordered_map<torch::jit::Node*, int> fallback_nodes;
-  std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
-      torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes);
-  ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 2));
-  ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1));
-  ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3}, {4}}));
+  PartitioningInfo partitioning_info;
+  partitioning_info.enabled = true;
+  PartitioningCtx ctx(g->block(), partitioning_info);
+  segmentGraph(&ctx, g->block());
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2));
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1));
+  ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3}, {4}}));
 }
 
 TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) {
@@ -106,18 +109,16 @@ TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) {
 
   auto g = std::make_shared<torch::jit::Graph>();
   torch::jit::parseIR(graph, g.get());
+  LOG_GRAPH(*g);
 
-  torch_tensorrt::core::partitioning::PartitionInfo partition_info;
-  partition_info.enabled = true;
-  partition_info.min_block_size = 3;
-  std::unordered_map<torch::jit::Node*, int> fallback_nodes;
-  std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
-      torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes);
-  ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1));
-  ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1));
-  ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4}}));
+  PartitioningInfo partitioning_info;
+  partitioning_info.enabled = true;
+  partitioning_info.min_block_size = 3;
+  PartitioningCtx ctx(g->block(), partitioning_info);
+  segmentGraph(&ctx, g->block());
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1));
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1));
+  ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3, 4}}));
 }
 
 TEST(Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) {
@@ -146,18 +147,16 @@ TEST(Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) {
 
   auto g = std::make_shared<torch::jit::Graph>();
   torch::jit::parseIR(graph, g.get());
+  LOG_GRAPH(*g);
 
-  torch_tensorrt::core::partitioning::PartitionInfo partition_info;
-  partition_info.enabled = true;
-  partition_info.min_block_size = 3;
-  std::unordered_map<torch::jit::Node*, int> fallback_nodes;
-  std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
-      torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes);
-  ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1));
-  ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1));
-  ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2, 3}, {4, 5, 6, 7}}));
+  PartitioningInfo partitioning_info;
+  partitioning_info.enabled = true;
+  partitioning_info.min_block_size = 3;
+  PartitioningCtx ctx(g->block(), partitioning_info);
+  segmentGraph(&ctx, g->block());
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1));
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1));
+  ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2, 3}, {4, 5, 6, 7}}));
 }
 
 TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) {
@@ -182,18 +181,16 @@ TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) {
 
   auto g = std::make_shared<torch::jit::Graph>();
   torch::jit::parseIR(graph, g.get());
+  LOG_GRAPH(*g);
 
-  torch_tensorrt::core::partitioning::PartitionInfo partition_info;
-  partition_info.enabled = true;
-  partition_info.forced_fallback_operators.push_back("aten::relu");
-  std::unordered_map<torch::jit::Node*, int> fallback_nodes;
-  std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
-      torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes);
-  ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 3));
-  ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 2));
-  ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0}, {1}, {2}, {3}, {4}}));
+  PartitioningInfo partitioning_info;
+  partitioning_info.enabled = true;
+  partitioning_info.forced_fallback_operators.push_back("aten::relu");
+  PartitioningCtx ctx(g->block(), partitioning_info);
+  segmentGraph(&ctx, g->block());
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 3));
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 2));
+  ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0}, {1}, {2}, {3}, {4}}));
 }
 
 TEST(Partitioning, SegmentBranchModelCorrectly) {
@@ -219,17 +216,15 @@ TEST(Partitioning, SegmentBranchModelCorrectly) {
 
   auto g = std::make_shared<torch::jit::Graph>();
   torch::jit::parseIR(graph, g.get());
+  LOG_GRAPH(*g);
 
-  torch_tensorrt::core::partitioning::PartitionInfo partition_info;
-  partition_info.enabled = true;
-  std::unordered_map<torch::jit::Node*, int> fallback_nodes;
-  std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
-      torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes);
-  ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 2));
-  ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1));
-  ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1}, {2}, {3, 4, 5, 6}}));
+  PartitioningInfo partitioning_info;
+  partitioning_info.enabled = true;
+  PartitioningCtx ctx(g->block(), partitioning_info);
+  segmentGraph(&ctx, g->block());
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2));
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1));
+  ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1}, {2}, {3, 4, 5, 6}}));
 }
 
 TEST(Partitioning, SegmentBranchModelWithMinBlockSizeCorrectly) {
@@ -255,18 +250,16 @@ TEST(Partitioning, SegmentBranchModelWithMinBlockSizeCorrectly) {
 
   auto g = std::make_shared<torch::jit::Graph>();
   torch::jit::parseIR(graph, g.get());
+  LOG_GRAPH(*g);
 
-  torch_tensorrt::core::partitioning::PartitionInfo partition_info;
-  partition_info.enabled = true;
-  partition_info.min_block_size = 3;
-  std::unordered_map<torch::jit::Node*, int> fallback_nodes;
-  std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
-      torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes);
-  ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1));
-  ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1));
-  ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4, 5, 6}}));
+  PartitioningInfo partitioning_info;
+  partitioning_info.enabled = true;
+  partitioning_info.min_block_size = 3;
+  PartitioningCtx ctx(g->block(), partitioning_info);
+  segmentGraph(&ctx, g->block());
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1));
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1));
+  ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3, 4, 5, 6}}));
 }
 
 TEST(Partitioning, SegmentBranchModelWithForcedFallbackOPCorrectly) {
@@ -296,16 +289,21 @@ TEST(Partitioning, SegmentBranchModelWithForcedFallbackOPCorrectly) {
 
   auto g = std::make_shared<torch::jit::Graph>();
   torch::jit::parseIR(graph, g.get());
+  LOG_GRAPH(*g);
 
-  torch_tensorrt::core::partitioning::PartitionInfo partition_info;
-  partition_info.enabled = true;
-  partition_info.forced_fallback_operators.push_back("aten::relu");
-  std::unordered_map<torch::jit::Node*, int> fallback_nodes;
-  torch_tensorrt::core::partitioning::PartitionedGraph segmented_blocks =
-      torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes);
-  ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 3));
+  PartitioningInfo partitioning_info;
+  partitioning_info.enabled = true;
+  partitioning_info.forced_fallback_operators.push_back("aten::relu");
+  PartitioningCtx ctx(g->block(), partitioning_info);
+
+  segmentGraph(&ctx, g->block());
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 3));
+  ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 2));
   ASSERT_TRUE(
-      checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 2));
-  ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1}, {2}, {3}, {4}, {5, 6}}));
+      checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1}, {2}, {3}, {4}, {5, 6}}));
 }
+
+} // namespace tests
+} // namespace partitioning
+} // namespace core
+} // namespace torch_tensorrt
diff --git a/tests/core/partitioning/test_shape_analysis.cpp b/tests/core/partitioning/test_shape_analysis.cpp
index 98b375f121..87c42c0e47 100644
--- a/tests/core/partitioning/test_shape_analysis.cpp
+++ b/tests/core/partitioning/test_shape_analysis.cpp
@@ -48,8 +48,8 @@ TEST(Partitioning, InferSequentialModelSegmentedBlockShapeCorrectly) {
   auto g = std::make_shared<torch::jit::Graph>();
   torch::jit::parseIR(graph, g.get());
 
-  torch_tensorrt::core::partitioning::PartitionInfo partition_info;
-  partition_info.enabled = true;
+  torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
+  partitioning_info.enabled = true;
   std::vector<torch_tensorrt::core::ir::Input> inputs;
   inputs.push_back(torch_tensorrt::core::ir::Input({3, 3, 16, 16}));
   inputs.push_back(torch_tensorrt::core::ir::Input({32, 3, 3, 3}));
@@ -66,9 +66,10 @@ TEST(Partitioning, InferSequentialModelSegmentedBlockShapeCorrectly) {
     input_types.insert({g->inputs()[i], {{at::kFloat}}});
   }
   auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
-  std::unordered_map<torch::jit::Node*, int> fallback_nodes;
-  std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
-      torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes);
+
+  torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
+  torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
+  auto segmented_blocks = ctx.partitioned_blocks.begin()->second;
 
   ASSERT_TRUE(checkSegmentedBlockInputShape(
       segmented_blocks,
@@ -101,8 +102,8 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) {
   auto g = std::make_shared<torch::jit::Graph>();
   torch::jit::parseIR(graph, g.get());
 
-  torch_tensorrt::core::partitioning::PartitionInfo partition_info;
-  partition_info.enabled = true;
+  torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
+  partitioning_info.enabled = true;
   std::vector<torch_tensorrt::core::ir::Input> inputs;
   inputs.push_back(torch_tensorrt::core::ir::Input({3, 3, 16, 16}));
   inputs.push_back(torch_tensorrt::core::ir::Input({32, 3, 3, 3}));
@@ -117,9 +118,10 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) {
     input_types.insert({g->inputs()[i], {{at::kFloat}}});
   }
   auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
-  std::unordered_map<torch::jit::Node*, int> fallback_nodes;
-  std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
-      torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes);
+
+  torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
+  torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
+  auto segmented_blocks = ctx.partitioned_blocks.begin()->second;
 
   ASSERT_TRUE(checkSegmentedBlockInputShape(
       segmented_blocks,
diff --git a/tests/core/partitioning/test_stitched_graph.cpp b/tests/core/partitioning/test_stitched_graph.cpp
index 61c5b58552..4332668506 100644
--- a/tests/core/partitioning/test_stitched_graph.cpp
+++ b/tests/core/partitioning/test_stitched_graph.cpp
@@ -75,7 +75,7 @@ TEST(Partitioning, StitchSequentialModelSegmentedBlockCorrectly) {
   std::vector<torch_tensorrt::core::ir::Input> inputs;
   inputs.push_back(torch_tensorrt::core::ir::Input({3, 3, 16, 16}));
   torch_tensorrt::core::CompileSpec cfg(inputs);
-  cfg.partition_info.enabled = true;
+  cfg.partitioning_info.enabled = true;
   torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
   auto fallback_g = new_mod.get_method("forward").graph();
   ASSERT_TRUE(checkAllInputsExistInStitchedGraph(fallback_g));
@@ -133,7 +133,7 @@ TEST(Partitioning, StitchBranchModelSegmentedBlockCorrectly) {
   std::vector<torch_tensorrt::core::ir::Input> inputs;
   inputs.push_back(torch_tensorrt::core::ir::Input({3, 3, 16, 16}));
   torch_tensorrt::core::CompileSpec cfg(inputs);
-  cfg.partition_info.enabled = true;
+  cfg.partitioning_info.enabled = true;
   torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
   auto fallback_g = new_mod.get_method("forward").graph();
   ASSERT_TRUE(checkAllInputsExistInStitchedGraph(fallback_g));
diff --git a/tests/core/partitioning/test_tensorrt_conversion.cpp b/tests/core/partitioning/test_tensorrt_conversion.cpp
index 8b42f95e24..41431c76db 100644
--- a/tests/core/partitioning/test_tensorrt_conversion.cpp
+++ b/tests/core/partitioning/test_tensorrt_conversion.cpp
@@ -57,7 +57,7 @@ TEST(Partitioning, ConvertSequentialModelSegmentedBlockCorrectly) {
   std::vector<torch_tensorrt::core::ir::Input> inputs;
   inputs.push_back(torch_tensorrt::core::ir::Input({3, 3, 16, 16}));
   torch_tensorrt::core::CompileSpec cfg(inputs);
-  cfg.partition_info.enabled = true;
+  cfg.partitioning_info.enabled = true;
   torch::jit::script::Module mod(c10::QualifiedName("module"));
 
   auto self = g->insertInput(0, "self_1");
@@ -116,7 +116,7 @@ TEST(Partitioning, ConvertBranchModelSegmentedBlockCorrectly) {
   std::vector<torch_tensorrt::core::ir::Input> inputs;
   inputs.push_back(torch_tensorrt::core::ir::Input({3, 3, 16, 16}));
   torch_tensorrt::core::CompileSpec cfg(inputs);
-  cfg.partition_info.enabled = true;
+  cfg.partitioning_info.enabled = true;
   torch::jit::script::Module mod(c10::QualifiedName("module"));
 
   auto self = g->insertInput(0, "self_1");