diff --git a/mobile/examples/phi-3/ios/LocalLLM/IMG_1014.jpg b/mobile/examples/phi-3/ios/LocalLLM/IMG_1014.jpg new file mode 100644 index 000000000..a8f357d27 Binary files /dev/null and b/mobile/examples/phi-3/ios/LocalLLM/IMG_1014.jpg differ diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM.xcodeproj/project.pbxproj b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM.xcodeproj/project.pbxproj index d739f813d..944f6e931 100644 --- a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM.xcodeproj/project.pbxproj +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM.xcodeproj/project.pbxproj @@ -3,20 +3,20 @@ archiveVersion = 1; classes = { }; - objectVersion = 56; + objectVersion = 70; objects = { /* Begin PBXBuildFile section */ - 5156483D2BFDBB6F005CA50C /* libonnxruntime.1.19.0.dylib in Frameworks */ = {isa = PBXBuildFile; fileRef = 5156483C2BFDBB6F005CA50C /* libonnxruntime.1.19.0.dylib */; }; - 5156483E2BFDBB6F005CA50C /* libonnxruntime.1.19.0.dylib in Embed Libraries */ = {isa = PBXBuildFile; fileRef = 5156483C2BFDBB6F005CA50C /* libonnxruntime.1.19.0.dylib */; settings = {ATTRIBUTES = (CodeSignOnCopy, ); }; }; 51D4C8D62BFD22D70029FCEA /* LocalLLMApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 51D4C8D52BFD22D70029FCEA /* LocalLLMApp.swift */; }; 51D4C8D82BFD22D70029FCEA /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 51D4C8D72BFD22D70029FCEA /* ContentView.swift */; }; 51D4C8DA2BFD22DB0029FCEA /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 51D4C8D92BFD22DB0029FCEA /* Assets.xcassets */; }; 51D4C8DD2BFD22DB0029FCEA /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 51D4C8DC2BFD22DB0029FCEA /* Preview Assets.xcassets */; }; - 51D4C9072BFD26150029FCEA /* libonnxruntime-genai.dylib in Frameworks */ = {isa = PBXBuildFile; fileRef = 51D4C9052BFD26150029FCEA /* libonnxruntime-genai.dylib */; }; - 51D4C9082BFD26EB0029FCEA /* libonnxruntime-genai.dylib in Embed Libraries */ = {isa = PBXBuildFile; fileRef = 51D4C9052BFD26150029FCEA /* libonnxruntime-genai.dylib */; settings = {ATTRIBUTES = (CodeSignOnCopy, ); }; }; 51D4C90E2BFD28DD0029FCEA /* GenAIGenerator.mm in Sources */ = {isa = PBXBuildFile; fileRef = 51D4C90D2BFD28DD0029FCEA /* GenAIGenerator.mm */; }; 51D4C9232BFD507A0029FCEA /* SharedTokenUpdater.swift in Sources */ = {isa = PBXBuildFile; fileRef = 51D4C9222BFD50790029FCEA /* SharedTokenUpdater.swift */; }; + 8A4D13D82CE2B1AE002BD11A /* libonnxruntime-genai.dylib in Frameworks */ = {isa = PBXBuildFile; fileRef = 51D4C9052BFD26150029FCEA /* libonnxruntime-genai.dylib */; }; + 8A4D13DF2CE2B1BA002BD11A /* libonnxruntime-genai.dylib in Embed Libraries */ = {isa = PBXBuildFile; fileRef = 51D4C9052BFD26150029FCEA /* libonnxruntime-genai.dylib */; settings = {ATTRIBUTES = (CodeSignOnCopy, ); }; }; + 8A53DB0C2DAF08B3001D41D1 /* onnxruntime.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 8A53DB0B2DAF08B3001D41D1 /* onnxruntime.framework */; }; + 8A53DB0D2DAF08D0001D41D1 /* onnxruntime.framework in Embed Libraries */ = {isa = PBXBuildFile; fileRef = 8A53DB0B2DAF08B3001D41D1 /* onnxruntime.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; /* End PBXBuildFile section */ /* Begin PBXCopyFilesBuildPhase section */ @@ -26,8 +26,8 @@ dstPath = ""; dstSubfolderSpec = 10; files = ( - 5156483E2BFDBB6F005CA50C /* libonnxruntime.1.19.0.dylib in Embed Libraries */, - 51D4C9082BFD26EB0029FCEA /* libonnxruntime-genai.dylib in Embed Libraries */, + 8A53DB0D2DAF08D0001D41D1 /* onnxruntime.framework in Embed Libraries */, + 8A4D13DF2CE2B1BA002BD11A /* libonnxruntime-genai.dylib in Embed Libraries */, ); name = "Embed Libraries"; runOnlyForDeploymentPostprocessing = 0; @@ -44,7 +44,6 @@ /* End PBXCopyFilesBuildPhase section */ /* Begin PBXFileReference section */ - 5156483C2BFDBB6F005CA50C /* libonnxruntime.1.19.0.dylib */ = {isa = PBXFileReference; lastKnownFileType = "compiled.mach-o.dylib"; name = libonnxruntime.1.19.0.dylib; path = LocalLLM/lib/libonnxruntime.1.19.0.dylib; sourceTree = ""; }; 51D4C8D22BFD22D70029FCEA /* LocalLLM.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = LocalLLM.app; sourceTree = BUILT_PRODUCTS_DIR; }; 51D4C8D52BFD22D70029FCEA /* LocalLLMApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LocalLLMApp.swift; sourceTree = ""; }; 51D4C8D72BFD22D70029FCEA /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; @@ -57,25 +56,38 @@ 51D4C90B2BFD28BF0029FCEA /* GenAIGenerator.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = GenAIGenerator.h; sourceTree = ""; }; 51D4C90C2BFD28DD0029FCEA /* LocalLLM-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "LocalLLM-Bridging-Header.h"; sourceTree = ""; }; 51D4C90D2BFD28DD0029FCEA /* GenAIGenerator.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = GenAIGenerator.mm; sourceTree = ""; }; - 51D4C9102BFD483E0029FCEA /* tokenizer.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; name = tokenizer.json; path = ../GenAIApp/tokenizer.json; sourceTree = ""; }; - 51D4C9112BFD483E0029FCEA /* phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx */ = {isa = PBXFileReference; lastKnownFileType = file; name = "phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx"; path = "../GenAIApp/phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx"; sourceTree = ""; }; - 51D4C9122BFD483E0029FCEA /* special_tokens_map.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; name = special_tokens_map.json; path = ../GenAIApp/special_tokens_map.json; sourceTree = ""; }; - 51D4C9132BFD483E0029FCEA /* phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data */ = {isa = PBXFileReference; lastKnownFileType = file; name = "phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data"; path = "../GenAIApp/phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data"; sourceTree = ""; }; - 51D4C9142BFD483E0029FCEA /* tokenizer_config.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; name = tokenizer_config.json; path = ../GenAIApp/tokenizer_config.json; sourceTree = ""; }; - 51D4C91A2BFD48490029FCEA /* config.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; name = config.json; path = ../GenAIApp/config.json; sourceTree = ""; }; - 51D4C91B2BFD48490029FCEA /* added_tokens.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; name = added_tokens.json; path = ../GenAIApp/added_tokens.json; sourceTree = ""; }; - 51D4C91C2BFD48490029FCEA /* configuration_phi3.py */ = {isa = PBXFileReference; lastKnownFileType = text.script.python; name = configuration_phi3.py; path = ../GenAIApp/configuration_phi3.py; sourceTree = ""; }; - 51D4C91D2BFD48490029FCEA /* genai_config.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; name = genai_config.json; path = ../GenAIApp/genai_config.json; sourceTree = ""; }; 51D4C9222BFD50790029FCEA /* SharedTokenUpdater.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SharedTokenUpdater.swift; sourceTree = ""; }; + 8A53DB0B2DAF08B3001D41D1 /* onnxruntime.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; path = onnxruntime.framework; sourceTree = ""; }; + 8A869B212CDAD08600AE0604 /* LocalLLM.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = LocalLLM.entitlements; sourceTree = ""; }; + 8ABA12972CC1D15C006B3DDF /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist; path = Info.plist; sourceTree = ""; }; /* End PBXFileReference section */ +/* Begin PBXFileSystemSynchronizedBuildFileExceptionSet section */ + 8AC409CD2DADE0EC00388525 /* PBXFileSystemSynchronizedBuildFileExceptionSet */ = { + isa = PBXFileSystemSynchronizedBuildFileExceptionSet; + membershipExceptions = ( + "LLama-3.2-1B-int4-acc_4-gqa-webgpu-fp16/genai_config.json", + "LLama-3.2-1B-int4-acc_4-gqa-webgpu-fp16/model.onnx", + "LLama-3.2-1B-int4-acc_4-gqa-webgpu-fp16/model.onnx.data", + "LLama-3.2-1B-int4-acc_4-gqa-webgpu-fp16/special_tokens_map.json", + "LLama-3.2-1B-int4-acc_4-gqa-webgpu-fp16/tokenizer_config.json", + "LLama-3.2-1B-int4-acc_4-gqa-webgpu-fp16/tokenizer.json", + ); + target = 51D4C8D12BFD22D70029FCEA /* LocalLLM */; + }; +/* End PBXFileSystemSynchronizedBuildFileExceptionSet section */ + +/* Begin PBXFileSystemSynchronizedRootGroup section */ + 8ABA12A22CC300ED006B3DDF /* model */ = {isa = PBXFileSystemSynchronizedRootGroup; exceptions = (8AC409CD2DADE0EC00388525 /* PBXFileSystemSynchronizedBuildFileExceptionSet */, ); explicitFileTypes = {}; explicitFolders = (); path = model; sourceTree = ""; }; +/* End PBXFileSystemSynchronizedRootGroup section */ + /* Begin PBXFrameworksBuildPhase section */ 51D4C8CF2BFD22D70029FCEA /* Frameworks */ = { isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - 5156483D2BFDBB6F005CA50C /* libonnxruntime.1.19.0.dylib in Frameworks */, - 51D4C9072BFD26150029FCEA /* libonnxruntime-genai.dylib in Frameworks */, + 8A53DB0C2DAF08B3001D41D1 /* onnxruntime.framework in Frameworks */, + 8A4D13D82CE2B1AE002BD11A /* libonnxruntime-genai.dylib in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -85,7 +97,6 @@ 5156483B2BFDBB6E005CA50C /* Frameworks */ = { isa = PBXGroup; children = ( - 5156483C2BFDBB6F005CA50C /* libonnxruntime.1.19.0.dylib */, ); name = Frameworks; sourceTree = ""; @@ -93,15 +104,7 @@ 51D4C8C92BFD22D70029FCEA = { isa = PBXGroup; children = ( - 51D4C91B2BFD48490029FCEA /* added_tokens.json */, - 51D4C91A2BFD48490029FCEA /* config.json */, - 51D4C91C2BFD48490029FCEA /* configuration_phi3.py */, - 51D4C91D2BFD48490029FCEA /* genai_config.json */, - 51D4C9112BFD483E0029FCEA /* phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx */, - 51D4C9132BFD483E0029FCEA /* phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data */, - 51D4C9122BFD483E0029FCEA /* special_tokens_map.json */, - 51D4C9142BFD483E0029FCEA /* tokenizer_config.json */, - 51D4C9102BFD483E0029FCEA /* tokenizer.json */, + 8ABA12A22CC300ED006B3DDF /* model */, 51D4C8D42BFD22D70029FCEA /* LocalLLM */, 51D4C8D32BFD22D70029FCEA /* Products */, 5156483B2BFDBB6E005CA50C /* Frameworks */, @@ -119,6 +122,8 @@ 51D4C8D42BFD22D70029FCEA /* LocalLLM */ = { isa = PBXGroup; children = ( + 8A869B212CDAD08600AE0604 /* LocalLLM.entitlements */, + 8ABA12972CC1D15C006B3DDF /* Info.plist */, 51D4C9032BFD25BA0029FCEA /* lib */, 51D4C8FF2BFD25890029FCEA /* header */, 51D4C8D52BFD22D70029FCEA /* LocalLLMApp.swift */, @@ -154,6 +159,7 @@ 51D4C9032BFD25BA0029FCEA /* lib */ = { isa = PBXGroup; children = ( + 8A53DB0B2DAF08B3001D41D1 /* onnxruntime.framework */, 51D4C9052BFD26150029FCEA /* libonnxruntime-genai.dylib */, ); path = lib; @@ -189,7 +195,7 @@ attributes = { BuildIndependentTargetsInParallel = 1; LastSwiftUpdateCheck = 1520; - LastUpgradeCheck = 1520; + LastUpgradeCheck = 1600; TargetAttributes = { 51D4C8D12BFD22D70029FCEA = { CreatedOnToolsVersion = 15.2; @@ -245,7 +251,6 @@ 51D4C8F42BFD22DC0029FCEA /* Debug */ = { isa = XCBuildConfiguration; buildSettings = { - ALWAYS_EMBED_SWIFT_STANDARD_LIBRARIES = YES; ALWAYS_SEARCH_USER_PATHS = NO; ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; CLANG_ANALYZER_NONNULL = YES; @@ -277,9 +282,10 @@ CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = dwarf; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; DEFINES_MODULE = YES; EMBED_ASSET_PACKS_IN_PRODUCT_BUNDLE = YES; + ENABLE_MODULE_VERIFIER = YES; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; @@ -299,7 +305,9 @@ GCC_WARN_UNUSED_VARIABLE = YES; IPHONEOS_DEPLOYMENT_TARGET = 16.6; LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + "MTL_ENABLE_DEBUG_INFO[sdk=iphoneos*]" = NO; MTL_FAST_MATH = YES; ONLY_ACTIVE_ARCH = YES; SDKROOT = iphoneos; @@ -317,7 +325,6 @@ 51D4C8F52BFD22DC0029FCEA /* Release */ = { isa = XCBuildConfiguration; buildSettings = { - ALWAYS_EMBED_SWIFT_STANDARD_LIBRARIES = YES; ALWAYS_SEARCH_USER_PATHS = NO; ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; CLANG_ANALYZER_NONNULL = YES; @@ -352,6 +359,7 @@ DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; DEFINES_MODULE = YES; EMBED_ASSET_PACKS_IN_PRODUCT_BUNDLE = YES; + ENABLE_MODULE_VERIFIER = YES; ENABLE_NS_ASSERTIONS = NO; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; @@ -365,6 +373,7 @@ GCC_WARN_UNUSED_VARIABLE = YES; IPHONEOS_DEPLOYMENT_TARGET = 16.6; LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; MTL_ENABLE_DEBUG_INFO = NO; MTL_FAST_MATH = YES; SDKROOT = iphoneos; @@ -385,12 +394,20 @@ ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; CLANG_ENABLE_MODULES = YES; + CODE_SIGN_ENTITLEMENTS = LocalLLM/LocalLLM.entitlements; + CODE_SIGN_IDENTITY = "Apple Development"; CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; DEVELOPMENT_ASSET_PATHS = "\"LocalLLM/Preview Content\""; - DEVELOPMENT_TEAM = UBF8T346G9; + DEVELOPMENT_TEAM = AL4J766FY4; ENABLE_PREVIEWS = YES; + FRAMEWORK_SEARCH_PATHS = ( + "$(inherited)", + "$(PROJECT_DIR)/LocalLLM/lib", + ); GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_FILE = LocalLLM/Info.plist; INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; INFOPLIST_KEY_UILaunchScreen_Generation = YES; @@ -407,6 +424,7 @@ MARKETING_VERSION = 1.0; PRODUCT_BUNDLE_IDENTIFIER = ai.onnxruntime.genai.demo.LocalLLM; PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE_SPECIFIER = ""; SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_OBJC_BRIDGING_HEADER = "LocalLLM/LocalLLM-Bridging-Header.h"; SWIFT_OPTIMIZATION_LEVEL = "-Onone"; @@ -421,12 +439,20 @@ ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; CLANG_ENABLE_MODULES = YES; + CODE_SIGN_ENTITLEMENTS = LocalLLM/LocalLLM.entitlements; + CODE_SIGN_IDENTITY = "Apple Development"; CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; DEVELOPMENT_ASSET_PATHS = "\"LocalLLM/Preview Content\""; - DEVELOPMENT_TEAM = UBF8T346G9; + DEVELOPMENT_TEAM = AL4J766FY4; ENABLE_PREVIEWS = YES; + FRAMEWORK_SEARCH_PATHS = ( + "$(inherited)", + "$(PROJECT_DIR)/LocalLLM/lib", + ); GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_FILE = LocalLLM/Info.plist; INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; INFOPLIST_KEY_UILaunchScreen_Generation = YES; @@ -441,8 +467,9 @@ "$(PROJECT_DIR)/LocalLLM/lib", ); MARKETING_VERSION = 1.0; - PRODUCT_BUNDLE_IDENTIFIER = ai.onnxruntime.genai.demo.LocalLLM; + PRODUCT_BUNDLE_IDENTIFIER = ai.onnxruntime.genai.demo.LocalLL; PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE_SPECIFIER = ""; SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_OBJC_BRIDGING_HEADER = "LocalLLM/LocalLLM-Bridging-Header.h"; SWIFT_VERSION = 5.0; diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/ContentView.swift b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/ContentView.swift index 0777e2ff5..147425bd9 100644 --- a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/ContentView.swift +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/ContentView.swift @@ -3,155 +3,180 @@ import SwiftUI - struct Message: Identifiable { - let id = UUID() - var text: String - let isUser: Bool + let id = UUID() + var text: String + let isUser: Bool } struct ContentView: View { - @State private var userInput: String = "" - @State private var messages: [Message] = [] // Store chat messages locally - @State private var isGenerating: Bool = false // Track token generation state - @State private var stats: String = "" // token generation stats - @State private var showAlert: Bool = false - @State private var errorMessage: String = "" - - private let generator = GenAIGenerator() - - var body: some View { - VStack { - // ChatBubbles - ScrollView { - VStack(alignment: .leading, spacing: 20) { - ForEach(messages) { message in - ChatBubble(text: message.text, isUser: message.isUser) - .padding(.horizontal, 20) - } - if !stats.isEmpty { - Text(stats) - .font(.footnote) - .foregroundColor(.gray) - .padding(.horizontal, 20) - .padding(.top, 5) - .multilineTextAlignment(.center) - } - } - .padding(.top, 20) - } + @State private var userInput: String = "" + @State private var messages: [Message] = [] // Store chat messages locally + @State private var isGenerating: Bool = false // Track token generation state + @State private var stats: String = "" // token generation stats + @State private var showAlert: Bool = false + @State private var errorMessage: String = "" + @State private var showFolderPicker: Bool = false // State for folder picker sheet - - // User input - HStack { - TextField("Type your message...", text: $userInput) - .padding() - .background(Color(.systemGray6)) - .cornerRadius(20) - .padding(.horizontal) - - Button(action: { - // Check for non-empty input - guard !userInput.trimmingCharacters(in: .whitespaces).isEmpty else { return } - - messages.append(Message(text: userInput, isUser: true)) - messages.append(Message(text: "", isUser: false)) // Placeholder for AI response - - - // clear previously generated tokens - SharedTokenUpdater.shared.clearTokens() - - let prompt = userInput - userInput = "" - isGenerating = true - - - DispatchQueue.global(qos: .background).async { - generator.generate(prompt) - } - }) { - Image(systemName: "paperplane.fill") - .foregroundColor(.white) - .padding() - .background(isGenerating ? Color.gray : Color.pastelGreen) - .clipShape(Circle()) - .padding(.trailing, 10) - } - .disabled(isGenerating) - } - .padding(.bottom, 20) - } - .background(Color(.systemGroupedBackground)) - .edgesIgnoringSafeArea(.bottom) - .onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationCompleted"))) { _ in - isGenerating = false // Re-enable the button when token generation is complete - } - .onReceive(SharedTokenUpdater.shared.$decodedTokens) { tokens in - // update model response - if let lastIndex = messages.lastIndex(where: { !$0.isUser }) { - let combinedText = tokens.joined(separator: "") - messages[lastIndex].text = combinedText - } + private let generator = GenAIGenerator() + + var body: some View { + VStack { + // ChatBubbles + ScrollView { + VStack(alignment: .leading, spacing: 20) { + ForEach(messages) { message in + ChatBubble(text: message.text, isUser: message.isUser) + .padding(.horizontal, 20) + } + if !stats.isEmpty { + Text(stats) + .font(.footnote) + .foregroundColor(.gray) + .padding(.horizontal, 20) + .padding(.top, 5) + .multilineTextAlignment(.center) + } } - .onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationStats"))) { notification in - if let userInfo = notification.userInfo, - let promptProcRate = userInfo["promptProcRate"] as? Double, - let tokenGenRate = userInfo["tokenGenRate"] as? Double { - stats = String(format: "Token generation rate: %.2f tokens/s. Prompt processing rate: %.2f tokens/s", tokenGenRate, promptProcRate) - } + .padding(.top, 20) + } + + HStack { + Button(action: { + showFolderPicker = true + }) { + HStack { + Image(systemName: "folder") + .resizable() + .scaledToFit() + .frame(width: 20, height: 20) + } + .padding() + .background(Color.pastelGreen) + .cornerRadius(10) + .shadow(radius: 2) + .padding(.leading, 10) } - .onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationError"))) { notification in - if let userInfo = notification.userInfo, let error = userInfo["error"] as? String { - errorMessage = error - isGenerating = false - showAlert = true + .sheet(isPresented: $showFolderPicker) { + FolderPicker { folderURL in + if let folderURL = folderURL { + let folderPath = folderURL.path + print("Selected folder: \(folderPath)") + DispatchQueue.global(qos: .background).async { + generator.setModelFolderPath(folderPath) + } } + } + }.help("Select a folder to set the model path") + + TextField("Type your message...", text: $userInput) + .padding() + .background(Color(.systemGray6)) + .cornerRadius(20) + .padding(.horizontal) + + Button(action: { + // Check for non-empty input + guard !userInput.trimmingCharacters(in: .whitespaces).isEmpty else { return } + + messages.append(Message(text: userInput, isUser: true)) + messages.append(Message(text: "", isUser: false)) // Placeholder for AI response + + // clear previously generated tokens + SharedTokenUpdater.shared.clearTokens() + + let prompt = userInput + userInput = "" + isGenerating = true + + DispatchQueue.global(qos: .background).async { + generator.generate(prompt) + } + }) { + Image(systemName: "paperplane.fill") + .foregroundColor(.white) + .padding() + .background(isGenerating ? Color.gray : Color.pastelGreen) + .clipShape(Circle()) } - .alert(isPresented: $showAlert) { - Alert( - title: Text("Error"), - message: Text(errorMessage), - dismissButton: .default(Text("OK")) - ) - } - + .disabled(isGenerating) + } + .padding(.bottom, 20) + } + .background(Color(.systemGroupedBackground)) + .edgesIgnoringSafeArea(.bottom) + .onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationCompleted"))) { _ in + isGenerating = false // Re-enable the button when token generation is complete + } + .onReceive(SharedTokenUpdater.shared.$decodedTokens) { tokens in + // update model response + if let lastIndex = messages.lastIndex(where: { !$0.isUser }) { + let combinedText = tokens.joined(separator: "") + messages[lastIndex].text = combinedText + } + } + .onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationStats"))) { notification in + if let userInfo = notification.userInfo, + let promptProcRate = userInfo["promptProcRate"] as? Double, + let tokenGenRate = userInfo["tokenGenRate"] as? Double + { + stats = String( + format: "Token generation rate: %.2f tokens/s. Prompt processing rate: %.2f tokens/s", tokenGenRate, + promptProcRate) + } + } + .onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("GenAIError"))) { notification in + if let userInfo = notification.userInfo, let error = userInfo["error"] as? String { + errorMessage = error + isGenerating = false + showAlert = true + } } + .alert(isPresented: $showAlert) { + Alert( + title: Text("Error"), + message: Text(errorMessage), + dismissButton: .default(Text("OK")) + ) + } + + } } struct ChatBubble: View { - var text: String - var isUser: Bool - - var body: some View { - HStack { - if isUser { - Spacer() - Text(text) - .padding() - .background(Color.pastelGreen) - .foregroundColor(.white) - .cornerRadius(25) - .padding(.horizontal, 10) - } else { - Text(text) - .padding() - .background(Color(.systemGray5)) - .foregroundColor(.black) - .cornerRadius(25) - .padding(.horizontal, 10) - Spacer() - } - } + var text: String + var isUser: Bool + + var body: some View { + HStack { + if isUser { + Spacer() + Text(text) + .padding() + .background(Color.pastelGreen) + .foregroundColor(.white) + .cornerRadius(25) + .padding(.horizontal, 10) + } else { + Text(text) + .padding() + .background(Color(.systemGray5)) + .foregroundColor(.black) + .cornerRadius(25) + .padding(.horizontal, 10) + Spacer() + } } + } } struct ContentView_Previews: PreviewProvider { - static var previews: some View { - ContentView() - } + static var previews: some View { + ContentView() + } } // Extension for a pastel green color extension Color { - static let pastelGreen = Color(red: 0.6, green: 0.9, blue: 0.6) + static let pastelGreen = Color(red: 0.6, green: 0.9, blue: 0.6) } diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/FolderPicker.swift b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/FolderPicker.swift new file mode 100644 index 000000000..d607a2698 --- /dev/null +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/FolderPicker.swift @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import SwiftUI +import UIKit + +struct FolderPicker: UIViewControllerRepresentable { + var onPick: (URL?) -> Void + + func makeUIViewController(context: Context) -> UIDocumentPickerViewController { + let picker = UIDocumentPickerViewController(forOpeningContentTypes: [.folder]) + picker.allowsMultipleSelection = false + picker.delegate = context.coordinator + return picker + } + + func updateUIViewController(_ uiViewController: UIDocumentPickerViewController, context: Context) {} + + func makeCoordinator() -> Coordinator { + Coordinator(onPick: onPick) + } + + class Coordinator: NSObject, UIDocumentPickerDelegate { + let onPick: (URL?) -> Void + + init(onPick: @escaping (URL?) -> Void) { + self.onPick = onPick + } + + func documentPicker(_ controller: UIDocumentPickerViewController, didPickDocumentsAt urls: [URL]) { + onPick(urls.first) + } + + func documentPickerWasCancelled(_ controller: UIDocumentPickerViewController) { + onPick(nil) + } + } +} diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/GenAIGenerator.h b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/GenAIGenerator.h index 288c914d4..17c8ce516 100644 --- a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/GenAIGenerator.h +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/GenAIGenerator.h @@ -11,6 +11,7 @@ NS_ASSUME_NONNULL_BEGIN @interface GenAIGenerator : NSObject +- (void)setModelFolderPath:(nonnull NSString *)modelPath; - (void)generate:(NSString *)input_user_question; @end diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/GenAIGenerator.mm b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/GenAIGenerator.mm index ddcf2b101..54eb4ee88 100644 --- a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/GenAIGenerator.mm +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/GenAIGenerator.mm @@ -7,12 +7,12 @@ #include "ort_genai.h" #include "ort_genai_c.h" - const size_t kMaxTokens = 200; @interface GenAIGenerator () { std::unique_ptr model; std::unique_ptr tokenizer; + NSString* modelPath; } @end @@ -30,27 +30,48 @@ - (instancetype)init { return self; } +- (void)setModelFolderPath:(NSString*)modelPath { + @synchronized(self) { + self->modelPath = [modelPath copy]; + NSLog(@"Model folder path set to: %@", modelPath); + + try { + [self loadModelFromPath]; + } catch (const std::exception& e) { + NSString* errorMessage = [NSString stringWithUTF8String:e.what()]; + NSLog(@"Error loading model: %@", errorMessage); + + // Notify the UI about the error + NSDictionary* errorInfo = @{@"error" : errorMessage}; + dispatch_async(dispatch_get_main_queue(), ^{ + [[NSNotificationCenter defaultCenter] postNotificationName:@"GenAIError" object:nil userInfo:errorInfo]; + }); + } + } +} + +- (void)loadModelFromPath { + @synchronized(self) { + NSLog(@"Creating model..."); + self->model = OgaModel::Create(self->modelPath.UTF8String); // throws exception + NSLog(@"Creating tokenizer..."); + self->tokenizer = OgaTokenizer::Create(*self->model); // throws exception + } +} + - (void)generate:(nonnull NSString*)input_user_question { std::vector tokenTimes; // per-token generation times tokenTimes.reserve(kMaxTokens); - TimePoint startTime, firstTokenTime, tokenStartTime; try { - NSLog(@"Starting token generation..."); - - if (!self->model) { - NSLog(@"Creating model..."); - NSString* llmPath = [[NSBundle mainBundle] resourcePath]; - const char* modelPath = llmPath.cString; - self->model = OgaModel::Create(modelPath); // throws exception - } - - if (!self->tokenizer) { - NSLog(@"Creating tokenizer..."); - self->tokenizer = OgaTokenizer::Create(*self->model); // throws exception + if (!self->modelPath) { + self->modelPath = [[NSBundle mainBundle] resourcePath]; + NSLog(@"No folder path provided. Using the default folder path: %@", self->modelPath); + [self loadModelFromPath]; } + NSLog(@"Starting token generation..."); auto tokenizer_stream = OgaTokenizerStream::Create(*self->tokenizer); // Construct the prompt @@ -66,7 +87,6 @@ - (void)generate:(nonnull NSString*)input_user_question { NSLog(@"Setting generator parameters..."); auto params = OgaGeneratorParams::Create(*self->model); params->SetSearchOption("max_length", kMaxTokens); - params->SetInputSequences(*sequences); auto generator = OgaGenerator::Create(*self->model, *params); @@ -74,10 +94,9 @@ - (void)generate:(nonnull NSString*)input_user_question { NSLog(@"Starting token generation loop..."); startTime = Clock::now(); + generator->AppendTokenSequences(*sequences); while (!generator->IsDone()) { tokenStartTime = Clock::now(); - - generator->ComputeLogits(); generator->GenerateNextToken(); if (isFirstToken) { @@ -134,7 +153,7 @@ - (void)generate:(nonnull NSString*)input_user_question { // Send error to the UI NSDictionary* errorInfo = @{@"error" : errorMessage}; dispatch_async(dispatch_get_main_queue(), ^{ - [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationError" object:nil userInfo:errorInfo]; + [[NSNotificationCenter defaultCenter] postNotificationName:@"GenAIError" object:nil userInfo:errorInfo]; }); } } diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/README.md b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/README.md index a9438224b..3632e0e1d 100644 --- a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/README.md +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/README.md @@ -10,12 +10,12 @@ For this application, the following prerequisites are preferred: 1. macOS 14+ -2. Xcode 15+ (latest Xcode version perferred.) +2. Xcode 15+ (latest Xcode version preferred) -3. iOS SDK 16.x + (iPhone 14 or iPhone 15 powered by a A16 or A17 preferred) +3. iOS SDK 16.x + (iPhone 14 or iPhone 15 powered by A16 or A17 preferred) **Note**: - The current Xcode project contains a built .dylib for ORT and ORT GenAI. The following steps `A, B, C` under `step 1.` for building from source for the libraries are optional. + The current Xcode project contains a built onnxruntime.framework and ORT GenAI library. The following steps `A, B, C` under `step 1.` for building from source for the libraries are optional. However if you want to build from source to include the latest updates, please use the `step 1.` as a reference. ### 1. Steps to build from source for ONNX Runtime and Generative AI libraries [Optional] @@ -33,66 +33,58 @@ For this application, the following prerequisites are preferred: #### **B. Compiling ONNX Runtime for iOS** -```bash - -git clone https://github.com/microsoft/onnxruntime.git -cd onnxruntime +***Notice*** -./build.sh --build_shared_lib --skip_tests --parallel --build_dir ./build_ios --ios --apple_sysroot iphoneos --osx_arch arm64 --apple_deploy_target 16.6 --cmake_generator Xcode --config Release + 1. Before compiling, you must ensure that Xcode is configured correctly and set it on the terminal +```bash +sudo xcode-select -switch /Applications/Xcode.app/Contents/Developer ``` -***Notice*** - - 1. Before compiling, you must ensure that Xcode is configured correctly and set it on the terminal + 2. Build a fat ONNX Runtime Framework for iOS and iOS simulator from `` using this command: ```bash +git clone https://github.com/microsoft/onnxruntime.git -sudo xcode-select -switch /Applications/Xcode.app/Contents/Developer +cd onnxruntime +python tools/ci_build/github/apple/build_apple_framework.py tools/ci_build/github/apple/default_full_ios_framework_build_settings.json --config Release ``` +The build creates `Headers`, `LICENSE`, and `onnxruntime.xcframework` in `build/iOS_framework/framework_out` directory. + + 3. The `default_full_ios_framework_build_settings.json` file is a default build settings file. You can modify the build settings in this file to suit your needs. For example, you can set the `--apple_deploy_target` to 15.1 or higher or remove osx_arch that are not needed. - 2. ONNX Runtime needs to be compiled based on different platforms. For iOS, you can compile for arm64 or x86_64 based on needs. If you are running an iOS simulator on an Intel mac, compile for x86_64. Use arm64 for an ARM based mac to run the simulator, and to run on an iPhone. - - 3. It is recommended to directly use the latest iOS SDK for compilation. Of course, you can also lower the version to be compatible with past SDKs. #### **C. Compiling Generative AI with ONNX Runtime for iOS** ```bash - git clone https://github.com/microsoft/onnxruntime-genai cd onnxruntime-genai -python3 build.py --parallel --build_dir ./build_iphoneos --ios --apple_sysroot iphoneos --osx_arch arm64 --apple_deploy_target 16.6 --cmake_generator Xcode - +python3 build.py --parallel --build_dir ./build_iphoneos --ios --apple_sysroot iphoneos --osx_arch arm64 --apple_deploy_target 15.1 --cmake_generator Xcode --ort_home /path/to/framework_out ``` -#### **D. Copy over latest header files and required .dylibs built from source** +#### **D. Copy over latest header files and required framework files built from source** -If you build from source and get the latest .dylibs for ORT and ORT GenAI, please copy the .dylibs over to `mobile\examples\phi-3\ios\LocalLLM\LocalLLM\lib` and copy the latest header files over to `mobile\examples\phi-3\ios\LocalLLM\LocalLLM\header` +If you build from source and get the latest `onnxruntime.framework` and ORT GenAI library, please copy them over to `mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib` and copy the latest header files over to `mobile/examples/phi-3/ios/LocalLLM/LocalLLM/header` -The build output path for libonnxruntime.dylib is `/build/intermediates/_///libonnxruntime.dylib` +The build output path for onnxruntime.framework is `/build/iOS_framework/framework_out/onnxruntime.xcframework/your_target_arch/onnxruntime.framework`. The build output path for libonnxruntime-genai.dylib is `/build//libonnxruntime-genai.dylib`. -For example: -- `onnxruntime/build/intermediates/iphoneos_arm64/Release/Release-iphoneos/libonnxruntime.1.19.0.dylib` -- `onnxruntime-genai/build/Release/Release-iphoneos/libonnxruntime-genai.dylib`. - -Note that you will need to build and copy the correct dylib for the target architecture you wish to run the app on. -e.g. -if you want to run on the iOS simulator on an Intel mac, you must build both onnxruntime and onnxruntime-genai for x86_64 and copy the dylibs to the app's `lib` directory. -if you want to run on an iPhone, you must build both onnxruntime and onnxruntime-genai for arm64 and copy the dylibs to the app's `lib` directory. +Note that you will need to build and copy the correct framework files for the target architecture you wish to run the app on. +For example: +- If you want to run on the iOS simulator on an Intel Mac, you must build both onnxruntime and onnxruntime-genai for x86_64 and copy the appropriate files to the app's `lib` directory. +- If you want to run on an iPhone, you must build both onnxruntime and onnxruntime-genai for arm64 and copy the appropriate files to the app's `lib` directory. The header files to copy are: -`/onnxruntime/core/session/onnxruntime_c_api.h`, `/src/ort_genai.h`, `/src/ort_genai_c.h`. ### 2. Create/Open the iOS application in Xcode -The app uses Objective-C/C++ since using Generative AI with ONNX Runtime C++ API, Objective-C has better compatiblility. +The app uses Objective-C/C++ since using Generative AI with ONNX Runtime C++ API, Objective-C has better compatibility. ### 3. Copy the ONNX quantized INT4 model to the App application project @@ -106,6 +98,7 @@ Upon app launching, Xcode will automatically copy and install the model files fr **Note**: The current app only sets up with a simple initial prompt question, you can adjust/try your own or refine the UI based on requirements. -***Notice:*** The current Xcode project runs on iOS 16.6, feel free to adjust latest iOS/build for lates iOS versions accordingly. +***Notice:*** The current Xcode project runs on iOS 16.6, feel free to adjust latest iOS/build for latest iOS versions accordingly. -![alt text]() \ No newline at end of file +![alt text]() +![alt text]() \ No newline at end of file diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/Screenshot2.jpg b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/Screenshot2.jpg new file mode 100644 index 000000000..a8f357d27 Binary files /dev/null and b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/Screenshot2.jpg differ diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/SharedTokenUpdater.swift b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/SharedTokenUpdater.swift index 260a9154b..468718363 100644 --- a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/SharedTokenUpdater.swift +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/SharedTokenUpdater.swift @@ -5,19 +5,19 @@ import Combine import Foundation @objc class SharedTokenUpdater: NSObject, ObservableObject { - @Published var decodedTokens: [String] = [] - - @objc static let shared = SharedTokenUpdater() - - @objc func addDecodedToken(_ token: String) { - DispatchQueue.main.async { - self.decodedTokens.append(token) - } + @Published var decodedTokens: [String] = [] + + @objc static let shared = SharedTokenUpdater() + + @objc func addDecodedToken(_ token: String) { + DispatchQueue.main.async { + self.decodedTokens.append(token) } + } - @objc func clearTokens() { - DispatchQueue.main.async { - self.decodedTokens.removeAll() - } + @objc func clearTokens() { + DispatchQueue.main.async { + self.decodedTokens.removeAll() } + } } diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/header/ort_genai.h b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/header/ort_genai.h index fb863dae2..718ba1f9f 100644 --- a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/header/ort_genai.h +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/header/ort_genai.h @@ -5,9 +5,11 @@ #include #include +#include #if __cplusplus >= 202002L #include +#define OGA_USE_SPAN 1 #endif #include "ort_genai_c.h" @@ -25,13 +27,18 @@ * tokenizer->Encode("A great recipe for Kung Pao chicken is ", *sequences); * * auto params = OgaGeneratorParams::Create(*model); - * params->SetInputSequences(*sequences); * params->SetSearchOption("max_length", 200); + * params->SetSearchOption("batch_size", 1); * - * auto output_sequences = model->Generate(*params); - * auto out_string = tokenizer->Decode(output_sequences->Get(0)); + * auto generator = OgaGenerator::Create(*model, *params); + * generator->AppendTokenSequences(*sequences); + * while (!generator->IsDone()) { + * generator->GenerateNextToken(); + * } + * auto output_sequence = generator->GetSequenceData(0); + * auto output_string = tokenizer->Decode(output_sequence, generator->GetSequenceCount(0)); * - * std::cout << "Output: " << std::endl << out_string << std::endl; + * std::cout << "Output: " << std::endl << output_string << std::endl; */ // The types defined in this file are to give us zero overhead C++ style interfaces around an opaque C pointer. @@ -58,14 +65,121 @@ inline void OgaCheckResult(OgaResult* result) { } } -struct OgaLog { - void SetBool(const char* name, bool value) { - OgaCheckResult(OgaSetLogBool(name, value)); +struct OgaFloat16_t; +struct OgaBFloat16_t; + +// Variable templates to convert a C++ type into it's OgaElementType +template +inline constexpr OgaElementType OgaTypeToElementType = T::Unsupported_Type; // Force a compile error if hit, please add specialized version if type is valid +template <> +inline constexpr OgaElementType OgaTypeToElementType = OgaElementType_bool; +template <> +inline constexpr OgaElementType OgaTypeToElementType = OgaElementType_int8; +template <> +inline constexpr OgaElementType OgaTypeToElementType = OgaElementType_uint8; +template <> +inline constexpr OgaElementType OgaTypeToElementType = OgaElementType_int16; +template <> +inline constexpr OgaElementType OgaTypeToElementType = OgaElementType_uint16; +template <> +inline constexpr OgaElementType OgaTypeToElementType = OgaElementType_int32; +template <> +inline constexpr OgaElementType OgaTypeToElementType = OgaElementType_uint32; +template <> +inline constexpr OgaElementType OgaTypeToElementType = OgaElementType_int64; +template <> +inline constexpr OgaElementType OgaTypeToElementType = OgaElementType_uint64; +template <> +inline constexpr OgaElementType OgaTypeToElementType = OgaElementType_float32; +template <> +inline constexpr OgaElementType OgaTypeToElementType = OgaElementType_float64; +template <> +inline constexpr OgaElementType OgaTypeToElementType = OgaElementType_float16; +template <> +inline constexpr OgaElementType OgaTypeToElementType = OgaElementType_bfloat16; + +struct OgaString { + OgaString(const char* p) : p_{p} {} + ~OgaString() { OgaDestroyString(p_); } + + operator const char*() const { return p_; } + + const char* p_; +}; + +struct OgaStringArray { + std::unique_ptr Create() { + OgaStringArray* p; + OgaCheckResult(OgaCreateStringArray(&p)); + return std::unique_ptr(p); } - void SetString(const char* name, const char* value) { - OgaCheckResult(OgaSetLogString(name, value)); + std::unique_ptr Create(const char** strings, size_t count) { + OgaStringArray* p; + OgaCheckResult(OgaCreateStringArrayFromStrings(strings, count, &p)); + return std::unique_ptr(p); } + + void Add(const char* str) { + OgaCheckResult(OgaStringArrayAddString(this, str)); + } + + const char* Get(size_t index) const { + const char* p; + OgaCheckResult(OgaStringArrayGetString(this, index, &p)); + return p; + } + + size_t Count() const { + size_t count; + OgaCheckResult(OgaStringArrayGetCount(this, &count)); + return count; + } + + static void operator delete(void* p) { OgaDestroyStringArray(reinterpret_cast(p)); } +}; + +struct OgaRuntimeSettings : OgaAbstract { + static std::unique_ptr Create() { + OgaRuntimeSettings* p; + OgaCheckResult(OgaCreateRuntimeSettings(&p)); + return std::unique_ptr(p); + } + + void SetHandle(const char* name, void* handle) { + OgaCheckResult(OgaRuntimeSettingsSetHandle(this, name, handle)); + } + void SetHandle(const std::string& name, void* handle) { + SetHandle(name.c_str(), handle); + } + + static void operator delete(void* p) { OgaDestroyRuntimeSettings(reinterpret_cast(p)); } +}; + +struct OgaConfig : OgaAbstract { + static std::unique_ptr Create(const char* config_path) { + OgaConfig* p; + OgaCheckResult(OgaCreateConfig(config_path, &p)); + return std::unique_ptr(p); + } + + void ClearProviders() { + OgaCheckResult(OgaConfigClearProviders(this)); + } + + void AppendProvider(const char* provider) { + OgaCheckResult(OgaConfigAppendProvider(this, provider)); + } + + void SetProviderOption(const char* provider, const char* name, const char* value) { + OgaCheckResult(OgaConfigSetProviderOption(this, provider, name, value)); + } + + void Overlay(const char* json) { + OgaCheckResult(OgaConfigOverlay(this, json)); + } + + static void operator delete(void* p) { OgaDestroyConfig(reinterpret_cast(p)); } }; struct OgaModel : OgaAbstract { @@ -74,23 +188,30 @@ struct OgaModel : OgaAbstract { OgaCheckResult(OgaCreateModel(config_path, &p)); return std::unique_ptr(p); } - - std::unique_ptr Generate(const OgaGeneratorParams& params) { - OgaSequences* p; - OgaCheckResult(OgaGenerate(this, ¶ms, &p)); - return std::unique_ptr(p); + static std::unique_ptr Create(const char* config_path, const OgaRuntimeSettings& settings) { + OgaModel* p; + OgaCheckResult(OgaCreateModelWithRuntimeSettings(config_path, &settings, &p)); + return std::unique_ptr(p); + } + static std::unique_ptr Create(const OgaConfig& config) { + OgaModel* p; + OgaCheckResult(OgaCreateModelFromConfig(&config, &p)); + return std::unique_ptr(p); } - static void operator delete(void* p) { OgaDestroyModel(reinterpret_cast(p)); } -}; - -struct OgaString { - OgaString(const char* p) : p_{p} {} - ~OgaString() { OgaDestroyString(p_); } + OgaString GetType() const { + const char* p; + OgaCheckResult(OgaModelGetType(this, &p)); + return p; + } - operator const char*() const { return p_; } + OgaString GetDeviceType() const { + const char* p; + OgaCheckResult(OgaModelGetDeviceType(this, &p)); + return p; + } - const char* p_; + static void operator delete(void* p) { OgaDestroyModel(reinterpret_cast(p)); } }; struct OgaSequences : OgaAbstract { @@ -112,10 +233,24 @@ struct OgaSequences : OgaAbstract { return OgaSequencesGetSequenceData(this, index); } -#if __cplusplus >= 202002L + void Append(const int32_t* tokens, size_t token_cnt) { + OgaCheckResult(OgaAppendTokenSequence(tokens, token_cnt, this)); + } + + void Append(int32_t token, size_t sequence_index) { + OgaCheckResult(OgaAppendTokenToSequence(token, this, sequence_index)); + } + +#if OGA_USE_SPAN std::span Get(size_t index) const { return {SequenceData(index), SequenceCount(index)}; } + void Append(std::span sequence) { + OgaCheckResult(OgaAppendTokenSequence(sequence.data(), sequence.size(), this)); + } + void Append(const std::vector& sequence) { + OgaCheckResult(OgaAppendTokenSequence(sequence.data(), sequence.size(), this)); + } #endif static void operator delete(void* p) { OgaDestroySequences(reinterpret_cast(p)); } @@ -132,13 +267,25 @@ struct OgaTokenizer : OgaAbstract { OgaCheckResult(OgaTokenizerEncode(this, str, &sequences)); } + std::unique_ptr EncodeBatch(const char** strings, size_t count) const { + OgaTensor* out; + OgaCheckResult(OgaTokenizerEncodeBatch(this, strings, count, &out)); + return std::unique_ptr(out); + } + + int32_t ToTokenId(const char* str) const { + int32_t token_id; + OgaCheckResult(OgaTokenizerToTokenId(this, str, &token_id)); + return token_id; + } + OgaString Decode(const int32_t* tokens_data, size_t tokens_length) const { const char* p; OgaCheckResult(OgaTokenizerDecode(this, tokens_data, tokens_length, &p)); return p; } -#if __cplusplus >= 202002L +#if OGA_USE_SPAN OgaString Decode(std::span tokens) const { const char* p; OgaCheckResult(OgaTokenizerDecode(this, tokens.data(), tokens.size(), &p)); @@ -146,6 +293,12 @@ struct OgaTokenizer : OgaAbstract { } #endif + std::unique_ptr DecodeBatch(const OgaTensor& tensor) const { + OgaStringArray* p; + OgaCheckResult(OgaTokenizerDecodeBatch(this, &tensor, &p)); + return std::unique_ptr(p); + } + static void operator delete(void* p) { OgaDestroyTokenizer(reinterpret_cast(p)); } }; @@ -156,6 +309,12 @@ struct OgaTokenizerStream : OgaAbstract { return std::unique_ptr(p); } + static std::unique_ptr Create(const OgaMultiModalProcessor& processor) { + OgaTokenizerStream* p; + OgaCheckResult(OgaCreateTokenizerStreamFromProcessor(&processor, &p)); + return std::unique_ptr(p); + } + /* * Decode a single token in the stream. If this results in a word being generated, it will be returned in 'out'. * The caller is responsible for concatenating each chunk together to generate the complete result. @@ -185,23 +344,23 @@ struct OgaGeneratorParams : OgaAbstract { OgaCheckResult(OgaGeneratorParamsSetSearchBool(this, name, value)); } - void SetInputIDs(const int32_t* input_ids, size_t input_ids_count, size_t sequence_length, size_t batch_size) { - OgaCheckResult(OgaGeneratorParamsSetInputIDs(this, input_ids, input_ids_count, sequence_length, batch_size)); + void SetModelInput(const char* name, OgaTensor& tensor) { + OgaCheckResult(OgaGeneratorParamsSetModelInput(this, name, &tensor)); } - void SetInputSequences(const OgaSequences& sequences) { - OgaCheckResult(OgaGeneratorParamsSetInputSequences(this, &sequences)); + void SetInputs(OgaNamedTensors& named_tensors) { + OgaCheckResult(OgaGeneratorParamsSetInputs(this, &named_tensors)); } void TryGraphCaptureWithMaxBatchSize(int max_batch_size) { - OgaCheckResult(OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(this, max_batch_size)); + printf("TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release\n"); } static void operator delete(void* p) { OgaDestroyGeneratorParams(reinterpret_cast(p)); } }; struct OgaGenerator : OgaAbstract { - static std::unique_ptr Create(OgaModel& model, const OgaGeneratorParams& params) { + static std::unique_ptr Create(const OgaModel& model, OgaGeneratorParams& params) { OgaGenerator* p; OgaCheckResult(OgaCreateGenerator(&model, ¶ms, &p)); return std::unique_ptr(p); @@ -211,14 +370,45 @@ struct OgaGenerator : OgaAbstract { return OgaGenerator_IsDone(this); } - void ComputeLogits() { - OgaCheckResult(OgaGenerator_ComputeLogits(this)); + void AppendTokenSequences(const OgaSequences& sequences) { + OgaCheckResult(OgaGenerator_AppendTokenSequences(this, &sequences)); + } + + void AppendTokens(const int32_t* input_ids, size_t input_ids_count) { + OgaCheckResult(OgaGenerator_AppendTokens(this, input_ids, input_ids_count)); + } + +#if OGA_USE_SPAN + void AppendTokens(std::span input_ids) { + OgaCheckResult(OgaGenerator_AppendTokens(this, input_ids.data(), input_ids.size())); + } +#endif + + bool IsSessionTerminated() const { + return OgaGenerator_IsSessionTerminated(this); } void GenerateNextToken() { OgaCheckResult(OgaGenerator_GenerateNextToken(this)); } +#if OGA_USE_SPAN + std::span GetNextTokens() { + const int32_t* out; + size_t out_count; + OgaCheckResult(OgaGenerator_GetNextTokens(this, &out, &out_count)); + return {out, out_count}; + } +#endif + + void RewindTo(size_t new_length) { + OgaCheckResult(OgaGenerator_RewindTo(this, new_length)); + } + + void SetRuntimeOption(const char* key, const char* value) { + OgaCheckResult(OgaGenerator_SetRuntimeOption(this, key, value)); + } + size_t GetSequenceCount(size_t index) const { return OgaGenerator_GetSequenceCount(this, index); } @@ -227,11 +417,263 @@ struct OgaGenerator : OgaAbstract { return OgaGenerator_GetSequenceData(this, index); } -#if __cplusplus >= 202002L + std::unique_ptr GetOutput(const char* name) { + OgaTensor* out; + OgaCheckResult(OgaGenerator_GetOutput(this, name, &out)); + return std::unique_ptr(out); + } + + std::unique_ptr GetLogits() { + OgaTensor* out; + OgaCheckResult(OgaGenerator_GetLogits(this, &out)); + return std::unique_ptr(out); + } + + void SetLogits(OgaTensor& tensor) { + OgaCheckResult(OgaGenerator_SetLogits(this, &tensor)); + } + +#if OGA_USE_SPAN std::span GetSequence(size_t index) const { return {GetSequenceData(index), GetSequenceCount(index)}; } #endif + void SetActiveAdapter(OgaAdapters& adapters, const char* adapter_name) { + OgaCheckResult(OgaSetActiveAdapter(this, &adapters, adapter_name)); + } + static void operator delete(void* p) { OgaDestroyGenerator(reinterpret_cast(p)); } }; + +struct OgaTensor : OgaAbstract { +#if OGA_USE_SPAN + template + static std::unique_ptr Create(T* data, std::span shape) { + OgaTensor* p; + OgaCheckResult(OgaCreateTensorFromBuffer(data, shape.data(), shape.size(), OgaTypeToElementType, &p)); + return std::unique_ptr(p); + } + + static std::unique_ptr Create(void* data, std::span shape, OgaElementType type) { + OgaTensor* p; + OgaCheckResult(OgaCreateTensorFromBuffer(data, shape.data(), shape.size(), type, &p)); + return std::unique_ptr(p); + } +#endif + + static std::unique_ptr Create(void* data, const int64_t* shape_dims, size_t shape_dims_count, OgaElementType element_type) { + OgaTensor* p; + OgaCheckResult(OgaCreateTensorFromBuffer(data, shape_dims, shape_dims_count, element_type, &p)); + return std::unique_ptr(p); + } + + OgaElementType Type() { + OgaElementType type; + OgaCheckResult(OgaTensorGetType(this, &type)); + return type; + } + + std::vector Shape() { + size_t size; + OgaCheckResult(OgaTensorGetShapeRank(this, &size)); + std::vector shape(size); + OgaCheckResult(OgaTensorGetShape(this, shape.data(), shape.size())); + return shape; + } + + void* Data() { + void* data; + OgaCheckResult(OgaTensorGetData(this, &data)); + return data; + } + + static void operator delete(void* p) { OgaDestroyTensor(reinterpret_cast(p)); } +}; + +struct OgaImages : OgaAbstract { + static std::unique_ptr Load(const std::vector& image_paths) { + OgaImages* p; + OgaStringArray* strs; + OgaCheckResult(OgaCreateStringArrayFromStrings(image_paths.data(), image_paths.size(), &strs)); + OgaCheckResult(OgaLoadImages(strs, &p)); + OgaDestroyStringArray(strs); + return std::unique_ptr(p); + } + +#if OGA_USE_SPAN + static std::unique_ptr Load(std::span image_paths) { + OgaImages* p; + OgaStringArray* strs; + OgaCheckResult(OgaCreateStringArrayFromStrings(image_paths.data(), image_paths.size(), &strs)); + OgaCheckResult(OgaLoadImages(strs, &p)); + OgaDestroyStringArray(strs); + return std::unique_ptr(p); + } +#endif + + static std::unique_ptr Load(const void** image_data, const size_t* image_data_sizes, size_t count) { + OgaImages* p; + OgaCheckResult(OgaLoadImagesFromBuffers(image_data, image_data_sizes, count, &p)); + return std::unique_ptr(p); + } + + static void operator delete(void* p) { OgaDestroyImages(reinterpret_cast(p)); } +}; + +struct OgaAudios : OgaAbstract { + static std::unique_ptr Load(const std::vector& audio_paths) { + OgaAudios* p; + OgaStringArray* strs; + OgaCheckResult(OgaCreateStringArrayFromStrings(audio_paths.data(), audio_paths.size(), &strs)); + OgaCheckResult(OgaLoadAudios(strs, &p)); + OgaDestroyStringArray(strs); + return std::unique_ptr(p); + } + +#if OGA_USE_SPAN + static std::unique_ptr Load(std::span audio_paths) { + OgaAudios* p; + OgaStringArray* strs; + OgaCheckResult(OgaCreateStringArrayFromStrings(audio_paths.data(), audio_paths.size(), &strs)); + OgaCheckResult(OgaLoadAudios(strs, &p)); + OgaDestroyStringArray(strs); + return std::unique_ptr(p); + } +#endif + + static std::unique_ptr Load(const void** audio_data, const size_t* audio_data_sizes, size_t count) { + OgaAudios* p; + OgaCheckResult(OgaLoadAudiosFromBuffers(audio_data, audio_data_sizes, count, &p)); + return std::unique_ptr(p); + } + + static void operator delete(void* p) { OgaDestroyAudios(reinterpret_cast(p)); } +}; + +struct OgaNamedTensors : OgaAbstract { + static std::unique_ptr Create() { + OgaNamedTensors* p; + OgaCheckResult(OgaCreateNamedTensors(&p)); + return std::unique_ptr(p); + } + + std::unique_ptr Get(const char* name) { + OgaTensor* p; + OgaCheckResult(OgaNamedTensorsGet(this, name, &p)); + return std::unique_ptr(p); + } + + void Set(const char* name, OgaTensor& tensor) { + OgaCheckResult(OgaNamedTensorsSet(this, name, &tensor)); + } + + void Delete(const char* name) { + OgaCheckResult(OgaNamedTensorsDelete(this, name)); + } + + size_t Count() const { + size_t count; + OgaCheckResult(OgaNamedTensorsCount(this, &count)); + return count; + } + + std::unique_ptr GetNames() const { + OgaStringArray* p; + OgaCheckResult(OgaNamedTensorsGetNames(this, &p)); + return std::unique_ptr(p); + } + + static void operator delete(void* p) { OgaDestroyNamedTensors(reinterpret_cast(p)); } +}; + +struct OgaMultiModalProcessor : OgaAbstract { + static std::unique_ptr Create(const OgaModel& model) { + OgaMultiModalProcessor* p; + OgaCheckResult(OgaCreateMultiModalProcessor(&model, &p)); + return std::unique_ptr(p); + } + + std::unique_ptr ProcessImages(const char* str, const OgaImages* images = nullptr) const { + OgaNamedTensors* p; + OgaCheckResult(OgaProcessorProcessImages(this, str, images, &p)); + return std::unique_ptr(p); + } + + std::unique_ptr ProcessAudios(const OgaAudios* audios) const { + OgaNamedTensors* p; + OgaCheckResult(OgaProcessorProcessAudios(this, audios, &p)); + return std::unique_ptr(p); + } + + std::unique_ptr ProcessImagesAndAudios(const char* str, const OgaImages* images = nullptr, const OgaAudios* audios = nullptr) const { + OgaNamedTensors* p; + OgaCheckResult(OgaProcessorProcessImagesAndAudios(this, str, images, audios, &p)); + return std::unique_ptr(p); + } + + OgaString Decode(const int32_t* tokens_data, size_t tokens_length) const { + const char* p; + OgaCheckResult(OgaProcessorDecode(this, tokens_data, tokens_length, &p)); + return p; + } + +#if OGA_USE_SPAN + OgaString Decode(std::span tokens) const { + const char* p; + OgaCheckResult(OgaProcessorDecode(this, tokens.data(), tokens.size(), &p)); + return p; + } +#endif + + static void operator delete(void* p) { OgaDestroyMultiModalProcessor(reinterpret_cast(p)); } +}; + +struct OgaAdapters : OgaAbstract { + static std::unique_ptr Create(const OgaModel& model) { + OgaAdapters* p; + OgaCheckResult(OgaCreateAdapters(&model, &p)); + return std::unique_ptr(p); + } + + void LoadAdapter(const char* adapter_file_path, + const char* adapter_name) { + OgaCheckResult(OgaLoadAdapter(this, adapter_file_path, adapter_name)); + } + + void UnloadAdapter(const char* adapter_name) { + OgaCheckResult(OgaUnloadAdapter(this, adapter_name)); + } + + static void operator delete(void* p) { OgaDestroyAdapters(reinterpret_cast(p)); } +}; + +struct OgaHandle { + OgaHandle() = default; + ~OgaHandle() noexcept { + OgaShutdown(); + } +}; + +// Global Oga functions +namespace Oga { + +inline void SetLogBool(const char* name, bool value) { + OgaCheckResult(OgaSetLogBool(name, value)); +} + +inline void SetLogString(const char* name, const char* value) { + OgaCheckResult(OgaSetLogString(name, value)); +} + +inline void SetCurrentGpuDeviceId(int device_id) { + OgaCheckResult(OgaSetCurrentGpuDeviceId(device_id)); +} + +inline int GetCurrentGpuDeviceId() { + int device_id; + OgaCheckResult(OgaGetCurrentGpuDeviceId(&device_id)); + return device_id; +} + +} // namespace Oga diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/header/ort_genai_c.h b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/header/ort_genai_c.h index 0939d2c36..71f8c211d 100644 --- a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/header/ort_genai_c.h +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/header/ort_genai_c.h @@ -27,99 +27,282 @@ extern "C" { #define OGA_API_CALL #endif -// ONNX Runtime Generative AI C API -// This API is not thread safe. +/** \addtogroup Global + * ONNX Runtime Generative AI C API + * This API is not thread safe. + * @{ + */ + +typedef enum OgaElementType { + OgaElementType_undefined, + OgaElementType_float32, // maps to c type float + OgaElementType_uint8, // maps to c type uint8_t + OgaElementType_int8, // maps to c type int8_t + OgaElementType_uint16, // maps to c type uint16_t + OgaElementType_int16, // maps to c type int16_t + OgaElementType_int32, // maps to c type int32_t + OgaElementType_int64, // maps to c type int64_t + OgaElementType_string, // string type (not currently supported by Oga) + OgaElementType_bool, // maps to c type bool + OgaElementType_float16, // IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction + OgaElementType_float64, // maps to c type double + OgaElementType_uint32, // maps to c type uint32_t + OgaElementType_uint64, // maps to c type uint64_t + OgaElementType_complex64, // complex with float32 real and imaginary components + OgaElementType_complex128, // complex with float64 real and imaginary components + OgaElementType_bfloat16, // Non-IEEE floating-point format based on IEEE754 single-precision +} OgaElementType; typedef struct OgaResult OgaResult; typedef struct OgaGeneratorParams OgaGeneratorParams; typedef struct OgaGenerator OgaGenerator; +typedef struct OgaRuntimeSettings OgaRuntimeSettings; +typedef struct OgaConfig OgaConfig; typedef struct OgaModel OgaModel; // OgaSequences is an array of token arrays where the number of token arrays can be obtained using // OgaSequencesCount and the number of tokens in each token array can be obtained using OgaSequencesGetSequenceCount. typedef struct OgaSequences OgaSequences; typedef struct OgaTokenizer OgaTokenizer; typedef struct OgaTokenizerStream OgaTokenizerStream; +typedef struct OgaTensor OgaTensor; +typedef struct OgaImages OgaImages; +typedef struct OgaNamedTensors OgaNamedTensors; +typedef struct OgaMultiModalProcessor OgaMultiModalProcessor; +typedef struct OgaAudios OgaAudios; +typedef struct OgaStringArray OgaStringArray; +typedef struct OgaAdapters OgaAdapters; + +//! @} + +/** \addtogroup Global + * @{ + */ -/* \brief Call this on process exit to cleanly shutdown the genai library & its onnxruntime usage - * \return Error message contained in the OgaResult. The const char* is owned by the OgaResult - * and can will be freed when the OgaResult is destroyed. +/** + * \brief Call this on process exit to cleanly shutdown the genai library & its onnxruntime usage */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaShutdown(); +OGA_EXPORT void OGA_API_CALL OgaShutdown(); -/* +/** * \param[in] result OgaResult that contains the error message. * \return Error message contained in the OgaResult. The const char* is owned by the OgaResult * and can will be freed when the OgaResult is destroyed. */ OGA_EXPORT const char* OGA_API_CALL OgaResultGetError(const OgaResult* result); -/* - * \param[in] Set logging options, see logging.h 'struct LogItems' for the list of available options +/** + * \param[in] name logging option name, see logging.h 'struct LogItems' for the list of available options + * \param[in] value logging option value. */ OGA_EXPORT OgaResult* OGA_API_CALL OgaSetLogBool(const char* name, bool value); OGA_EXPORT OgaResult* OGA_API_CALL OgaSetLogString(const char* name, const char* value); -/* +/** * \param[in] result OgaResult to be destroyed. */ -OGA_EXPORT void OGA_API_CALL OgaDestroyResult(OgaResult*); +OGA_EXPORT void OGA_API_CALL OgaDestroyResult(OgaResult* result); OGA_EXPORT void OGA_API_CALL OgaDestroyString(const char*); +OGA_EXPORT void OGA_API_CALL OgaDestroyNamedTensors(OgaNamedTensors*); OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateSequences(OgaSequences** out); -/* +/** * \param[in] sequences OgaSequences to be destroyed. */ OGA_EXPORT void OGA_API_CALL OgaDestroySequences(OgaSequences* sequences); -/* +/** * \brief Returns the number of sequences in the OgaSequences * \param[in] sequences * \return The number of sequences in the OgaSequences */ OGA_EXPORT size_t OGA_API_CALL OgaSequencesCount(const OgaSequences* sequences); -/* - * \brief Returns the number of tokens in the sequence at the given index - * \param[in] sequences +/** + * \brief Appends token_cnt number of tokens from token_ptr to sequence + * \param[in] token_ptr constant pointer to int32 tokens + * \param[in] token_cnt number of tokens to read from token_ptr + * \param[in] sequences OgaSequences object to append the tokens to + * \return OgaResult containing the error message when tokens could not been added, else nullptr. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaAppendTokenSequence(const int32_t* token_ptr, size_t token_cnt, OgaSequences* sequences); + +/** + * \brief Appends the given token to the sequence at the given index. + If the sequence at the given index does not exist, a new sequence is + created at the given index if sequence_idx is equal to the current sequences count. + * \param[in] token token to append to the sequence + * \param[in] sequences OgaSequences object to append the token to + * \param[in] sequence_index index of the sequence to append the token to + * \return OgaResult containing the error message when tokens could not been added, else nullptr. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaAppendTokenToSequence(int32_t token, OgaSequences* sequences, size_t sequence_index); + +/** + * \brief Returns the number of tokens in the sequence at the given index. + * \param[in] sequences OgaSequences to use. + * \param[in] sequence_index index of the sequence to use. * \return The number of tokens in the sequence at the given index */ OGA_EXPORT size_t OGA_API_CALL OgaSequencesGetSequenceCount(const OgaSequences* sequences, size_t sequence_index); -/* +/** * \brief Returns a pointer to the sequence data at the given index. The number of tokens in the sequence * is given by OgaSequencesGetSequenceCount - * \param[in] sequences + * \param[in] sequences OgaSequences to use. + * \param[in] sequence_index index of the sequence to use. * \return The pointer to the sequence data at the given index. The pointer is valid until the OgaSequences is destroyed. */ OGA_EXPORT const int32_t* OGA_API_CALL OgaSequencesGetSequenceData(const OgaSequences* sequences, size_t sequence_index); -/* - * \brief Creates a model from the given configuration directory and device type. +OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadImage(const char* image_path, OgaImages** images); +OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadImages(const OgaStringArray* image_paths, OgaImages** images); + +/** + * \brief Load multiple images from an array of byte buffers + * \param[in] image_data Array of byte buffers containing the image data. + * \param[in] image_data_sizes Array of sizes of the byte buffers. + * \param[in] count Number of images to load. + * \param[out] images The loaded images. + * \return OgaResult containing the error message if the loading of the images failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadImagesFromBuffers(const void** image_data, const size_t* image_data_sizes, size_t count, OgaImages** images); + +OGA_EXPORT void OGA_API_CALL OgaDestroyImages(OgaImages* images); + +OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadAudio(const char* audio_path, OgaAudios** audios); + +OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadAudios(const OgaStringArray* audio_paths, OgaAudios** audios); + +/** + * \brief Load multiple audios from an array of byte buffers + * \param[in] audio_data Array of byte buffers containing the audio data. + * \param[in] audio_data_sizes Array of sizes of the byte buffers. + * \param[in] count Number of audios to load. + * \param[out] audios The loaded audios. + * \return OgaResult containing the error message if the loading of the audios failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadAudiosFromBuffers(const void** audio_data, const size_t* audio_data_sizes, size_t count, OgaAudios** audios); + +OGA_EXPORT void OGA_API_CALL OgaDestroyAudios(OgaAudios* audios); + +/** + * \brief Creates a runtime settings instance to be used to create a model. + * \param[out] out The created runtime settings. + * \return OgaResult containing the error message if the creation of the runtime settings failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateRuntimeSettings(OgaRuntimeSettings** out); +/** + * \brief Destroys the given runtime settings. + * \param[in] settings The runtime settings to be destroyed. + */ +OGA_EXPORT void OGA_API_CALL OgaDestroyRuntimeSettings(OgaRuntimeSettings* settings); + +/** + * \brief Sets a specific runtime handle for the runtime settings. + * \param[in] settings The runtime settings to set the device type. + * \param[in] handle_name The name of the handle to set for the runtime settings. + * \param[in] handle The value of handle to set for the runtime settings. + * \return OgaResult containing the error message if the setting of the device type failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaRuntimeSettingsSetHandle(OgaRuntimeSettings* settings, const char* handle_name, void* handle); + +/** + * \brief Creates an OgaConfig from the given configuration directory. + * \param[in] config_path The path to the configuration directory. The path is expected to be encoded in UTF-8. + * \param[out] out The created config. + * \return OgaResult containing the error message if the creation of the config failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateConfig(const char* config_path, OgaConfig** out); + +/** + * \brief Clear the list of providers in the given config + * \param[in] config The config to clear the providers from. + * \return OgaResult containing the error message if the clearing of the providers failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaConfigClearProviders(OgaConfig* config); + +/** + * \brief Add the provider at the end of the list of providers in the given config if it doesn't already exist + * if it already exists, does nothing. + * \param[in] config The config to set the provider on. + * \param[in] provider The provider to set on the config. + * \return OgaResult containing the error message if the setting of the provider failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaConfigAppendProvider(OgaConfig* config, const char* provider); + +/** + * \brief Set a provider option + * \param[in] config The config to set the provider option on. + * \param[in] provider The provider to set the option on. + * \param[in] key The key of the option to set. + * \param[in] value The value of the option to set. + * \return OgaResult containing the error message if the setting of the provider option failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaConfigSetProviderOption(OgaConfig* config, const char* provider, const char* key, const char* value); + +/** + * \brief Overlay JSON on top of config file + * \param[in] config The config to overlay the JSON on. + * \param[in] json The JSON to overlay on the config. + * \return OgaResult containing the error message if the overlaying of the JSON failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaConfigOverlay(OgaConfig* config, const char* json); + +/** + * \brief Creates a model from the given configuration directory. * \param[in] config_path The path to the model configuration directory. The path is expected to be encoded in UTF-8. - * \param[in] device_type The device type to use for the model. * \param[out] out The created model. * \return OgaResult containing the error message if the model creation failed. */ OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out); -/* +/** + * \brief Creates a model from the given configuration. + * \param[in] config The configuration to use for the model. + * \param[out] out The created model. + * \return OgaResult containing the error message if the model creation failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModelFromConfig(const OgaConfig* config, OgaModel** out); + +/** + * \brief Creates a model from the given configuration directory, runtime settings and device type. + * \param[in] config_path The path to the model configuration directory. The path is expected to be encoded in UTF-8. + * \param[in] settings The runtime settings to use for the model. + * \param[out] out The created model. + * \return OgaResult containing the error message if the model creation failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModelWithRuntimeSettings(const char* config_path, const OgaRuntimeSettings* settings, OgaModel** out); + +/** + * \brief Returns the type of the model. + * \param[in] model The model to get the type from. + * \param[out] out The type of the model. Must be destroyed with OgaDestroyString + * \return OgaResult containing the error message if the getting of the model type failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaModelGetType(const OgaModel* model, const char** out); + +/** + * \brief Returns the device type of the model. + * \param[in] model The model to get the device type from. + * \param[out] out The device type of the model. Must be destroyed with OgaDestroyString + * \return OgaResult containing the error message if the getting of the device type failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaModelGetDeviceType(const OgaModel* model, const char** out); + +/** + * \brief Destroys the given config + * \param[in] config The config to be destroyed. + */ +OGA_EXPORT void OGA_API_CALL OgaDestroyConfig(OgaConfig* config); + +/** * \brief Destroys the given model. * \param[in] model The model to be destroyed. */ OGA_EXPORT void OGA_API_CALL OgaDestroyModel(OgaModel* model); -/* - * \brief Generates an array of token arrays from the model execution based on the given generator params. - * \param[in] model The model to use for generation. - * \param[in] generator_params The parameters to use for generation. - * \param[out] out The generated sequences of tokens. The caller is responsible for freeing the sequences using OgaDestroySequences - * after it is done using the sequences. - * \return OgaResult containing the error message if the generation failed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerate(OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out); - -/* +/** * \brief Creates a OgaGeneratorParams from the given model. * \param[in] model The model to use for generation. * \param[out] out The created generator params. @@ -127,7 +310,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerate(OgaModel* model, const OgaGenerat */ OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* model, OgaGeneratorParams** out); -/* +/** * \brief Destroys the given generator params. * \param[in] generator_params The generator params to be destroyed. */ @@ -137,70 +320,128 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGenerato OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* generator_params, const char* name, bool value); OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* generator_params, int32_t max_batch_size); -/* - * \brief Sets the input ids for the generator params. The input ids are used to seed the generation. - * \param[in] generator_params The generator params to set the input ids on. - * \param[in] input_ids The input ids array of size input_ids_count = batch_size * sequence_length. - * \param[in] input_ids_count The total number of input ids. - * \param[in] sequence_length The sequence length of the input ids. - * \param[in] batch_size The batch size of the input ids. - * \return OgaResult containing the error message if the setting of the input ids failed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* generator_params, const int32_t* input_ids, - size_t input_ids_count, size_t sequence_length, size_t batch_size); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputs(OgaGeneratorParams* generator_params, const OgaNamedTensors* named_tensors); -/* - * \brief Sets the input id sequences for the generator params. The input id sequences are used to seed the generation. - * \param[in] generator_params The generator params to set the input ids on. - * \param[in] sequences The input id sequences. - * \return OgaResult containing the error message if the setting of the input id sequences failed. +/** + * \brief For additional model inputs that genai does not handle, this lets the user set their values. For example LoRA models handle + * fine tuning through model inputs. This lets the user supply the fine tuning inputs, while genai handles the standard inputs. + * \param[in] generator_params The generator params to set the input on + * \param[in] name Name of the model input (this must match the model's input name) + * \param[in] tensor The OgaTensor of the input data */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* generator_params, const OgaSequences* sequences); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetModelInput(OgaGeneratorParams* generator_params, const char* name, OgaTensor* tensor); -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(OgaGeneratorParams*, const int32_t* inputs, size_t count); -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperDecoderInputIDs(OgaGeneratorParams*, const int32_t* input_ids, size_t input_ids_count); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(OgaGeneratorParams*, OgaTensor* tensor); -/* +/** * \brief Creates a generator from the given model and generator params. * \param[in] model The model to use for generation. * \param[in] params The parameters to use for generation. * \param[out] out The created generator. * \return OgaResult containing the error message if the generator creation failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out); +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out); -/* +/** * \brief Destroys the given generator. * \param[in] generator The generator to be destroyed. */ OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator); -/* +/** * \brief Returns true if the generator has finished generating all the sequences. * \param[in] generator The generator to check if it is done with generating all sequences. * \return True if the generator has finished generating all the sequences, false otherwise. */ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator); +OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsSessionTerminated(const OgaGenerator* generator); -/* +/** + * \brief Adds the input ids to the generator. The input ids are used to seed the generation. + * \param[in] oga_generator The generator to add the input ids to. + * \param[in] p_sequences The input id sequences. + * \return OgaResult containing the error message if the setting of the input ids failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokenSequences(OgaGenerator* oga_generator, const OgaSequences* p_sequences); + +/** + * \brief Adds the input ids to the generator. The input ids are used to seed the generation. + * \param[in] oga_generator The generator to add the input ids to. + * \param[in] input_ids The input ids to add. + * \param[in] input_ids_count The number of input ids to add (batch_size * sequence_length). + * \return OgaResult containing the error message if the setting of the input ids failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* oga_generator, const int32_t* input_ids, size_t input_ids_count); + +/** * \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator. * \param[in] generator The generator to compute the logits for. * \return OgaResult containing the error message if the computation of the logits failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator); OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator); +/** + * \brief Returns a pointer to the next tokens generated by the model. The out_count will match the batch size + * \param[in] generator The generator to get the next tokens from. + * \param[out] out The pointer to the next tokens generated by the model. The pointer is valid until the next OgaGenerator call + * \param[out] out_count The number of tokens in the out array. + * \return OgaResult containing the error message if the getting of the next tokens failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GetNextTokens(const OgaGenerator* generator, const int32_t** out, size_t* out_count); + +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_SetRuntimeOption(OgaGenerator* generator, const char* key, const char* value); + +/** + * \brief Rewinds the generator to the given length. This is useful when the user wants to rewind the generator to a specific length + * and continue generating from that point. + * \param[in] generator The generator to rewind to the given length. + * \param[in] new_length The desired length in tokens after rewinding. + * \return OgaResult containing the error message if the rewinding failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_RewindTo(OgaGenerator* generator, size_t new_length); + +/** + * \brief Returns a copy of the model output identified by the given name as an OgaTensor on CPU. The buffer is owned by returned OgaTensor + * and will be released when the OgaTensor is destroyed + * \param[in] generator The generator to run the GetOutput on the name provided and the out pointer to store the output. + * \param[in] name The name of the output tensor. + * \param[out] out The returned OgaTensor. + * \return OgaResult containing the error message if the computation failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* generator, const char* name, OgaTensor** out); + +/** + * \brief Returns a copy of the logits from the model as an OgaTensor on CPU. The buffer is owned by returned OgaTensor + * and will be released when the OgaTensor is destroyed + * \param[in] generator The generator get the logits from + * \param[out] out The OgaTensor containing the logits, it only contains the last token logits even in prompt processing + * \return OgaResult containing the error message if the computation failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GetLogits(OgaGenerator* generator, OgaTensor** out); + +/** + * \brief Sets the logits to the generator. This is useful when the user wants to set the logits to a specific value + * for example when doing guided generation. + * \param[in] generator The generator to set the logits on + * \param[in] tensor The OgaTensor containing the logits, it must have the same shape as the logits returned by GetLogits + * which is the last token logits. + * \return OgaResult containing the error message if the setting of the logits failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_SetLogits(OgaGenerator* generator, OgaTensor* tensor); + /* * \brief Returns the number of tokens in the sequence at the given index. * \param[in] generator The generator to get the count of the tokens for the sequence at the given index. + * \param[in] index The given index. * \return The number tokens in the sequence at the given index. */ OGA_EXPORT size_t OGA_API_CALL OgaGenerator_GetSequenceCount(const OgaGenerator* generator, size_t index); -/* +/** * \brief Returns a pointer to the sequence data at the given index. The number of tokens in the sequence * is given by OgaGenerator_GetSequenceCount * \param[in] generator The generator to get the sequence data for the sequence at the given index. + * \param[in] index The given index. * \return The pointer to the sequence data at the given index. The sequence data is owned by the OgaGenerator * and will be freed when the OgaGenerator is destroyed. The caller must copy the data if it needs to * be used after the OgaGenerator is destroyed. @@ -210,30 +451,216 @@ OGA_EXPORT const int32_t* OGA_API_CALL OgaGenerator_GetSequenceData(const OgaGen OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateTokenizer(const OgaModel* model, OgaTokenizer** out); OGA_EXPORT void OGA_API_CALL OgaDestroyTokenizer(OgaTokenizer*); -/* Encodes a single string and adds the encoded sequence of tokens to the OgaSequences. The OgaSequences must be freed with OgaDestroySequences - when it is no longer needed. +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateMultiModalProcessor(const OgaModel* model, OgaMultiModalProcessor** out); + +OGA_EXPORT void OGA_API_CALL OgaDestroyMultiModalProcessor(OgaMultiModalProcessor* processor); + +/** + * Encodes a single string and adds the encoded sequence of tokens to the OgaSequences. The OgaSequences must be freed with OgaDestroySequences + * when it is no longer needed. */ OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer*, const char* str, OgaSequences* sequences); -/* Decode a single token sequence and returns a null terminated utf8 string. out_string must be freed with OgaDestroyString +/** + * Batch encode an array of strings and return a single tensor output + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncodeBatch(const OgaTokenizer*, const char** strings, size_t count, OgaTensor** out); + +/** + * Batch decode a tensor of token ids and return an array of strings + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerDecodeBatch(const OgaTokenizer*, const OgaTensor* tensor, OgaStringArray** out); + +/** + * \brief Converts the given string to a single token id. + * \param[in] tokenizer The tokenizer to use to convert the string to a token id. + * \param[in] str The string to convert to a token id. + * \param[in] token_id The converted token id. + * \return OgaResult containing the error message if the conversion of the string to a token id failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerToTokenId(const OgaTokenizer* tokenizer, const char* str, int32_t* token_id); + +OGA_EXPORT OgaResult* OGA_API_CALL OgaProcessorProcessImages(const OgaMultiModalProcessor*, const char* prompt, const OgaImages* images, OgaNamedTensors** input_tensors); + +OGA_EXPORT OgaResult* OGA_API_CALL OgaProcessorProcessAudios(const OgaMultiModalProcessor*, const OgaAudios* audios, OgaNamedTensors** input_tensors); + +OGA_EXPORT OgaResult* OGA_API_CALL OgaProcessorProcessImagesAndAudios(const OgaMultiModalProcessor*, const char* prompt, const OgaImages* images, + const OgaAudios* audios, OgaNamedTensors** input_tensors); + +/** Decode a single token sequence and returns a null terminated utf8 string. out_string must be freed with OgaDestroyString */ OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerDecode(const OgaTokenizer*, const int32_t* tokens, size_t token_count, const char** out_string); +OGA_EXPORT OgaResult* OGA_API_CALL OgaProcessorDecode(const OgaMultiModalProcessor*, const int32_t* tokens, size_t token_count, const char** out_string); -/* OgaTokenizerStream is to decoded token strings incrementally, one token at a time. +/** OgaTokenizerStream is to decoded token strings incrementally, one token at a time. */ OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateTokenizerStream(const OgaTokenizer*, OgaTokenizerStream** out); +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateTokenizerStreamFromProcessor(const OgaMultiModalProcessor*, OgaTokenizerStream** out); OGA_EXPORT void OGA_API_CALL OgaDestroyTokenizerStream(OgaTokenizerStream*); -/* +/** * Decode a single token in the stream. If this results in a word being generated, it will be returned in 'out'. * The caller is responsible for concatenating each chunk together to generate the complete result. * 'out' is valid until the next call to OgaTokenizerStreamDecode or when the OgaTokenizerStream is destroyed */ OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerStreamDecode(OgaTokenizerStream*, int32_t token, const char** out); +/** Create an OgaTensor from an optional user owned buffer. If a user owned buffer is supplied, the OgaTensor does + * not own the memory (as it has no way to free it) so the 'data' parameter must be valid for the lifetime of the OgaTensor. + * If the 'data' parameter is nullptr, the OgaTensor will allocate its own memory. + * + * \param[in] data User supplied memory pointer, if non nullptr it must remain valid for lifetime of the OgaTensor + * \param[in] shape_dims Pointer to array of int64_t values that define the tensor shape, example [1 20 30] would be equivalent to a C array of [1][20][30] + * \param[in] shape_dims_count Count of elements in the shape_dims array + * \param[in] element_type The data type that 'data' points to. + * \param[out] out Writes the newly created OgaTensor into this, must be destroyed with OgaDestroyTensor + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateTensorFromBuffer(void* data, const int64_t* shape_dims, size_t shape_dims_count, OgaElementType element_type, OgaTensor** out); + +OGA_EXPORT void OGA_API_CALL OgaDestroyTensor(OgaTensor* tensor); + +/** Get the OgaElementType of the data stored in the OgaTensor + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaTensorGetType(OgaTensor*, OgaElementType* out); + +/** Get the number of dimensions of the OgaTensor's shape, typically used to allocate a buffer of this size then calling OgaTensorGetShape with it + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaTensorGetShapeRank(OgaTensor*, size_t* out); + +/** Copies the shape dimensions into the shape_dims parameters. shape_dims_count must match the value returned by OgaTensorGetShapeRank + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaTensorGetShape(OgaTensor*, int64_t* shape_dims, size_t shape_dims_count); + +/** A pointer to the tensor data, it is typically cast into the actual data type of the tensor + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaTensorGetData(OgaTensor*, void** out); + +/** \brief Create an OgaNamedTensors + * \param[out] out The created OgaNamedTensors + * \return OgaResult containing the error message if the creation of the OgaNamedTensors failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateNamedTensors(OgaNamedTensors** out); + +/** \brief Lookup a tensor in a NamedTensor set by name + * \param[in] named_tensors The named tensors to search + * \param[in] name The name of the tensor to find + * \param[out] out The tensor with the given name + * \return OgaResult containing the error message if the tensor with the given name could not be found. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaNamedTensorsGet(OgaNamedTensors* named_tensors, const char* name, OgaTensor** out); + +/** \brief Set a tensor in a NamedTensor set by name + * \param[in] named_tensors The named tensors to set the tensor + * \param[in] name The name of the tensor to set + * \param[in] tensor The tensor to set + * \return OgaResult containing the error message if the tensor with the given name could not be set. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaNamedTensorsSet(OgaNamedTensors* named_tensors, const char* name, OgaTensor* tensor); + +/** \brief Delete a tensor in a NamedTensor set by name + * \param[in] named_tensors The named tensors to remove the tensor + * \param[in] name The name of the tensor to remove + * \return OgaResult containing the error message if the tensor with the given name could not be removed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaNamedTensorsDelete(OgaNamedTensors* named_tensors, const char* name); + +/** \brief Get the number of tensors in the NamedTensors + * \param[in] named_tensors The named tensors to get the count of the tensors + * \param[out] out The number of tensors in the NamedTensors + * \return OgaResult containing the error message if the getting of the count of the tensors failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaNamedTensorsCount(const OgaNamedTensors* named_tensors, size_t* out); + +/** \brief Return an OgaStringArray of the names of the tensors in an OgaNamedTensors object + * \param[in] named_tensors The named tensors to get the names of the tensors + * \param[out] out The OgaStringArray containing the names of the tensors + * \return OgaResult containing the error message if the getting of the names of the tensors failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaNamedTensorsGetNames(const OgaNamedTensors* named_tensors, OgaStringArray** out); + OGA_EXPORT OgaResult* OGA_API_CALL OgaSetCurrentGpuDeviceId(int device_id); OGA_EXPORT OgaResult* OGA_API_CALL OgaGetCurrentGpuDeviceId(int* device_id); +/** + * \brief Creates an object of type OgaStringArray. + * \return The result of the operation. If the operation is successful, a nullptr is returned. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateStringArray(OgaStringArray** out); + +/** + * \brief Creates an object of type OgaStringArray from the given strings. + * \return The result of the operation. If the operation is successful, a nullptr is returned. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateStringArrayFromStrings(const char* const* strs, size_t count, OgaStringArray** out); + +/** + * \brief Destroys OgaStringArray. + */ +OGA_EXPORT void OGA_API_CALL OgaDestroyStringArray(OgaStringArray* string_array); + +/** + * \brief Adds the given string to the string_array. + * \param[inout] string_array The string array to which the string is to be added + * \param[in] str The string to be added to the string_array. + * \return The result of the operation. If the operation is successful, a nullptr is returned. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaStringArrayAddString(OgaStringArray* string_array, const char* str); + +/** + * \brief Gets the number of strings in the string_array. + * \param[in] string_array The OgaStringArray object to get the count of the strings. + * \param[out] out The number of strings in the string_array. + * \return The result of the operation. If the operation is successful, a nullptr is returned. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaStringArrayGetCount(const OgaStringArray* string_array, size_t* out); + +/** + * \brief Get a string from a string_array + * \param[in] string_array The OgaStringArray object to get the string from. + * \param[in] index The index of the string to get. + * \return The string at the given index. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaStringArrayGetString(const OgaStringArray* string_array, size_t index, const char** out); + +/** + * \brief Creates the OgaAdapters object that manages the adapters. + - The OgaAdapters object is used to load all the model adapters. + - It is responsible for reference counting the loaded adapters. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateAdapters(const OgaModel* model, OgaAdapters** out); + +/** + * \brief Destroys the OgaAdapters object. + */ +OGA_EXPORT void OGA_API_CALL OgaDestroyAdapters(OgaAdapters* adapters); + +/** + * \brief Loads the model adapter from the given adapter file path and adapter name. + * \param[in] adapters The OgaAdapters object to load the adapter. + * \param[in] adapter_file_path The file path of the adapter to load. + * \param[in] adapter_name A unique identifier for the adapter chosed by the function invoker. + * This name is used for querying the adapter. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadAdapter(OgaAdapters* adapters, const char* adapter_file_path, + const char* adapter_name); + +/** + * \brief Unloads the adapter with the given identifier from the previosly loaded adapters. + If the adapter is not found, or if it cannot be unloaded (when it is in use), an error is returned. + * \param[in] adapters The OgaAdapters object to unload the adapter. + * \param[in] adapter_name The name of the adapter to unload. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaUnloadAdapter(OgaAdapters* adapters, const char* adapter_name); + +/** + * \brief Sets the adapter with the given adapter name as active for the given OgaGenerator object. + * \param[in] generator The OgaGenerator object to set the active adapter. + * \param[in] adapters The OgaAdapters object that manages the model adapters. + * \param[in] adapter_name The name of the adapter to set as active. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaSetActiveAdapter(OgaGenerator* generator, OgaAdapters* adapters, + const char* adapter_name); #ifdef __cplusplus } #endif +//! @} \ No newline at end of file diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/libonnxruntime-genai.dylib b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/libonnxruntime-genai.dylib old mode 100644 new mode 100755 index 940394e67..17715d40c Binary files a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/libonnxruntime-genai.dylib and b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/libonnxruntime-genai.dylib differ diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/cpu_provider_factory.h b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/cpu_provider_factory.h new file mode 100644 index 000000000..292678692 --- /dev/null +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/cpu_provider_factory.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "onnxruntime_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \param use_arena zero: false. non-zero: true. + */ +ORT_EXPORT +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena) +ORT_ALL_ARGS_NONNULL; + +#ifdef __cplusplus +} +#endif diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/header/onnxruntime_c_api.h b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_c_api.h similarity index 83% rename from mobile/examples/phi-3/ios/LocalLLM/LocalLLM/header/onnxruntime_c_api.h rename to mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_c_api.h index de3013484..30284ee9e 100644 --- a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/header/onnxruntime_c_api.h +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_c_api.h @@ -38,7 +38,7 @@ * * This value is used by some API functions to behave as this version of the header expects. */ -#define ORT_API_VERSION 19 +#define ORT_API_VERSION 22 #ifdef __cplusplus extern "C" { @@ -46,7 +46,7 @@ extern "C" { //! @} // SAL2 Definitions -#ifndef _WIN32 +#ifndef _MSC_VER #define _In_ #define _In_z_ #define _In_opt_ @@ -196,7 +196,10 @@ typedef enum ONNXTensorElementDataType { ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, // Non-IEEE floating-point format based on IEEE754 single-precision ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2, // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ // Non-IEEE floating-point format based on IEEE754 single-precision + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision + // Int4 types were introduced in ONNX 1.16. See https://onnx.ai/onnx/technical/int4.html + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, // maps to a pair of packed uint4 values (size == 1 byte) + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 // maps to a pair of packed int4 values (size == 1 byte) } ONNXTensorElementDataType; // Synced with onnx TypeProto oneof @@ -252,6 +255,7 @@ typedef enum OrtErrorCode { ORT_NOT_IMPLEMENTED, ORT_INVALID_GRAPH, ORT_EP_FAIL, + ORT_MODEL_LOAD_CANCELED, } OrtErrorCode; typedef enum OrtOpAttrType { @@ -301,8 +305,14 @@ ORT_RUNTIME_CLASS(Op); ORT_RUNTIME_CLASS(OpAttr); ORT_RUNTIME_CLASS(Logger); ORT_RUNTIME_CLASS(ShapeInferContext); - -#ifdef _WIN32 +ORT_RUNTIME_CLASS(LoraAdapter); +ORT_RUNTIME_CLASS(ValueInfo); +ORT_RUNTIME_CLASS(Node); +ORT_RUNTIME_CLASS(Graph); +ORT_RUNTIME_CLASS(Model); +ORT_RUNTIME_CLASS(ModelCompilationOptions); + +#ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; #else typedef OrtStatus* OrtStatusPtr; @@ -470,13 +480,13 @@ typedef struct OrtCUDAProviderOptions { /** \brief Enable TunableOp for using. * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. - * This option can be overriden by environment variable ORT_CUDA_TUNABLE_OP_ENABLE. + * This option can be overridden by environment variable ORT_CUDA_TUNABLE_OP_ENABLE. */ int tunable_op_enable; /** \brief Enable TunableOp for tuning. * Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default. - * This option can be overriden by environment variable ORT_CUDA_TUNABLE_OP_TUNING_ENABLE. + * This option can be overridden by environment variable ORT_CUDA_TUNABLE_OP_TUNING_ENABLE. */ int tunable_op_tuning_enable; @@ -559,13 +569,13 @@ typedef struct OrtROCMProviderOptions { /** \brief Enable TunableOp for using. * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. - * This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE. + * This option can be overridden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE. */ int tunable_op_enable; /** \brief Enable TunableOp for tuning. * Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default. - * This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_TUNING_ENABLE. + * This option can be overridden by environment variable ORT_ROCM_TUNABLE_OP_TUNING_ENABLE. */ int tunable_op_tuning_enable; @@ -614,11 +624,21 @@ typedef struct OrtMIGraphXProviderOptions { int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name + int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true + const char* migraphx_save_model_path; // migraphx model path name + int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true + const char* migraphx_load_model_path; // migraphx model path name + bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options - * - * \see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO + * \brief This Struct is frozen since ORT 1.13.0. Its maintained part of Legacy API for compatibility. + * \brief For latest OpenVINO Provider Options update to the ProviderOptions map. + * \brief Latest OpenVINO Provider Options are listed in the + * \htmlonly + * onnxruntime document. + * \endhtmlonly + * \see OrtApi::SessionOptionsAppendExecutionProvider() */ typedef struct OrtOpenVINOProviderOptions { #ifdef __cplusplus @@ -651,6 +671,12 @@ typedef struct OrtApi OrtApi; struct OrtTrainingApi; typedef struct OrtTrainingApi OrtTrainingApi; +struct OrtModelEditorApi; +typedef struct OrtModelEditorApi OrtModelEditorApi; + +struct OrtCompileApi; +typedef struct OrtCompileApi OrtCompileApi; + /** \brief The helper interface to get the right version of OrtApi * * Get a pointer to this structure through ::OrtGetApiBase @@ -833,7 +859,8 @@ struct OrtApi { * * \snippet{doc} snippets.dox OrtStatus Return Value */ - ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, + ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); /** \brief Run the model in an ::OrtSession @@ -1326,6 +1353,8 @@ struct OrtApi { * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. * p_data is owned by caller. ReleaseValue won't release p_data. * + * If you wish to transfer ownership of p_data to ORT use CreateTensorWithDataAndDeleterAsOrtValue. + * * \param[in] info Memory description of where the p_data buffer resides (CPU vs GPU etc). * \param[in] p_data Pointer to the data buffer. * \param[in] p_data_len The number of bytes in the data buffer. @@ -1847,8 +1876,8 @@ struct OrtApi { * and not present, the function returns success and out is set to nullptr. * * \param[in] context ::OrtKernelContext instance - * \param[in] input index. See KernelContext_GetInputCount for boundaries check. - * \param[in, out] returns a ptr to OrtValue if the input is present + * \param[in] index See KernelContext_GetInputCount for boundaries check. + * \param[out] out OrtValue if the input is present otherwise is set nullptr * * \snippet{doc} snippets.dox OrtStatus Return Value */ @@ -1861,8 +1890,10 @@ struct OrtApi { * and not present, the function returns success and out is set to nullptr. * * \param[in] context ::OrtKernelContext instance - * \param[in] output index. See KernelContext_GetOutputCount for boundaries check. - * \param[in, out] returns a ptr to OrtValue if the output is present + * \param[in] index See KernelContext_GetOutputCount for boundaries check. + * \param[in] dim_values output dimensions + * \param[in] dim_count number of dimensions + * \param[out] out a ptr to OrtValue to output otherwise set to nullptr * * \snippet{doc} snippets.dox OrtStatus Return Value */ @@ -1981,7 +2012,8 @@ struct OrtApi { /** \brief Get the value type from an ::OrtMapTypeInfo * * \param[in] map_type_info - * \param[out] type_info + * \param[out] type_info A copy of the OrtTypeInfo for the map value type. + * The user must free this value with ReleaseTypeInfo. * * \snippet{doc} snippets.dox OrtStatus Return Value */ @@ -1996,7 +2028,8 @@ struct OrtApi { * This is used by WinML to support model reflection APIs. * * \param[in] sequence_type_info - * \param[out] type_info + * \param[out] type_info A copy of the OrtTypeInfo for the sequence element type. + * The user must free this value with ReleaseTypeInfo. * * \snippet{doc} snippets.dox OrtStatus Return Value */ @@ -2789,7 +2822,7 @@ struct OrtApi { * "initial_growth_chunk_size_bytes": (Possible) Size of the second allocation in the arena. * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. * "max_power_of_two_extend_bytes": The maximum enxtend size if arena strategy is `kNextPowerOfTwo`. - * It is not an allocation limit, it is only a limit for extention when requested byte is less than the limit. + * It is not an allocation limit, it is only a limit for extension when requested byte is less than the limit. * When requested bytes is more than the limit, allocator will still return as requested. * Use -1 to allow ORT to choose the default 1GB for max_power_of_two_extend_bytes. * Ultimately, the allocation size is determined by the allocation memory request. @@ -2871,7 +2904,8 @@ struct OrtApi { * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(CreateSessionWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, - _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, _Outptr_ OrtSession** out); /** \brief Create session from memory with prepacked weights container @@ -2894,7 +2928,8 @@ struct OrtApi { */ ORT_API2_STATUS(CreateSessionFromArrayWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, _Outptr_ OrtSession** out); /// @} @@ -2937,7 +2972,7 @@ struct OrtApi { * * Please refer to https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtTensorRTProviderOptionsV2 - * and value should be its related range. + * and value should be its related range. Recreates the options and only sets the supplied values. * * For example, key="trt_max_workspace_size" and value="2147483648" * @@ -3433,7 +3468,7 @@ struct OrtApi { * * Please refer to https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtCUDAProviderOptionsV2 - * and value should be its related range. + * and value should be its related range. Recreates the options and only sets the supplied values. * * For example, key="device_id" and value="0" * @@ -3616,49 +3651,96 @@ struct OrtApi { * that should be used to add it. * * QNN supported keys: - * "backend_path": file path to QNN backend library. - * "profiling_level": QNN profiling level, options: "off", "basic", "detailed". Default to off. + * "backend_type": Type of QNN backend. Specifies a backend path that is the associated QNN backend library file + * name. E.g., given backend type "htp", on Windows, the backend path would be "QnnHtp.dll", and on other + * platforms, it would be "libQnnHtp.so". Mutually exclusive with "backend_path". + * Available options: + * - "cpu" + * - "gpu" + * - "htp": Default. + * - "saver" + * "backend_path": File path to QNN backend library. Mutually exclusive with "backend_type". + * "profiling_level": QNN profiling level. + * Available options: + * - "off": Default. + * - "basic" + * - "detailed" * "profiling_file_path": QNN profiling file path if ETW not enabled. * "rpc_control_latency": QNN RPC control latency. * "vtcm_mb": QNN VTCM size in MB. default to 0(not set). - * "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance", - * "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default". + * "htp_performance_mode": QNN performance mode. + * Available options: + * - "burst" + * - "balanced" + * - "default": Default. + * - "high_performance" + * - "high_power_saver" + * - "low_balanced" + * - "extreme_power_saver" + * - "low_power_saver" + * - "power_saver" + * - "sustained_high_performance" * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will - * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and - * may alter model/EP partitioning. Use only for debugging. - * "qnn_context_priority": QNN context priority, options: "low", "normal", "normal_high", "high". Default to "normal". - * "htp_graph_finalization_optimization_mode": Set the optimization mode for graph finalization on the HTP backend. Available options: - * - "0": Default. - * - "1": Faster preparation time, less optimal graph. - * - "2": Longer preparation time, more optimal graph. - * - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details. - * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown). - * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options: - * - "0": Default (none). - * - "68" - * - "69" - * - "73" - * - "75" + * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and + * may alter model/EP partitioning. Use only for debugging. + * "qnn_context_priority": QNN context priority. + * Available options: + * - "low" + * - "normal": Default. + * - "normal_high" + * - "high" + * "htp_graph_finalization_optimization_mode": Set the optimization mode for graph finalization on the HTP backend. + * Available options: + * - "0": Default. + * - "1": Faster preparation time, less optimal graph. + * - "2": Longer preparation time, more optimal graph. + * - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific + * details. + * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. + * Defaults to "0" (unknown). + * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. + * Available options: + * - "0": Default (none). + * - "68" + * - "69" + * - "73" + * - "75" * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). - "enable_htp_fp16_precision": Only used for float32 model. - Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. - - "0": Default. With fp32 precision. - - "1": With fp16 precision. + * "enable_htp_fp16_precision": Used for float32 model for HTP backend. + * Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. + * - "0": With fp32 precision. + * - "1": Default. With fp16 precision. + * "offload_graph_io_quantization": Offload graph input quantization and graph output dequantization to another + * execution provider (typically CPU EP). + * - "0": Disabled. QNN EP will handle quantization and dequantization of graph I/O. + * - "1": Enabled. This is the default value. + * "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context + * binary. + * - "0": Default. Disabled. + * - "1": Enabled. + * "enable_htp_shared_memory_allocator": Enable the QNN HTP shared memory allocator. Requires libcdsprpc.so/dll to + * be available. + * - "0": Default. Disabled. + * - "1": Enabled. + * "dump_json_qnn_graph": Set to "1" to dump QNN graphs generated by QNN EP as JSON files. Each graph partition + * assigned to QNN EP is dumped to a separate file. + * "json_qnn_graph_dir": Directory in which to dump QNN JSON graphs. If not specified, QNN graphs are dumped in the + * program's current working directory. Ignored if "dump_json_qnn_graph" is not set. * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", - * "DSP", "DSP_FIXED8_TF", "AIP_FIXED_TF", "AIP_FIXED8_TF". - * Mapping to SNPE Runtime_t definition: CPU, CPU_FLOAT32 => zdl::DlSystem::Runtime_t::CPU; - * GPU, GPU_FLOAT32_16_HYBRID => zdl::DlSystem::Runtime_t::GPU; - * GPU_FLOAT16 => zdl::DlSystem::Runtime_t::GPU_FLOAT16; - * DSP, DSP_FIXED8_TF => zdl::DlSystem::Runtime_t::DSP. - * AIP_FIXED_TF, AIP_FIXED8_TF => zdl::DlSystem::Runtime_t::AIP_FIXED_TF. + * "DSP", "DSP_FIXED8_TF", "AIP_FIXED_TF", "AIP_FIXED8_TF". + * Mapping to SNPE Runtime_t definition: + * CPU, CPU_FLOAT32 => zdl::DlSystem::Runtime_t::CPU; + * GPU, GPU_FLOAT32_16_HYBRID => zdl::DlSystem::Runtime_t::GPU; + * GPU_FLOAT16 => zdl::DlSystem::Runtime_t::GPU_FLOAT16; + * DSP, DSP_FIXED8_TF => zdl::DlSystem::Runtime_t::DSP. + * AIP_FIXED_TF, AIP_FIXED8_TF => zdl::DlSystem::Runtime_t::AIP_FIXED_TF. * "priority": execution priority, options: "low", "normal". * "buffer_type": ITensor or user buffers, options: "ITENSOR", user buffer with different types - "TF8", "TF16", "UINT8", "FLOAT". * "ITENSOR" -- default, ITensor which is float only. * "TF8" -- quantized model required, "FLOAT" -- for both quantized or non-quantized model * "enable_init_cache": enable SNPE init caching feature, set to 1 to enabled it. Disabled by default. - * If SNPE is not available (due to a non Snpe enabled build or its dependencies not being installed), this function will fail. * * XNNPACK supported keys: * "intra_op_num_threads": number of thread-pool size to use for XNNPACK execution provider. @@ -3769,7 +3851,7 @@ struct OrtApi { /** \brief Release an OrtCANNProviderOptions * - * \param[in] the pointer of OrtCANNProviderOptions which will been deleted + * \param[in] input The pointer of OrtCANNProviderOptions which will been deleted * * \since Version 1.13. */ @@ -4259,8 +4341,8 @@ struct OrtApi { * specific type that is described by the returned ::OrtTypeInfo. * * \param[in] optional_type_info - * \param[out] out A pointer to the ::OrtTypeInfo for what the optional value could be. - * it is owned by OrtOptionalTypeInfo instance. + * \param[out] out A copy of ::OrtTypeInfo for what the optional value could be. + * The user must free this value with ReleaseTypeInfo. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -4458,13 +4540,14 @@ struct OrtApi { * E.g. a cuda stream or a cublas handle * * \param context - Kernel context - * \param resouce_version - Version of the resource + * \param resource_version - Version of the resource * \param resource_id - Type of resource * \param resource - A pointer to returned resource * * \since Version 1.16. */ - ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resouce_version, _In_ int resource_id, _Outptr_ void** resource); + ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resource_version, + _In_ int resource_id, _Outptr_ void** resource); /** \brief Set user logging function * @@ -4519,10 +4602,10 @@ struct OrtApi { ORT_API2_STATUS(ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context, _In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr); /** - * Set type and shape info of an ouput + * Set type and shape info of an output * * \param[in] context - * \param[in] index The index of the ouput + * \param[in] index The index of the output * \param[out] info Type shape info of the output * * \since Version 1.17. @@ -4588,6 +4671,8 @@ struct OrtApi { * \param[in] num_keys * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.17. */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2, _In_ OrtSessionOptions* options, @@ -4605,6 +4690,8 @@ struct OrtApi { * \param[in] num_keys * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, @@ -4618,7 +4705,10 @@ struct OrtApi { * \param[in] mem_info OrtMemoryInfo instance * \param[in] count_or_bytes How many bytes is this scratch buffer * \param[out] out A pointer to the scrach buffer + * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); @@ -4629,6 +4719,8 @@ struct OrtApi { * \param[out] out A pointer to OrtAllocator * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); @@ -4650,12 +4742,187 @@ struct OrtApi { * \param[in] num_external_initializer_files Number of external files * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(AddExternalInitializersFromFilesInMemory, _In_ OrtSessionOptions* options, _In_reads_(num_external_initializer_files) const ORTCHAR_T* const* external_initializer_file_names, _In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, _In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, size_t num_external_initializer_files); + + /** \brief Create an OrtLoraAdapter + * + * The function attempts to locate file specified by adapter_file_path, read it and create an OrtLoraAdapter + * instance. The adapter_file_path should be a valid path to a file that contains a valid Lora Adapter + * format. The function attempts to validate the format at load time. The file will always be memory mapped, unless + * the platform does not support memory mapping, in which case the file will be read into memory. + * + * \param[in] adapter_file_path adapter file path. + * \param[in] allocator optional pointer to a device allocator. If specified + * data is copied to the device at some point before Run() is invoked. If nullptr, data stays on CPU. + * The data would still be copied to device if required by the model at inference time. + * \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with + * OrtApi::ReleaseLoraAdapter. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. + */ + ORT_API2_STATUS(CreateLoraAdapter, const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator, + _Outptr_ OrtLoraAdapter** out); + + /** \brief Create an OrtLoraAdapter + * + * The function copies the bytes from the array and creates an OrtLoraAdapter instance. + * + * + * \param[in] bytes pointer to a valid Lora Adapter format buffer. + * \param[in] num_bytes length of bytes buffer. + * \param[in] allocator optional pointer to a device allocator. If specified + * data is copied to the device at some point before Run() is invoked. If nullptr, data stays on CPU. + * The data would still be copied to device if required by the model at inference time. + * \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with + * OrtApi::ReleaseLoraAdapter. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. + */ + ORT_API2_STATUS(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator, + _Outptr_ OrtLoraAdapter** out); + + /** \brief Release an ::OrtLoraAdapter obtained from OrtApi::CreateLoraAdapter + */ + ORT_CLASS_RELEASE(LoraAdapter); + + /** \brief Add the Lora Adapter to the list of active adapters. + * + * The function adds the Lora Adapter to the list of active adapters. The Lora Adapter must be created with + * OrtApi::CreateLoraAdapter or FromArray. The Lora Adapter will be used by the session to run the model. + * The instance of the OrtRunOptions can then be used to customize the Run() calls. + * More than one OrtLoraAdapter can be active at the same time. Lora Parameters that belong to different + * Lora adapters that will be active at the same time must not overlap. + * This setting does not affect RunWithBinding. + * + * \param[in] options OrtRunOptions instance + * \param[in] adapter OrtLoraAdapter instance + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. + */ + ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter); + + /// @} + /// \name OrtEpDynamicOptions + /// @{ + + /** \brief Set DynamicOptions for EPs (Execution Providers) + * + * Valid options can be found in `include\onnxruntime\core\session\onnxruntime_session_options_config_keys.h` + * Look for `kOrtEpDynamicOptions` + * + * \param[in] sess OrtSession + * \param[in] keys Array of null terminated UTF8 encoded strings of EP dynamic option keys + * \param[in] values Array of null terminated UTF8 encoded string of EP dynamic option values + * \param[in] kv_len Number of elements in the keys and values arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. + */ + ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, + _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); + + /** \brief Release an OrtValueInfo instance if it was not added to an OrtGraph. + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(ValueInfo); + + /** \brief Release an OrtNode if it was not added to an OrtGraph. + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(Node); + + /** \brief Release an OrtGraph. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(Graph); + + /** \brief Release an OrtModel. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(Model); + + /** \brief Get the value name from an OrtValueInfo instance. + * \param[in] value_info The OrtValueInfo instance. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_API2_STATUS(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name); + + /** \brief Get the type information from an OrtValueInfo instance. + * \param[in] value_info The OrtValueInfo instance. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_API2_STATUS(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info); + + /** \brief Get the Model Editor API instance + * + * Get the Model Editor API instance to create a new model or augment an existing model. + * + * \return Model Editor API struct + * + * \since Version 1.21. + */ + const OrtModelEditorApi*(ORT_API_CALL* GetModelEditorApi)(); + + /** \brief Create an OrtValue for a Tensor that uses pre-existing memory. + * + * ORT will take ownership of the memory and free it using the provided deleter when no longer in use. + * + * \param[in] deleter OrtAllocator instance that will be used to free the memory. + * Only the OrtAllocator:Info and OrtAllocator::Release functions are required. + * The OrtMemoryInfo returned by OrtAllocator::Info must match the location of p_data. + * \param[in] p_data Pointer to the memory that will be used by the Tensor. ORT will take ownership of the memory. + * \param[in] p_data_len Length of the memory in bytes. + * \param[in] shape Dimensions of the Tensor. All values should be > 0. + * \param[in] shape_len Number of dimensions in the shape array. + * \param[in] type Data type of the Tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, + _In_ void* p_data, size_t p_data_len, + _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); + + /** \brief sets load cancellation flag to abort session loading process. + * + * \param[in] options instance that was passed to the session at creation time. + * \param[in] cancel setting this to true after model loading process was initiated will + * attempt to cancel the loading process. If cancellation is successful, CreateSession() + * CreateSessionFromArray() or any other session creation API that take session options as an + * argument will return an OrtStatus indicating that session loading was canceled at user request, + * error code ORT_MODEL_LOAD_CANCELED. + * The APIs above would not return any valid Session instance. This is the best case effort and the result + * is not guaranteed. The session may have already been created and initialized + * before the cancellation request was issued. + * + * \snippet{doc} snippets.dox OrtStatus + * + */ + ORT_API2_STATUS(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, + _In_ bool cancel); + + const OrtCompileApi*(ORT_API_CALL* GetCompileApi)(); }; /* @@ -4770,6 +5037,561 @@ struct OrtCustomOp { void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); }; +/** + * ORT Model Editor API + */ + +/** + * \brief The OrtModelEditorApi struct provides functions to create or edit an ONNX model. + * + * See onnxruntime/test/shared_lib/test_model_editor_api.cc for example usage. + * + * \since Version 1.21. + */ +struct OrtModelEditorApi { + // Model building/editing requires a full build. We return nullptr from GetModelEditorApi if this is a minimal + // build, so it doesn't matter if there are no function pointers in this struct as a user will never get an + // OrtModelEditorApi instance. We do however need a dummy field to avoid empty struct warning. +#if defined(ORT_MINIMAL_BUILD) + const bool not_defined_in_this_build; +#else + /** \brief Create an OrtTypeInfo instance for a Tensor. + * + * Create an OrtTypeInfo instance for a Tensor to use as graph inputs/outputs with the Model Editor API. + * + * User can release `tensor_info` after creating the OrtTypeInfo. + * + * \param[in] tensor_info Tensor type and shape information. + * \param[out] TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a SparseTensor. + * + * Create an OrtTypeInfo instance for a SparseTensor to use as graph inputs/outputs with the Model Editor API. + * + * User can release `tensor_info` after creating the OrtTypeInfo. + * + * \param[in] tensor_info SparseTensor type and shape information. + * \param[out] TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a Map. + * + * Create an OrtTypeInfo instance for a Map to use as graph inputs/outputs with the Model Editor API. + * + * User can release `map_value_type` after creating the OrtTypeInfo. + * + * \param[in] map_key_type Key type for the map. + * \param[in] map_value_type Value type for the map. + * \param[out] TypeInfo instance for the map. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a Sequence. + * + * Create an OrtTypeInfo instance for a Sequence to use as graph inputs/outputs with the Model Editor API. + * + * User can release `sequence_type` after creating the OrtTypeInfo. + * + * \param[in] sequence_type Sequence type and shape information. + * \param[out] TypeInfo instance for the sequence. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for an Optional. + * + * Create an OrtTypeInfo instance for an Optional to use as graph inputs/outputs with the Model Editor API. + * + * User can release `contained_type` after creating the OrtTypeInfo. + * + * \param[in] tensor_info Tensor type and shape information. + * \param[out] TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtValueInfo for use as an OrtGraph input or output. + * + * \param[in] name The name of the input or output. + * \param[in] type_info The type information for the input or output. The provided value is copied. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, + _Outptr_ OrtValueInfo** value_info); + + /** \brief Create an OrtNode to add to an OrtGraph. + * + * Create an OrtNode. + * + * Create attributes with CreateOpAttr. OrtOpAttr instances are copied. + * + * \param[in] operator_name The name of the operator. + * \param[in] domain_name The domain of the operator. Use an empty string for ONNX operators. + * \param[in] node_name The name of the node. + * \param[in] input_names The names of the inputs. + * \param[in] input_names_len The number of input names. + * \param[in] output_names The names of the outputs. + * \param[in] output_names_len The number of output names. + * \param[in] attributes The optional attributes of the node. + * \param[in] attribs_len The number of attributes. May be zero. + * \param[out] node The OrtNode instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, _In_ const char* domain_name, _In_ const char* node_name, + _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _In_reads_(attribs_len) _In_opt_ OrtOpAttr** attributes, _In_ size_t attribs_len, + _Outptr_ OrtNode** node); + + /** \brief Create an OrtGraph + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateGraph, _Outptr_ OrtGraph** graph); + + /** \brief Set the inputs for the OrtGraph. + * + * Set the graph inputs. This will replace any existing inputs with the new values. + * The OrtGraph takes ownership of the OrtValueInfo instances and you should NOT call ReleaseOrtValueInfo. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] inputs The input OrtValueInfo instances. + * \param[in] inputs_len The number of input OrtValueInfo instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(SetGraphInputs, _Inout_ OrtGraph* graph, + _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); + + /** \brief Set the outputs for the OrtGraph. + * + * Set the graph outputs. This will replace any existing outputs with the new values. + * The OrtGraph takes ownership of the OrtValueInfo instances provided and you should NOT call ReleaseOrtValueInfo. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] outputs The output OrtValueInfo instances. + * \param[in] outputs_len The number of output OrtValueInfo instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(SetGraphOutputs, _Inout_ OrtGraph* graph, + _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); + + /** \brief Add an initializer to the OrtGraph + * + * ORT will take ownership of the OrtValue and you should NOT call ReleaseOrtValue. + * + * Two options: + * + * Allocated memory: + * Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the data. + * Set `data_is_external` to false. + * + * Pre-existing memory: + * Use CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue + * with a tensor that contains a pointer to the existing data. + * Set `data_is_external` to true. + * + * The pointer must remain valid for the duration of the inference session. + * If using CreateTensorWithDataAsOrtValue you are responsible for freeing the memory after the inference session + * is released. + * If using CreateTensorWithDataAndDeleterAsOrtValue, ORT will free the memory using the provided deleter as + * soon as the OrtValue is no longer in use. + * + * NOTE: A tensor containing pre-existing memory MUST have 128 bytes of data or more. + * For smaller tensors use CreateTensorAsOrtValue. + * + * ONNX shape inferencing does not support external data. An initializer involved in shape inferencing is + * typically small (a single value or limited by the rank of a tensor) and uses less than 128 bytes of + * memory, so this limit acts as a simple catch-all rule to avoid issues. + * e.g. Reshape's `shape`, Clip's `min` and `max`, various ops `axes`. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] name The value name for the initializer. + * \param[in] tensor The OrtValue instance containing the tensor data. + * \param[in] data_is_external Set to true if the data is external and should not be copied. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(AddInitializerToGraph, _Inout_ OrtGraph* graph, _In_ const char* name, _In_ OrtValue* tensor, + bool data_is_external); + + /** \brief Add an OrtNode to an OrtGraph + * + * Add the node to the graph. The OrtGraph will take ownership of OrtNode and you should NOT call ReleaseOrtNode. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] node The OrtNode instance to add to the graph. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(AddNodeToGraph, _Inout_ OrtGraph* graph, _In_ OrtNode* node); + + /** \brief Create an OrtModel. + * + * Create an OrtModel. + * + * This can be used to build a new model, or to augment an existing model. + * + * \param[in] domain_names The domain names for the model. + * If augmenting an existing model add additional domains if needed. + * \param[in] opset_versions The opset versions for the model. + * If augmenting an existing model add additional opset versions if needed. + * \param[in] opset_entries_len The number of domain_names and opset_versions entries. + * Domain and opset entries should be 1:1 + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateModel, + _In_reads_(opset_entries_len) const char* const* domain_names, + _In_reads_(opset_entries_len) const int* opset_versions, + size_t opset_entries_len, + _Outptr_ OrtModel** model); + + /** \brief Add an OrtGraph to an OrtModel. + * + * Add the graph to a model. This should be called once when creating a new model. + * + * The OrtModel takes ownership of the OrtGraph and you should NOT call ReleaseOrtGraph. + * + * \param[in] model The OrtModel instance to update. + * \param[in] graph The OrtGraph instance to add to the model. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(AddGraphToModel, _Inout_ OrtModel* model, _In_ OrtGraph* graph); + + /** \brief Create an OrtSession using the OrtModel. + * + * Create an inference session using the OrtModel instance. + * The OrtModel should have been populated with an OrtGraph containing nodes and initializers, and SetGraphInputs + * and SetGraphOutputs must have been called. + * This will validate the model, run optimizers, and prepare the session for inferencing. + * + * ReleaseOrtModel must be called to free the OrtModel after session creation. + * + * \param[in] env The OrtEnv instance. + * \param[in] model The OrtModel instance. + * \param[in] options The OrtSessionOptions instance. + * \param[out] out The OrtSession instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + + /** \brief Create an OrtSession to augment an existing model. + * + * Create an OrtSession with an existing model that will be augmented with additional nodes and initializers. + * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the + * model is finalized. + * + * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. + * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. + * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made + * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. + * + * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the + * session for inferencing by calling FinalizeModelEditorSession. + * + * \param{in} env The OrtEnv instance. + * \param{in} model_path The path to the existing ONNX model to augment. + * \param{in} options The OrtSessionOptions instance. + * \param{out} out The created OrtSession instance. + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateModelEditorSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + + /** \brief Create an OrtSession to augment an existing model. + * + * Create an OrtSession with an existing model that will be augmented with additional nodes and initializers. + * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the + * model is finalized. + * + * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. + * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. + * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made + * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. + * + * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the + * session for inferencing by calling FinalizeModelEditorSession. + * + * \param{in} env The OrtEnv instance. + * \param{in} model_data The model data for the existing model to augment. + * \param{in} model_data_length The length of the model data. + * \param{in} options The OrtSessionOptions instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateModelEditorSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + + /** \brief Query the session for the opset version of a domain. + * + * When using the Model Editor API to augment a model, any new nodes must conform to the opset version of the + * original model. To do that the user must be able to discover that opset version. + * + * \param[in] session OrtSession to query + * \param[in] domain Domain to query. The ONNX domain is an empty string. + * \param[out] opset The opset version of the domain. + * + * \snippet{doc} snippets.dox OrtStatus Return Value. Returns an error if the domain is not used in the model. + * + * \since Version 1.21. + */ + ORT_API2_STATUS(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, _Out_ int* opset); + + /** \brief Apply changes to augment the ONNX model in a session created using CreateModelEditorSession[FromArray] + * + * Adds new nodes and updates graph inputs/outputs using `model` to augment the original ONNX model in the session. + * All changes will be validated. + * Call FinalizeModelEditorSession to prepare the session for inferencing. + * + * Existing input/outputs will only be updated if the OrtGraph inputs/outputs are set in the OrtModel. + * i.e. you don't need to call SetGraphInputs/SetGraphOutputs if they are unchanged. + * + * ReleaseOrtModel must be called to free the OrtModel after it is applied to the session. + * + * \param[in] session OrtSession to update. Session must have been created using CreateModelEditorSession[FromArray]. + * \param[in] model OrtModel containing new nodes, new initializers, and updated graph input and/or output info. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(ApplyModelToModelEditorSession, _Inout_ OrtSession* session, _In_ OrtModel* model); + + /** \brief Finalize the Model Editor session that was created using CreateModelEditorSession[FromArray]. + * + * Finalize the Model Editor session that augmented an ONNX model by adding new nodes. + * This will run optimizers and prepare the session for inferencing. + * + * \param[in] session OrtSession to finalize. Session must have been created using CreateModelEditorSession[FromArray]. + * \param[in] options OrtSessionOptions to use for the session. + * \param[in] Optional prepacked_weights_container OrtPrepackedWeightsContainer to use for the session. + Set to nullptr if not used. + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(FinalizeModelEditorSession, _Inout_ OrtSession* session, _In_ const OrtSessionOptions* options, + _In_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container); +#endif // !defined(ORT_MINIMAL_BUILD) +}; + +/** + * ORT Compile API + */ + +/** + * \brief The OrtCompileApi struct provides functions to compile ONNX models. + * + * Execution providers that support compilation fuse a subgraph into an EPContext node that wraps a provider-specific + * binary representation of the subgraph. + * See \href https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html for EPContext details. + * + * \since Version 1.22. + */ +struct OrtCompileApi { + // Model compilation requires a full build. We return nullptr from GetCompileApi if this is a minimal + // build, so it doesn't matter if there are no function pointers in this struct as a user will never get an + // OrtCompileApi instance. We do however need a dummy field to avoid empty struct warning. +#if defined(ORT_MINIMAL_BUILD) + const bool not_defined_in_this_build; +#else + /// @} + /// \name OrtModelCompilationOptions + /// @{ + ORT_CLASS_RELEASE(ModelCompilationOptions); + + /** \brief Creates an OrtModelCompilationOptions object from an existing OrtSessionOptions object. + * + * An OrtModelCompilationOptions object contains the settings used to generate a compiled ONNX model. + * The OrtSessionOptions object has the execution providers with which the model will be compiled. + * + * ReleaseOrtModelCompilationsOptions must be called to free the OrtModelCompilationOptions after calling + * CompileModel. + * + * \param[in] env OrtEnv object. + * \param[in] session_options The OrtSessionOptions instance from which to create the OrtModelCompilationOptions. + * \param[out] out The created OrtModelCompilationOptions instance. + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateModelCompilationOptionsFromSessionOptions, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* session_options, _Outptr_ OrtModelCompilationOptions** out); + + /** \brief Sets the file path to the input ONNX model to compile. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] input_model_path Null terminated string of the path (wchar on Windows, char otherwise). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetInputModelPath, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const ORTCHAR_T* input_model_path); + + /** \brief Sets the buffer that stores the bytes of the loaded ONNX model to compile. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] input_model_data Buffer containing the loaded ONNX model bytes. + * \param[in] input_model_data_size The number of bytes in the `input_model_data` buffer. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetInputModelFromBuffer, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const void* input_model_data, + size_t input_model_data_size); + + /** \brief Sets the file path for the output ONNX model generated by CompileModel. + * + * If the output model path is not specified and the output model is not to be stored in a buffer, + * ONNX Runtime will generate a path based on the input model's file path. + * Examples: + * /Path/my_model.onnx -> /Path/my_model_ctx.onnx + * /Path/my_model -> /Path/my_model_ctx.onnx + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] input_model_path Null terminated string of the path (wchar on Windows, char otherwise). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelPath, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const ORTCHAR_T* output_model_path); + + /** \brief Optionally sets the file that should store external initializers for the compiled ONNX model. + * If not set, initializers are stored within the model. + * + * Only initializers for nodes that were not compiled are stored in the external initializers file. + * Compiled nodes contain their initializer data within the `ep_cache_context` attribute of EPContext nodes. + * Refer to ModelCompilationOptions_SetEpContextEmbedMode. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] external_initializers_file_path Null terminated string of the path to the file. + * \param[in] external_initializers_size_threshold Initializers larger than this threshold are stored in the file. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelExternalInitializersFile, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const ORTCHAR_T* external_initializers_file_path, + size_t external_initializer_size_threshold); + + /** \brief Configures model compilation to store the output compiled ONNX model in a buffer. + * + * The caller passes an OrtAllocator that ONNX Runtime uses to allocate memory for the buffer. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] allocator The allocator used to allocate the buffer for the compiled model. + * \param[out] output_model_buffer_ptr Pointer to the buffer that stores the compiled model. + * \param[out] output_model_buffer_size_ptr Pointer set to the size of output buffer in bytes. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelBuffer, + _In_ OrtModelCompilationOptions* model_compile_options, + _Inout_ OrtAllocator* allocator, + _Outptr_ void** output_model_buffer_ptr, + _Out_ size_t* output_model_buffer_size_ptr); + + /** \brief Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute + * of EPContext nodes. Defaults to false. + * + * If enabled, the `ep_cache_context` attribute of EPContext nodes will store the context binary data, which may + * include weights for compiled subgraphs. + * + * If disabled, the `ep_cache_context` attribute of EPContext nodes will contain the path to the file containing the + * context binary data. The path is set by the execution provider creating the EPContext node. + * + * See \href https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html for EPContext details. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] embed_ep_context_in_model True to embed EPContext binary data into the EPContext node + * `ep_cache_context` attributes. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModelCompilationOptions* model_compile_options, + bool embed_ep_context_in_model); + + /** \brief Compiles an input ONNX model with the given compilation options. + * + * \param[in] env OrtEnv object. + * \param[in] model_compile_options The compilation options that defines compilation options for a model. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); +#endif +}; /* * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_cxx_api.h b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_cxx_api.h new file mode 100644 index 000000000..b5d5855a4 --- /dev/null +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_cxx_api.h @@ -0,0 +1,2743 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Summary: The Ort C++ API is a header only wrapper around the Ort C API. +// +// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors +// and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so +// all the resources follow RAII and do not leak memory. +// +// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers. +// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them +// until you assign an instance that actually holds an underlying object. +// +// For Ort objects only move assignment between objects is allowed, there are no copy constructors. +// Some objects have explicit 'Clone' methods for this purpose. +// +// ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments +// by value or by reference. ConstXXXX types are restricted to const only interfaces. +// +// UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces. +// +// The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not +// have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code. + +#pragma once +#include "onnxruntime_c_api.h" +#include "onnxruntime_float16.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef ORT_NO_EXCEPTIONS +#include +#endif + +/** \brief All C++ Onnxruntime APIs are defined inside this namespace + * + */ +namespace Ort { + +/** \brief All C++ methods that can fail will throw an exception of this type + * + * If ORT_NO_EXCEPTIONS is defined, then any error will result in a call to abort() + */ +struct Exception : std::exception { + Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} + + OrtErrorCode GetOrtErrorCode() const { return code_; } + const char* what() const noexcept override { return message_.c_str(); } + + private: + std::string message_; + OrtErrorCode code_; +}; + +#ifdef ORT_NO_EXCEPTIONS +// The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors. +// NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW +#ifndef ORT_CXX_API_THROW +#define ORT_CXX_API_THROW(string, code) \ + do { \ + std::cerr << Ort::Exception(string, code) \ + .what() \ + << std::endl; \ + abort(); \ + } while (false) +#endif +#else +#define ORT_CXX_API_THROW(string, code) \ + throw Ort::Exception(string, code) +#endif + +// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, +// it's in a template so that we can define a global variable in a header and make +// it transparent to the users of the API. +template +struct Global { + static const OrtApi* api_; +}; + +// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it. +template +#ifdef ORT_API_MANUAL_INIT +const OrtApi* Global::api_{}; +inline void InitApi() noexcept { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } + +// Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is +// required by C++ APIs. +// +// Example mycustomop.cc: +// +// #define ORT_API_MANUAL_INIT +// #include +// #undef ORT_API_MANUAL_INIT +// +// OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) { +// Ort::InitApi(api_base->GetApi(ORT_API_VERSION)); +// // ... +// } +// +inline void InitApi(const OrtApi* api) noexcept { Global::api_ = api; } +#else +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers. +// Please define ORT_API_MANUAL_INIT if it conerns you. +#pragma warning(disable : 26426) +#endif +const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif +#endif + +/// This returns a reference to the ORT C API. +inline const OrtApi& GetApi() noexcept { return *Global::api_; } + +/// +/// This function returns the onnxruntime version string +/// +/// version string major.minor.rev +std::string GetVersionString(); + +/// +/// This function returns the onnxruntime build information: including git branch, +/// git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags. +/// +/// string +std::string GetBuildInfoString(); + +/// +/// This is a C++ wrapper for OrtApi::GetAvailableProviders() and +/// returns a vector of strings representing the available execution providers. +/// +/// vector of strings +std::vector GetAvailableProviders(); + +/// +/// This returns a reference to the ORT C Model Editor API. Used if building or augmenting a model at runtime. +/// +/// ORT C Model Editor API reference +inline const OrtModelEditorApi& GetModelEditorApi() { + auto* api = GetApi().GetModelEditorApi(); + if (api == nullptr) { + // minimal build + ORT_CXX_API_THROW("Model Editor API is not available in this build", ORT_FAIL); + } + + return *api; +} + +/// +/// This returns a reference to the ORT C Compile API. Used if compiling a model at runtime. +/// +/// ORT C Compile API reference +inline const OrtCompileApi& GetCompileApi() { + auto* api = GetApi().GetCompileApi(); + if (api == nullptr) { + // minimal build + ORT_CXX_API_THROW("Compile API is not available in this build", ORT_FAIL); + } + + return *api; +} + +/** \brief IEEE 754 half-precision floating point data type + * + * \details This struct is used for converting float to float16 and back + * so the user could feed inputs and fetch outputs using these type. + * + * The size of the structure should align with uint16_t and one can freely cast + * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. + * + * \code{.unparsed} + * // This example demonstrates converion from float to float16 + * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; + * std::vector fp16_values; + * fp16_values.reserve(std::size(values)); + * std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values), + * [](float value) { return Ort::Float16_t(value); }); + * + * \endcode + */ +struct Float16_t : onnxruntime_float16::Float16Impl { + private: + /// + /// Constructor from a 16-bit representation of a float16 value + /// No conversion is done here. + /// + /// 16-bit representation + constexpr explicit Float16_t(uint16_t v) noexcept { val = v; } + + public: + using Base = onnxruntime_float16::Float16Impl; + + /// + /// Default constructor + /// + Float16_t() = default; + + /// + /// Explicit conversion to uint16_t representation of float16. + /// + /// uint16_t bit representation of float16 + /// new instance of Float16_t + constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); } + + /// + /// __ctor from float. Float is converted into float16 16-bit representation. + /// + /// float value + explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); } + + /// + /// Converts float16 to float + /// + /// float representation of float16 value + float ToFloat() const noexcept { return Base::ToFloatImpl(); } + + /// + /// Checks if the value is negative + /// + /// true if negative + using Base::IsNegative; + + /// + /// Tests if the value is NaN + /// + /// true if NaN + using Base::IsNaN; + + /// + /// Tests if the value is finite + /// + /// true if finite + using Base::IsFinite; + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + using Base::IsPositiveInfinity; + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + using Base::IsNegativeInfinity; + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + using Base::IsInfinity; + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + using Base::IsNaNOrZero; + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + using Base::IsNormal; + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + using Base::IsSubnormal; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + using Base::Abs; + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + using Base::Negate; + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + using Base::AreZero; + + /// + /// User defined conversion operator. Converts Float16_t to float. + /// + explicit operator float() const noexcept { return ToFloat(); } + + using Base::operator==; + using Base::operator!=; + using Base::operator<; +}; + +static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); + +/** \brief bfloat16 (Brain Floating Point) data type + * + * \details This struct is used for converting float to bfloat16 and back + * so the user could feed inputs and fetch outputs using these type. + * + * The size of the structure should align with uint16_t and one can freely cast + * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. + * + * \code{.unparsed} + * // This example demonstrates converion from float to float16 + * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; + * std::vector bfp16_values; + * bfp16_values.reserve(std::size(values)); + * std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values), + * [](float value) { return Ort::BFloat16_t(value); }); + * + * \endcode + */ +struct BFloat16_t : onnxruntime_float16::BFloat16Impl { + private: + /// + /// Constructor from a uint16_t representation of bfloat16 + /// used in FromBits() to escape overload resolution issue with + /// constructor from float. + /// No conversion is done. + /// + /// 16-bit bfloat16 value + constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; } + + public: + using Base = onnxruntime_float16::BFloat16Impl; + + BFloat16_t() = default; + + /// + /// Explicit conversion to uint16_t representation of bfloat16. + /// + /// uint16_t bit representation of bfloat16 + /// new instance of BFloat16_t + static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); } + + /// + /// __ctor from float. Float is converted into bfloat16 16-bit representation. + /// + /// float value + explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); } + + /// + /// Converts bfloat16 to float + /// + /// float representation of bfloat16 value + float ToFloat() const noexcept { return Base::ToFloatImpl(); } + + /// + /// Checks if the value is negative + /// + /// true if negative + using Base::IsNegative; + + /// + /// Tests if the value is NaN + /// + /// true if NaN + using Base::IsNaN; + + /// + /// Tests if the value is finite + /// + /// true if finite + using Base::IsFinite; + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + using Base::IsPositiveInfinity; + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + using Base::IsNegativeInfinity; + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + using Base::IsInfinity; + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + using Base::IsNaNOrZero; + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + using Base::IsNormal; + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + using Base::IsSubnormal; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + using Base::Abs; + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + using Base::Negate; + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + using Base::AreZero; + + /// + /// User defined conversion operator. Converts BFloat16_t to float. + /// + explicit operator float() const noexcept { return ToFloat(); } + + // We do not have an inherited impl for the below operators + // as the internal class implements them a little differently + bool operator==(const BFloat16_t& rhs) const noexcept; + bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); } + bool operator<(const BFloat16_t& rhs) const noexcept; +}; + +static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); + +/** \brief float8e4m3fn (Float8 Floating Point) data type + * \details It is necessary for type dispatching to make use of C++ API + * The type is implicitly convertible to/from uint8_t. + * See https://onnx.ai/onnx/technical/float8.html for further details. + */ +struct Float8E4M3FN_t { + uint8_t value; + constexpr Float8E4M3FN_t() noexcept : value(0) {} + constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {} + constexpr operator uint8_t() const noexcept { return value; } + // nan values are treated like any other value for operator ==, != + constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; }; + constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; }; +}; + +static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match"); + +/** \brief float8e4m3fnuz (Float8 Floating Point) data type + * \details It is necessary for type dispatching to make use of C++ API + * The type is implicitly convertible to/from uint8_t. + * See https://onnx.ai/onnx/technical/float8.html for further details. + */ +struct Float8E4M3FNUZ_t { + uint8_t value; + constexpr Float8E4M3FNUZ_t() noexcept : value(0) {} + constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {} + constexpr operator uint8_t() const noexcept { return value; } + // nan values are treated like any other value for operator ==, != + constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; }; + constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; }; +}; + +static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match"); + +/** \brief float8e5m2 (Float8 Floating Point) data type + * \details It is necessary for type dispatching to make use of C++ API + * The type is implicitly convertible to/from uint8_t. + * See https://onnx.ai/onnx/technical/float8.html for further details. + */ +struct Float8E5M2_t { + uint8_t value; + constexpr Float8E5M2_t() noexcept : value(0) {} + constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {} + constexpr operator uint8_t() const noexcept { return value; } + // nan values are treated like any other value for operator ==, != + constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; }; + constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; }; +}; + +static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match"); + +/** \brief float8e5m2fnuz (Float8 Floating Point) data type + * \details It is necessary for type dispatching to make use of C++ API + * The type is implicitly convertible to/from uint8_t. + * See https://onnx.ai/onnx/technical/float8.html for further details. + */ +struct Float8E5M2FNUZ_t { + uint8_t value; + constexpr Float8E5M2FNUZ_t() noexcept : value(0) {} + constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {} + constexpr operator uint8_t() const noexcept { return value; } + // nan values are treated like any other value for operator ==, != + constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; }; + constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; }; +}; + +static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match"); + +namespace detail { +// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type +// This can't be done in the C API since C doesn't have function overloading. +#define ORT_DEFINE_RELEASE(NAME) \ + inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); } + +#define ORT_DEFINE_RELEASE_FROM_API_STRUCT(NAME, API_GETTER) \ + inline void OrtRelease(Ort##NAME* ptr) { API_GETTER().Release##NAME(ptr); } + +ORT_DEFINE_RELEASE(Allocator); +ORT_DEFINE_RELEASE(MemoryInfo); +ORT_DEFINE_RELEASE(CustomOpDomain); +ORT_DEFINE_RELEASE(ThreadingOptions); +ORT_DEFINE_RELEASE(Env); +ORT_DEFINE_RELEASE(RunOptions); +ORT_DEFINE_RELEASE(LoraAdapter); +ORT_DEFINE_RELEASE(Session); +ORT_DEFINE_RELEASE(SessionOptions); +ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); +ORT_DEFINE_RELEASE(SequenceTypeInfo); +ORT_DEFINE_RELEASE(MapTypeInfo); +ORT_DEFINE_RELEASE(TypeInfo); +ORT_DEFINE_RELEASE(Value); +ORT_DEFINE_RELEASE(ModelMetadata); +ORT_DEFINE_RELEASE(IoBinding); +ORT_DEFINE_RELEASE(ArenaCfg); +ORT_DEFINE_RELEASE(Status); +ORT_DEFINE_RELEASE(OpAttr); +ORT_DEFINE_RELEASE(Op); +ORT_DEFINE_RELEASE(KernelInfo); +ORT_DEFINE_RELEASE(ValueInfo); +ORT_DEFINE_RELEASE(Node); +ORT_DEFINE_RELEASE(Graph); +ORT_DEFINE_RELEASE(Model); +#if !defined(ORT_MINIMAL_BUILD) +ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi); +#endif // !defined(ORT_MINIMAL_BUILD) + +#undef ORT_DEFINE_RELEASE + +/** \brief This is a tagging template type. Use it with Base to indicate that the C++ interface object + * has no ownership of the underlying C object. + */ +template +struct Unowned { + using Type = T; +}; + +/** \brief Used internally by the C++ API. C++ wrapper types inherit from this. + * This is a zero cost abstraction to wrap the C API objects and delete them on destruction. + * + * All of the C++ classes + * a) serve as containers for pointers to objects that are created by the underlying C API. + * Their size is just a pointer size, no need to dynamically allocate them. Use them by value. + * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects. + * they would release objects owned automatically when going out of scope, they are move-only. + * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers. + * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else + * such as Onnxruntime or instances of XXXX classes. + * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used + * in C++ code. + * + */ + +/// +/// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction. +/// +template +struct Base { + using contained_type = T; + + constexpr Base() = default; + constexpr explicit Base(contained_type* p) noexcept : p_{p} {} + ~Base() { + OrtRelease(p_); + } + + Base(const Base&) = delete; + Base& operator=(const Base&) = delete; + + Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } + Base& operator=(Base&& v) noexcept { + OrtRelease(p_); + p_ = v.release(); + return *this; + } + + constexpr operator contained_type*() const noexcept { return p_; } + + /// \brief Relinquishes ownership of the contained C object pointer + /// The underlying object is not destroyed + contained_type* release() { + T* p = p_; + p_ = nullptr; + return p; + } + + protected: + contained_type* p_{}; +}; + +// Undefined. For const types use Base> +template +struct Base; + +/// +/// Covers unowned pointers owned by either the ORT +/// or some other instance of CPP wrappers. +/// Used for ConstXXX and UnownedXXXX types that are copyable. +/// Also convenient to wrap raw OrtXX pointers . +/// +/// +template +struct Base> { + using contained_type = typename Unowned::Type; + + constexpr Base() = default; + constexpr explicit Base(contained_type* p) noexcept : p_{p} {} + + ~Base() = default; + + Base(const Base&) = default; + Base& operator=(const Base&) = default; + + Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } + Base& operator=(Base&& v) noexcept { + p_ = nullptr; + std::swap(p_, v.p_); + return *this; + } + + constexpr operator contained_type*() const noexcept { return p_; } + + protected: + contained_type* p_{}; +}; + +// Light functor to release memory with OrtAllocator +struct AllocatedFree { + OrtAllocator* allocator_; + explicit AllocatedFree(OrtAllocator* allocator) + : allocator_(allocator) {} + void operator()(void* ptr) const { + if (ptr) allocator_->Free(allocator_, ptr); + } +}; + +} // namespace detail + +struct AllocatorWithDefaultOptions; +struct Env; +struct Graph; +struct Model; +struct Node; +struct ModelMetadata; +struct TypeInfo; +struct Value; +struct ValueInfo; + +/** \brief unique_ptr typedef used to own strings allocated by OrtAllocators + * and release them at the end of the scope. The lifespan of the given allocator + * must eclipse the lifespan of AllocatedStringPtr instance + */ +using AllocatedStringPtr = std::unique_ptr; + +/** \brief The Status that holds ownership of OrtStatus received from C API + * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate + * constructors to construct an instance of a Status object from exceptions. + */ +struct Status : detail::Base { + using Base = detail::Base; + using Base::Base; + + explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used + explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. + explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception + explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception + Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message. + std::string GetErrorMessage() const; + OrtErrorCode GetErrorCode() const; + bool IsOK() const noexcept; ///< Returns true if instance represents an OK (non-error) status. +}; + +/** \brief The ThreadingOptions + * + * The ThreadingOptions used for set global threadpools' options of The Env. + */ +struct ThreadingOptions : detail::Base { + /// \brief Wraps OrtApi::CreateThreadingOptions + ThreadingOptions(); + + /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads + ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads); + + /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads + ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads); + + /// \brief Wraps OrtApi::SetGlobalSpinControl + ThreadingOptions& SetGlobalSpinControl(int allow_spinning); + + /// \brief Wraps OrtApi::SetGlobalDenormalAsZero + ThreadingOptions& SetGlobalDenormalAsZero(); + + /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn + ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); + + /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions + ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options); + + /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn + ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); +}; + +/** \brief The Env (Environment) + * + * The Env holds the logging state used by all other objects. + * Note: One Env must be created before using any other Onnxruntime functionality + */ +struct Env : detail::Base { + explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used + + /// \brief Wraps OrtApi::CreateEnv + Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + + /// \brief Wraps OrtApi::CreateEnvWithCustomLogger + Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param); + + /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools + Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + + /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools + Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, + OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + + /// \brief C Interop Helper + explicit Env(OrtEnv* p) : Base{p} {} + + Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents + Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents + + Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel + + Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator + + Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2 +}; + +/** \brief Custom Op Domain + * + */ +struct CustomOpDomain : detail::Base { + using Base = detail::Base; + using Base::Base; + + explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used + + /// \brief Wraps OrtApi::CreateCustomOpDomain + explicit CustomOpDomain(const char* domain); + + // This does not take ownership of the op, simply registers it. + void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add +}; + +/// \brief LoraAdapter holds a set of Lora Parameters loaded from a single file +struct LoraAdapter : detail::Base { + using Base = detail::Base; + using Base::Base; + + explicit LoraAdapter(std::nullptr_t) {} ///< Create an empty LoraAdapter object, must be assigned a valid one to be used + /// \brief Wraps OrtApi::CreateLoraAdapter + /// + /// The function attempts to load the adapter from the specified file + /// \param adapter_path The path to the Lora adapter + /// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still + /// be copied to device if required by the model at inference time. + static LoraAdapter CreateLoraAdapter(const std::basic_string& adapter_path, + OrtAllocator* allocator); + + /// \brief Wraps OrtApi::CreateLoraAdapterFromArray + /// + /// The function attempts to load the adapter from the specified byte array. + /// \param bytes The byte array containing file LoraAdapter format + /// \param num_bytes The number of bytes in the byte array + /// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still + /// be copied to device if required by the model at inference time. + static LoraAdapter CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes, + OrtAllocator* allocator); +}; + +/** \brief RunOptions + * + */ +struct RunOptions : detail::Base { + explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used + RunOptions(); ///< Wraps OrtApi::CreateRunOptions + + RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel + int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel + + RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel + int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel + + RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag + const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag + + RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry + + /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance + * + * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error + * Wraps OrtApi::RunOptionsSetTerminate + */ + RunOptions& SetTerminate(); + + /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating + * + * Wraps OrtApi::RunOptionsUnsetTerminate + */ + RunOptions& UnsetTerminate(); + + /** \brief Add the LoraAdapter to the list of active adapters. + * The setting does not affect RunWithBinding() calls. + * + * Wraps OrtApi::RunOptionsAddActiveLoraAdapter + * \param adapter The LoraAdapter to be used as the active adapter + */ + RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter); +}; + +namespace detail { +// Utility function that returns a SessionOption config entry key for a specific custom operator. +// Ex: custom_op.[custom_op_name].[config] +std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config); +} // namespace detail + +/// +/// Class that represents session configuration entries for one or more custom operators. +/// +/// Example: +/// Ort::CustomOpConfigs op_configs; +/// op_configs.AddConfig("my_custom_op", "device_type", "CPU"); +/// +/// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary. +/// +struct CustomOpConfigs { + CustomOpConfigs() = default; + ~CustomOpConfigs() = default; + CustomOpConfigs(const CustomOpConfigs&) = default; + CustomOpConfigs& operator=(const CustomOpConfigs&) = default; + CustomOpConfigs(CustomOpConfigs&& o) = default; + CustomOpConfigs& operator=(CustomOpConfigs&& o) = default; + + /** \brief Adds a session configuration entry/value for a specific custom operator. + * + * \param custom_op_name The name of the custom operator for which to add a configuration entry. + * Must match the name returned by the CustomOp's GetName() method. + * \param config_key The name of the configuration entry. + * \param config_value The value of the configuration entry. + * \return A reference to this object to enable call chaining. + */ + CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value); + + /** \brief Returns a flattened map of custom operator configuration entries and their values. + * + * The keys has been flattened to include both the custom operator name and the configuration entry key name. + * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair + * {"my_op.key", "value"}. + * + * \return An unordered map of flattened configurations. + */ + const std::unordered_map& GetFlattenedConfigs() const; + + private: + std::unordered_map flat_configs_; +}; + +/** \brief Options object used when creating a new Session object + * + * Wraps ::OrtSessionOptions object and methods + */ + +struct SessionOptions; + +namespace detail { +// we separate const-only methods because passing const ptr to non-const methods +// is only discovered when inline methods are compiled which is counter-intuitive +template +struct ConstSessionOptionsImpl : Base { + using B = Base; + using B::B; + + SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions + + std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry + bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry + std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def); +}; + +template +struct SessionOptionsImpl : ConstSessionOptionsImpl { + using B = ConstSessionOptionsImpl; + using B::B; + + SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads + SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads + SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel + SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute + + SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena + SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena + + SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath + + SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling + SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling + + SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps + + SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern + SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern + + SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode + + SessionOptionsImpl& SetLoadCancellationFlag(bool value); ///< Wraps OrtApi::SessionOptionsSetLoadCancellationFlag + + SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId + SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel + + SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain + + SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads + + SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry + + SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer + SessionOptionsImpl& AddExternalInitializers(const std::vector& names, const std::vector& ort_values); ///< Wraps OrtApi::AddExternalInitializers + SessionOptionsImpl& AddExternalInitializersFromFilesInMemory(const std::vector>& external_initializer_file_names, + const std::vector& external_initializer_file_buffer_array, + const std::vector& external_initializer_file_lengths); ///< Wraps OrtApi::AddExternalInitializersFromFilesInMemory + + SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA + SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2 + SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM + SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2 + SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map& provider_options = {}); + SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT + SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT + SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN + SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl + SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK. + SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, + const std::unordered_map& provider_options = {}); + + SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn + SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions + SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn + + ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2. + ///< The custom operator configurations are optional. If provided, custom operator configs are set via + ///< OrtApi::AddSessionConfigEntry. + SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {}); + + SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction + + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI + SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options = {}); +}; +} // namespace detail + +using UnownedSessionOptions = detail::SessionOptionsImpl>; +using ConstSessionOptions = detail::ConstSessionOptionsImpl>; + +/** \brief Wrapper around ::OrtSessionOptions + * + */ +struct SessionOptions : detail::SessionOptionsImpl { + explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used + SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions + explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl{p} {} ///< Used for interop with the C API + UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; } + ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; } +}; + +#if !defined(ORT_MINIMAL_BUILD) +/** \brief Options object used when compiling a model. + * + * Wraps ::OrtModelCompilationOptions object and methods + */ +struct ModelCompilationOptions : detail::Base { + using Base = detail::Base; + using Base::Base; + + explicit ModelCompilationOptions(std::nullptr_t) {} ///< Create an empty ModelCompilationOptions object, must be assigned a valid one to be used. + explicit ModelCompilationOptions(OrtModelCompilationOptions* p) ///< Takes ownership of an OrtModelCompilationOptions + : detail::Base{p} {} + + ModelCompilationOptions(const Env& env, const SessionOptions& session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions + ModelCompilationOptions(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions + + ModelCompilationOptions& SetInputModelPath(const ORTCHAR_T* input_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelPath + ModelCompilationOptions& SetInputModelFromBuffer(const void* input_model_data, + size_t input_model_data_size); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelFromBuffer + ModelCompilationOptions& SetEpContextEmbedMode(bool embed_ep_context_in_model); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextEmbedMode + ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath + ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path, + size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile + ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, + size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer +}; + +/** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels. + * + * \param env: ORT environment object. + * \param model_compilation_options: Compilation options for a model. + * \return A Status indicating success or failure. + */ +Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options); +#endif // !defined(ORT_MINIMAL_BUILD) + +/** \brief Wrapper around ::OrtModelMetadata + * + */ +struct ModelMetadata : detail::Base { + using Base = detail::Base; + using Base::Base; + + explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used + + /** \brief Returns a copy of the producer name. + * + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName + + /** \brief Returns a copy of the graph name. + * + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName + + /** \brief Returns a copy of the domain name. + * + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain + + /** \brief Returns a copy of the description. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription + + /** \brief Returns a copy of the graph description. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription + + /** \brief Returns a vector of copies of the custom metadata keys. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope. + * The OrtAllocator instance must be valid at the point of memory release. + */ + std::vector GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys + + /** \brief Looks up a value by a key in the Custom Metadata map + * + * \param key zero terminated string key to lookup + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * maybe nullptr if key is not found. + * + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap + + int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion +}; + +struct IoBinding; + +namespace detail { + +// we separate const-only methods because passing const ptr to non-const methods +// is only discovered when inline methods are compiled which is counter-intuitive +template +struct ConstSessionImpl : Base { + using B = Base; + using B::B; + + size_t GetInputCount() const; ///< Returns the number of model inputs + size_t GetOutputCount() const; ///< Returns the number of model outputs + size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden + + std::vector GetInputNames() const; + std::vector GetOutputNames() const; + std::vector GetOverridableInitializerNames() const; + + /** \brief Returns a copy of input name at the specified index. + * + * \param index must less than the value returned by GetInputCount() + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const; + + /** \brief Returns a copy of output name at then specified index. + * + * \param index must less than the value returned by GetOutputCount() + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const; + + /** \brief Returns a copy of the overridable initializer name at then specified index. + * + * \param index must less than the value returned by GetOverridableInitializerCount() + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName + + uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs + ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata + + TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo + TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo + TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo + + int GetOpset(const std::string& domain) const; ///< Wraps OrtApi::SessionGetOpsetForDomain + + // Will move before checkin if that's the case. + std::vector GetInputs() const; + std::vector GetOutputs() const; +}; + +template +struct SessionImpl : ConstSessionImpl { + using B = ConstSessionImpl; + using B::B; + + /** \brief Run the model returning results in an Ort allocated vector. + * + * Wraps OrtApi::Run + * + * The caller provides a list of inputs and a list of the desired outputs to return. + * + * See the output logs for more information on warnings/errors that occur while processing the model. + * Common errors are.. (TODO) + * + * \param[in] run_options + * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names + * \param[in] input_values Array of Value objects of length input_count that is the list of input values + * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays) + * \param[in] output_names Array of C style strings of length output_count that is the list of output names + * \param[in] output_count Number of outputs (the size of the output_names array) + * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector) + */ + std::vector Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, size_t output_count); + + /** \brief Run the model returning results in user provided outputs + * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t) + */ + void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count); + + void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding + + /** \brief Run the model asynchronously in a thread owned by intra op thread pool + * + * Wraps OrtApi::RunAsync + * + * \param[in] run_options + * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names + * \param[in] input_values Array of Value objects of length input_count + * \param[in] input_count Number of elements in the input_names and inputs arrays + * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names + * \param[out] output_values Array of provided Values to be filled with outputs. + * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*. + * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime. + * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback. + * NOTE: it is customer's duty to finally release output_values and each of its member, + * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer. + * \param[in] output_count Number of elements in the output_names and outputs array + * \param[in] callback Callback function on model run completion + * \param[in] user_data User data that pass back to the callback + */ + void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data); + + /** \brief End profiling and return a copy of the profiling file name. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling + + /** \brief Set DynamicOptions for EPs (Execution Providers) + * + * Wraps OrtApi::SetEpDynamicOptions + * + * Valid options can be found in `include\onnxruntime\core\session\onnxruntime_session_options_config_keys.h` + * Look for `kOrtEpDynamicOptions` + * + * \param[in] keys Array of null terminated UTF8 encoded strings of EP dynamic option keys + * \param[in] values Array of null terminated UTF8 encoded string of EP dynamic option values + * \param[in] kv_len Number of elements in the keys and values arrays + */ + void SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len); + + void FinalizeModelEditorSession(const Model& model, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr); +}; + +} // namespace detail + +using ConstSession = detail::ConstSessionImpl>; +using UnownedSession = detail::SessionImpl>; + +/** \brief Wrapper around ::OrtSession + * + */ +struct Session : detail::SessionImpl { + /// Create an empty Session object, must be assigned a valid one to be used. Wraps OrtApi::CreateSession + explicit Session(std::nullptr_t) {} + explicit Session(OrtSession* p) : SessionImpl{p} {} ///< C API Interop + + Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); + + /// Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer + Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container); + + /// Wraps OrtApi::CreateSessionFromArray + Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); + + /// Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer + Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container); + +#if !defined(ORT_MINIMAL_BUILD) + /// Wraps OrtModelEditorApi::CreateSessionFromModel + Session(const Env& env, const Model& model, const SessionOptions& options); + + /// Wraps OrtModelEditorApi::CreateModelEditorSession + static Session CreateModelEditorSession(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); + + /// Wraps OrtModelEditorApi::CreateModelEditorSession + static Session CreateModelEditorSession(const Env& env, const void* model_data, size_t model_data_length, + const SessionOptions& options); +#endif // !defined(ORT_MINIMAL_BUILD) + + ConstSession GetConst() const { return ConstSession{this->p_}; } + UnownedSession GetUnowned() const { return UnownedSession{this->p_}; } +}; + +namespace detail { +template +struct MemoryInfoImpl : Base { + using B = Base; + using B::B; + + std::string GetAllocatorName() const; + OrtAllocatorType GetAllocatorType() const; + int GetDeviceId() const; + OrtMemoryInfoDeviceType GetDeviceType() const; + OrtMemType GetMemoryType() const; + + template + bool operator==(const MemoryInfoImpl& o) const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstMemoryInfo = detail::MemoryInfoImpl>; + +/** \brief Wrapper around ::OrtMemoryInfo + * + */ +struct MemoryInfo : detail::MemoryInfoImpl { + static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); + explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created + explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C API + MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); + ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } +}; + +namespace detail { +template +struct TensorTypeAndShapeInfoImpl : Base { + using B = Base; + using B::B; + + ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType + size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount + + size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount + + /** \deprecated use GetShape() returning std::vector + * [[deprecated]] + * This interface is unsafe to use + */ + [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions + + void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions + std::vector GetSymbolicDimensions() const; + + std::vector GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape +}; + +} // namespace detail + +using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl>; + +/** \brief Wrapper around ::OrtTensorTypeAndShapeInfo + * + */ +struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl { + using Base = detail::TensorTypeAndShapeInfoImpl; + using Base::Base; + + /// Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used + explicit TensorTypeAndShapeInfo(std::nullptr_t) {} + /// Used for interop with the C API + explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} + + // Create a TensorTypeAndShapeInfo object with the specified element type and dimensions + // symbolic_dims are optional, but should be 1:1 with dims. + // The value in symbolic_dims will be used for all entries in dims that are -1. + explicit TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type, + const std::vector& dims, + const std::vector* symbolic_dims = nullptr); + + ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; } +}; + +namespace detail { +template +struct SequenceTypeInfoImpl : Base { + using B = Base; + using B::B; + TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType +}; + +} // namespace detail + +using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl>; + +/** \brief Wrapper around ::OrtSequenceTypeInfo + * + */ +struct SequenceTypeInfo : detail::SequenceTypeInfoImpl { + using Base = detail::SequenceTypeInfoImpl; + using Base::Base; + + explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used + explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl{p} {} ///< Used for interop with the C API + ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; } +}; + +namespace detail { +template +struct OptionalTypeInfoImpl : Base { + using B = Base; + using B::B; + TypeInfo GetOptionalElementType() const; ///< Wraps OrtApi::CastOptionalTypeToContainedTypeInfo +}; + +} // namespace detail + +// This is always owned by the TypeInfo and can only be obtained from it. +using ConstOptionalTypeInfo = detail::OptionalTypeInfoImpl>; + +namespace detail { +template +struct MapTypeInfoImpl : detail::Base { + using B = Base; + using B::B; + ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType + TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType +}; + +} // namespace detail + +using ConstMapTypeInfo = detail::MapTypeInfoImpl>; + +/** \brief Wrapper around ::OrtMapTypeInfo + * + */ +struct MapTypeInfo : detail::MapTypeInfoImpl { + using Base = detail::MapTypeInfoImpl; + using Base::Base; + + explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used + explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl{p} {} ///< Used for interop with the C API + ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; } +}; + +namespace detail { +template +struct TypeInfoImpl : detail::Base { + using B = Base; + using B::B; + + ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo + ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo + ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo + ConstOptionalTypeInfo GetOptionalTypeInfo() const; ///< wraps OrtApi::CastTypeInfoToOptionalTypeInfo + + ONNXType GetONNXType() const; +}; +} // namespace detail + +/// +/// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value. +/// Provides access to const OrtTypeInfo APIs. +/// +using ConstTypeInfo = detail::TypeInfoImpl>; + +/// +/// Type information that may contain either TensorTypeAndShapeInfo or +/// the information about contained sequence or map depending on the ONNXType. +/// +struct TypeInfo : detail::TypeInfoImpl { + using Base = detail::TypeInfoImpl; + using Base::Base; + + /// Create an empty TypeInfo object, must be assigned a valid one to be used + explicit TypeInfo(std::nullptr_t) {} + explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl{p} {} ///< C API Interop + +#if !defined(ORT_MINIMAL_BUILD) + static TypeInfo CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_info); + static TypeInfo CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_info); + static TypeInfo CreateSequenceTypeInfo(ConstTypeInfo sequence_type); + static TypeInfo CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type); + static TypeInfo CreateOptionalTypeInfo(ConstTypeInfo contained_type); +#endif // !defined(ORT_MINIMAL_BUILD) + + ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; } +}; + +namespace detail { +// This structure is used to feed sparse tensor values +// information for use with FillSparseTensor() API +// if the data type for the sparse tensor values is numeric +// use data.p_data, otherwise, use data.str pointer to feed +// values. data.str is an array of const char* that are zero terminated. +// number of strings in the array must match shape size. +// For fully sparse tensors use shape {0} and set p_data/str +// to nullptr. +struct OrtSparseValuesParam { + const int64_t* values_shape; + size_t values_shape_len; + union { + const void* p_data; + const char** str; + } data; +}; + +// Provides a way to pass shape in a single +// argument +struct Shape { + const int64_t* shape; + size_t shape_len; +}; + +template +struct ConstValueImpl : Base { + using B = Base; + using B::B; + + /// + /// Obtains a pointer to a user defined data for experimental purposes + /// + template + void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue + + bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc + bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None + + size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements + Value GetValue(int index, OrtAllocator* allocator) const; + + /// + /// This API returns a full length of string data contained within either a tensor or a sparse Tensor. + /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful + /// for allocating necessary memory and calling GetStringTensorContent(). + /// + /// total length of UTF-8 encoded bytes contained. No zero terminators counted. + size_t GetStringTensorDataLength() const; + + /// + /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor + /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate. + /// The user must also allocate offsets buffer with the number of entries equal to that of the contained + /// strings. + /// + /// Strings are always assumed to be on CPU, no X-device copy. + /// + /// user allocated buffer + /// length in bytes of the allocated buffer + /// a pointer to the offsets user allocated buffer + /// count of offsets, must be equal to the number of strings contained. + /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo() + /// for sparse tensors + void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; + + /// + /// Returns a const typed pointer to the tensor contained data. + /// No type checking is performed, the caller must ensure the type matches the tensor type. + /// + /// + /// const pointer to data, no copies made + template + const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// + + /// + /// Returns a non-typed pointer to a tensor contained data. + /// + /// const pointer to data, no copies made + const void* GetTensorRawData() const; + + /// + /// The API returns type information for data contained in a tensor. For sparse + /// tensors it returns type information for contained non-zero values. + /// It returns dense shape for sparse tensors. + /// + /// TypeInfo + TypeInfo GetTypeInfo() const; + + /// + /// The API returns type information for data contained in a tensor. For sparse + /// tensors it returns type information for contained non-zero values. + /// It returns dense shape for sparse tensors. + /// + /// TensorTypeAndShapeInfo + TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; + + /// + /// This API returns information about the memory allocation used to hold data. + /// + /// Non owning instance of MemoryInfo + ConstMemoryInfo GetTensorMemoryInfo() const; + + /// + /// The API copies UTF-8 encoded bytes for the requested string element + /// contained within a tensor or a sparse tensor into a provided buffer. + /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate. + /// + /// + /// + /// + void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const; + + /// + /// Returns string tensor UTF-8 encoded string element. + /// Use of this API is recommended over GetStringTensorElement() that takes void* buffer pointer. + /// + /// + /// std::string + std::string GetStringTensorElement(size_t element_index) const; + + /// + /// The API returns a byte length of UTF-8 encoded string element + /// contained in either a tensor or a spare tensor values. + /// + /// + /// byte length for the specified string element + size_t GetStringTensorElementLength(size_t element_index) const; + +#if !defined(DISABLE_SPARSE_TENSORS) + /// + /// The API returns the sparse data format this OrtValue holds in a sparse tensor. + /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used + /// the value returned is ORT_SPARSE_UNDEFINED. + /// + /// Format enum + OrtSparseFormat GetSparseFormat() const; + + /// + /// The API returns type and shape information for stored non-zero values of the + /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer. + /// + /// TensorTypeAndShapeInfo values information + TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const; + + /// + /// The API returns type and shape information for the specified indices. Each supported + /// indices have their own enum values even if a give format has more than one kind of indices. + /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer. + /// + /// enum requested + /// type and shape information + TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const; + + /// + /// The API retrieves a pointer to the internal indices buffer. The API merely performs + /// a convenience data type casting on the return type pointer. Make sure you are requesting + /// the right type, use GetSparseTensorIndicesTypeShapeInfo(); + /// + /// type to cast to + /// requested indices kind + /// number of indices entries + /// Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer. + template + const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const; + + /// + /// Returns true if the OrtValue contains a sparse tensor + /// + /// + bool IsSparseTensor() const; + + /// + /// The API returns a pointer to an internal buffer of the sparse tensor + /// containing non-zero values. The API merely does casting. Make sure you + /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo() + /// first. + /// + /// numeric data types only. Use GetStringTensor*() to retrieve strings. + /// a pointer to the internal values buffer. Do not free this pointer. + template + const R* GetSparseTensorValues() const; + +#endif +}; + +template +struct ValueImpl : ConstValueImpl { + using B = ConstValueImpl; + using B::B; + + /// + /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer + /// No type checking is performed, the caller must ensure the type matches the tensor type. + /// + /// non-const pointer to data, no copies made + template + R* GetTensorMutableData(); + + /// + /// Returns a non-typed non-const pointer to a tensor contained data. + /// + /// pointer to data, no copies made + void* GetTensorMutableRawData(); + + /// + // Obtain a reference to an element of data at the location specified + /// by the vector of dims. + /// + /// + /// [in] expressed by a vecotr of dimensions offsets + /// + template + R& At(const std::vector& location); + + /// + /// Set all strings at once in a string tensor + /// + /// [in] An array of strings. Each string in this array must be null terminated. + /// [in] Count of strings in s (Must match the size of \p value's tensor shape) + void FillStringTensor(const char* const* s, size_t s_len); + + /// + /// Set a single string in a string tensor + /// + /// [in] A null terminated UTF-8 encoded string + /// [in] Index of the string in the tensor to set + void FillStringTensorElement(const char* s, size_t index); + + /// + /// Allocate if necessary and obtain a pointer to a UTF-8 + /// encoded string element buffer indexed by the flat element index, + /// of the specified length. + /// + /// This API is for advanced usage. It avoids a need to construct + /// an auxiliary array of string pointers, and allows to write data directly + /// (do not zero terminate). + /// + /// + /// + /// a pointer to a writable buffer + char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length); + +#if !defined(DISABLE_SPARSE_TENSORS) + /// + /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor. + /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user + /// allocated buffers lifespan must eclipse that of the OrtValue. + /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. + /// + /// pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors. + /// number of indices entries. Use 0 for fully sparse tensors + void UseCooIndices(int64_t* indices_data, size_t indices_num); + + /// + /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor. + /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user + /// allocated buffers lifespan must eclipse that of the OrtValue. + /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. + /// + /// pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors + /// number of csr inner indices or 0 for fully sparse tensors + /// pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors + /// number of csr outer indices or 0 for fully sparse tensors + void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num); + + /// + /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor. + /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user + /// allocated buffers lifespan must eclipse that of the OrtValue. + /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. + /// + /// indices shape or a {0} for fully sparse + /// user allocated buffer with indices or nullptr for fully spare tensors + void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data); + + /// + /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API + /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located + /// at difference device than the allocator, a X-device copy will be performed if possible. + /// + /// specified buffer memory description + /// values buffer information. + /// coo indices buffer or nullptr for fully sparse data + /// number of COO indices or 0 for fully sparse data + void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param, + const int64_t* indices_data, size_t indices_num); + + /// + /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API + /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located + /// at difference device than the allocator, a X-device copy will be performed if possible. + /// + /// specified buffer memory description + /// values buffer information + /// csr inner indices pointer or nullptr for fully sparse tensors + /// number of csr inner indices or 0 for fully sparse tensors + /// pointer to csr indices data or nullptr for fully sparse tensors + /// number of csr outer indices or 0 + void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, + const OrtSparseValuesParam& values, + const int64_t* inner_indices_data, size_t inner_indices_num, + const int64_t* outer_indices_data, size_t outer_indices_num); + + /// + /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API + /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located + /// at difference device than the allocator, a X-device copy will be performed if possible. + /// + /// specified buffer memory description + /// values buffer information + /// indices shape. use {0} for fully sparse tensors + /// pointer to indices data or nullptr for fully sparse tensors + void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, + const OrtSparseValuesParam& values, + const Shape& indices_shape, + const int32_t* indices_data); + +#endif +}; + +} // namespace detail + +using ConstValue = detail::ConstValueImpl>; +using UnownedValue = detail::ValueImpl>; + +/** \brief Wrapper around ::OrtValue + * + */ +struct Value : detail::ValueImpl { + using Base = detail::ValueImpl; + using Base::Base; + using OrtSparseValuesParam = detail::OrtSparseValuesParam; + using Shape = detail::Shape; + + explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used + Value(Value&&) = default; + Value& operator=(Value&&) = default; + + ConstValue GetConst() const { return ConstValue{this->p_}; } + UnownedValue GetUnowned() const { return UnownedValue{this->p_}; } + + /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. + * \tparam T The numeric datatype. This API is not suitable for strings. + * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc). + * \param p_data Pointer to the data buffer. + * \param p_data_element_count The number of elements in the data buffer. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + */ + template + static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, + const int64_t* shape, size_t shape_len); + + /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. + * + * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc). + * \param p_data Pointer to the data buffer. + * \param p_data_byte_count The number of bytes in the data buffer. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + * \param type The data type. + */ + static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); + + /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAndDeleterAsOrtValue. + * + * \param deleter OrtAllocator that will be used to free the buffer when no longer required. + * \param p_data Pointer to the data buffer. + * \param p_data_byte_count The number of bytes in the data buffer. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + * \param type The data type. + */ + static Value CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); + + /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue. + * This overload will allocate the buffer for the tensor according to the supplied shape and data type. + * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. + * The input data would need to be copied into the allocated buffer. + * This API is not suitable for strings. + * + * \tparam T The numeric datatype. This API is not suitable for strings. + * \param allocator The allocator to use. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + */ + template + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len); + + /** \brief Creates an OrtValue with a tensor using the supplied OrtAllocator. + * Wraps OrtApi::CreateTensorAsOrtValue. + * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. + * The input data would need to be copied into the allocated buffer. + * This API is not suitable for strings. + * + * \param allocator The allocator to use. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + * \param type The data type. + */ + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); + + /** \brief Creates an OrtValue with a Map Onnx type representation. + * The API would ref-count the supplied OrtValues and they will be released + * when the returned OrtValue is released. The caller may release keys and values after the call + * returns. + * + * \param keys an OrtValue containing a tensor with primitive data type keys. + * \param values an OrtValue that may contain a tensor. Ort currently supports only primitive data type values. + */ + static Value CreateMap(const Value& keys, const Value& values); ///< Wraps OrtApi::CreateValue + + /** \brief Creates an OrtValue with a Sequence Onnx type representation. + * The API would ref-count the supplied OrtValues and they will be released + * when the returned OrtValue is released. The caller may release the values after the call + * returns. + * + * \param values a vector of OrtValues that must have the same Onnx value type. + */ + static Value CreateSequence(const std::vector& values); ///< Wraps OrtApi::CreateValue + + /** \brief Creates an OrtValue wrapping an Opaque type. + * This is used for experimental support of non-tensor types. + * + * \tparam T - the type of the value. + * \param domain - zero terminated utf-8 string. Domain of the type. + * \param type_name - zero terminated utf-8 string. Name of the type. + * \param value - the value to be wrapped. + */ + template + static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue + +#if !defined(DISABLE_SPARSE_TENSORS) + /// + /// This is a simple forwarding method to the other overload that helps deducing + /// data type enum value from the type of the buffer. + /// + /// numeric datatype. This API is not suitable for strings. + /// Memory description where the user buffers reside (CPU vs GPU etc) + /// pointer to the user supplied buffer, use nullptr for fully sparse tensors + /// a would be dense shape of the tensor + /// non zero values shape. Use a single 0 shape for fully sparse tensors. + /// + template + static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, + const Shape& values_shape); + + /// + /// Creates an OrtValue instance containing SparseTensor. This constructs + /// a sparse tensor that makes use of user allocated buffers. It does not make copies + /// of the user provided data and does not modify it. The lifespan of user provided buffers should + /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain + /// a pointer to non-zero values. To fully populate the sparse tensor call UseIndices() API below + /// to supply a sparse format specific indices. + /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings + /// can be properly copied into the allocated buffer. + /// + /// Memory description where the user buffers reside (CPU vs GPU etc) + /// pointer to the user supplied buffer, use nullptr for fully sparse tensors + /// a would be dense shape of the tensor + /// non zero values shape. Use a single 0 shape for fully sparse tensors. + /// data type + /// Ort::Value instance containing SparseTensor + static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, + const Shape& values_shape, ONNXTensorElementDataType type); + + /// + /// This is a simple forwarding method to the below CreateSparseTensor. + /// This helps to specify data type enum in terms of C++ data type. + /// Use CreateSparseTensor + /// + /// numeric data type only. String data enum must be specified explicitly. + /// allocator to use + /// a would be dense shape of the tensor + /// Ort::Value + template + static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape); + + /// + /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data. + /// The data must be supplied by on of the FillSparseTensor() methods that take both non-zero values + /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator. + /// Use this API to create OrtValues that contain sparse tensors with all supported data types including + /// strings. + /// + /// allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue + /// a would be dense shape of the tensor + /// data type + /// an instance of Ort::Value + static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type); + +#endif // !defined(DISABLE_SPARSE_TENSORS) +}; + +/// +/// Represents native memory allocation coming from one of the +/// OrtAllocators registered with OnnxRuntime. +/// Use it to wrap an allocation made by an allocator +/// so it can be automatically released when no longer needed. +/// +struct MemoryAllocation { + MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); + ~MemoryAllocation(); + MemoryAllocation(const MemoryAllocation&) = delete; + MemoryAllocation& operator=(const MemoryAllocation&) = delete; + MemoryAllocation(MemoryAllocation&&) noexcept; + MemoryAllocation& operator=(MemoryAllocation&&) noexcept; + + void* get() { return p_; } + size_t size() const { return size_; } + + private: + OrtAllocator* allocator_; + void* p_; + size_t size_; +}; + +namespace detail { +template +struct AllocatorImpl : Base { + using B = Base; + using B::B; + + void* Alloc(size_t size); + MemoryAllocation GetAllocation(size_t size); + void Free(void* p); + ConstMemoryInfo GetInfo() const; +}; + +} // namespace detail + +/** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime + * + */ +struct AllocatorWithDefaultOptions : detail::AllocatorImpl> { + explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance + AllocatorWithDefaultOptions(); +}; + +/** \brief Wrapper around ::OrtAllocator + * + */ +struct Allocator : detail::AllocatorImpl { + explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance + Allocator(const Session& session, const OrtMemoryInfo*); +}; + +using UnownedAllocator = detail::AllocatorImpl>; + +namespace detail { +namespace binding_utils { +// Bring these out of template +std::vector GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*); +std::vector GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*); +} // namespace binding_utils + +template +struct ConstIoBindingImpl : Base { + using B = Base; + using B::B; + + std::vector GetOutputNames() const; + std::vector GetOutputNames(OrtAllocator*) const; + std::vector GetOutputValues() const; + std::vector GetOutputValues(OrtAllocator*) const; +}; + +template +struct IoBindingImpl : ConstIoBindingImpl { + using B = ConstIoBindingImpl; + using B::B; + + void BindInput(const char* name, const Value&); + void BindOutput(const char* name, const Value&); + void BindOutput(const char* name, const OrtMemoryInfo*); + void ClearBoundInputs(); + void ClearBoundOutputs(); + void SynchronizeInputs(); + void SynchronizeOutputs(); +}; + +} // namespace detail + +using ConstIoBinding = detail::ConstIoBindingImpl>; +using UnownedIoBinding = detail::IoBindingImpl>; + +/** \brief Wrapper around ::OrtIoBinding + * + */ +struct IoBinding : detail::IoBindingImpl { + explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later. + explicit IoBinding(Session& session); + ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; } + UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; } +}; + +/*! \struct Ort::ArenaCfg + * \brief it is a structure that represents the configuration of an arena based allocator + * \details Please see docs/C_API.md for details + */ +struct ArenaCfg : detail::Base { + explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used + /** + * Wraps OrtApi::CreateArenaCfg + * \param max_mem - use 0 to allow ORT to choose the default + * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested + * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default + * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default + * See docs/C_API.md for details on what the following parameters mean and how to choose these values + */ + ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk); +}; + +// +// Custom OPs (only needed to implement custom OPs) +// + +/// +/// This struct provides life time management for custom op attribute +/// +struct OpAttr : detail::Base { + using Base = detail::Base; + using Base::Base; + + explicit OpAttr(std::nullptr_t) {} + OpAttr(const char* name, const void* data, int len, OrtOpAttrType type); +}; + +/** + * Macro that logs a message using the provided logger. Throws an exception if OrtApi::Logger_LogMessage fails. + * Example: ORT_CXX_LOG(logger, ORT_LOGGING_LEVEL_INFO, "Log a message"); + * + * \param logger The Ort::Logger instance to use. Must be a value or reference. + * \param message_severity The logging severity level of the message. + * \param message A null-terminated UTF-8 message to log. + */ +#define ORT_CXX_LOG(logger, message_severity, message) \ + do { \ + if (message_severity >= logger.GetLoggingSeverityLevel()) { \ + Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \ + static_cast(__FUNCTION__), message)); \ + } \ + } while (false) + +/** + * Macro that logs a message using the provided logger. Can be used in noexcept code since errors are silently ignored. + * Example: ORT_CXX_LOG_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log a message"); + * + * \param logger The Ort::Logger instance to use. Must be a value or reference. + * \param message_severity The logging severity level of the message. + * \param message A null-terminated UTF-8 message to log. + */ +#define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \ + do { \ + if (message_severity >= logger.GetLoggingSeverityLevel()) { \ + static_cast(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \ + static_cast(__FUNCTION__), message)); \ + } \ + } while (false) + +/** + * Macro that logs a printf-like formatted message using the provided logger. Throws an exception if + * OrtApi::Logger_LogMessage fails or if a formatting error occurs. + * Example: ORT_CXX_LOGF(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12); + * + * \param logger The Ort::Logger instance to use. Must be a value or reference. + * \param message_severity The logging severity level of the message. + * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. + * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. + * \param ... Zero or more variadic arguments referenced by the format string. + */ +#define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \ + do { \ + if (message_severity >= logger.GetLoggingSeverityLevel()) { \ + Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \ + static_cast(__FUNCTION__), __VA_ARGS__)); \ + } \ + } while (false) + +/** + * Macro that logs a printf-like formatted message using the provided logger. Can be used in noexcept code since errors + * are silently ignored. + * Example: ORT_CXX_LOGF_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12); + * + * \param logger The Ort::Logger instance to use. Must be a value or reference. + * \param message_severity The logging severity level of the message. + * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. + * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. + * \param ... Zero or more variadic arguments referenced by the format string. + */ +#define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \ + do { \ + if (message_severity >= logger.GetLoggingSeverityLevel()) { \ + static_cast(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \ + static_cast(__FUNCTION__), __VA_ARGS__)); \ + } \ + } while (false) + +/// +/// This class represents an ONNX Runtime logger that can be used to log information with an +/// associated severity level and source code location (file path, line number, function name). +/// +/// A Logger can be obtained from within custom operators by calling Ort::KernelInfo::GetLogger(). +/// Instances of Ort::Logger are the size of two pointers and can be passed by value. +/// +/// Use the ORT_CXX_LOG macros to ensure the source code location is set properly from the callsite +/// and to take advantage of a cached logging severity level that can bypass calls to the underlying C API. +/// +struct Logger { + /** + * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use. + */ + Logger() = default; + + /** + * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use. + */ + explicit Logger(std::nullptr_t) {} + + /** + * Creates a logger from an ::OrtLogger instance. Caches the logger's current severity level by calling + * OrtApi::Logger_GetLoggingSeverityLevel. Throws an exception if OrtApi::Logger_GetLoggingSeverityLevel fails. + * + * \param logger The ::OrtLogger to wrap. + */ + explicit Logger(const OrtLogger* logger); + + ~Logger() = default; + + Logger(const Logger&) = default; + Logger& operator=(const Logger&) = default; + + Logger(Logger&& v) noexcept = default; + Logger& operator=(Logger&& v) noexcept = default; + + /** + * Returns the logger's current severity level from the cached member. + * + * \return The current ::OrtLoggingLevel. + */ + OrtLoggingLevel GetLoggingSeverityLevel() const noexcept; + + /** + * Logs the provided message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOG or ORT_CXX_LOG_NOEXCEPT + * macros to properly set the source code location and to use the cached severity level to potentially bypass + * calls to the underlying C API. + * + * \param log_severity_level The message's logging severity level. + * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. + * \param line_number The file line number in which the message is logged. Usually the value of __LINE__. + * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. + * \param message The message to log. + * \return A Ort::Status value to indicate error or success. + */ + Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, + const char* func_name, const char* message) const noexcept; + + /** + * Logs a printf-like formatted message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOGF or ORT_CXX_LOGF_NOEXCEPT + * macros to properly set the source code location and to use the cached severity level to potentially bypass + * calls to the underlying C API. Returns an error status if a formatting error occurs. + * + * \param log_severity_level The message's logging severity level. + * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. + * \param line_number The file line number in which the message is logged. Usually the value of __LINE__. + * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. + * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. + * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. + * \param args Zero or more variadic arguments referenced by the format string. + * \return A Ort::Status value to indicate error or success. + */ + template + Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, + const char* func_name, const char* format, Args&&... args) const noexcept; + + private: + const OrtLogger* logger_{}; + OrtLoggingLevel cached_severity_level_{}; +}; + +/// +/// This class wraps a raw pointer OrtKernelContext* that is being passed +/// to the custom kernel Compute() method. Use it to safely access context +/// attributes, input and output parameters with exception safety guarantees. +/// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +/// +struct KernelContext { + explicit KernelContext(OrtKernelContext* context); + size_t GetInputCount() const; + size_t GetOutputCount() const; + // If input is optional and is not present, the method returns an empty ConstValue + // which can be compared to nullptr. + ConstValue GetInput(size_t index) const; + // If output is optional and is not present, the method returns an empty UnownedValue + // which can be compared to nullptr. + UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; + UnownedValue GetOutput(size_t index, const std::vector& dims) const; + void* GetGPUComputeStream() const; + Logger GetLogger() const; + OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const; + OrtKernelContext* GetOrtKernelContext() const { return ctx_; } + void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const; + + private: + OrtKernelContext* ctx_; +}; + +struct KernelInfo; + +namespace detail { +namespace attr_utils { +void GetAttr(const OrtKernelInfo* p, const char* name, float&); +void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&); +void GetAttr(const OrtKernelInfo* p, const char* name, std::string&); +void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector&); +void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector&); +} // namespace attr_utils + +template +struct KernelInfoImpl : Base { + using B = Base; + using B::B; + + KernelInfo Copy() const; + + template // R is only implemented for float, int64_t, and string + R GetAttribute(const char* name) const { + R val; + attr_utils::GetAttr(this->p_, name, val); + return val; + } + + template // R is only implemented for std::vector, std::vector + std::vector GetAttributes(const char* name) const { + std::vector result; + attr_utils::GetAttrs(this->p_, name, result); + return result; + } + + Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const; + + size_t GetInputCount() const; + size_t GetOutputCount() const; + + std::string GetInputName(size_t index) const; + std::string GetOutputName(size_t index) const; + + TypeInfo GetInputTypeInfo(size_t index) const; + TypeInfo GetOutputTypeInfo(size_t index) const; + + ConstValue GetTensorConstantInput(size_t index, int* is_constant) const; + + std::string GetNodeName() const; + Logger GetLogger() const; +}; + +} // namespace detail + +using ConstKernelInfo = detail::KernelInfoImpl>; + +/// +/// This struct owns the OrtKernInfo* pointer when a copy is made. +/// For convenient wrapping of OrtKernelInfo* passed to kernel constructor +/// and query attributes, warp the pointer with Ort::Unowned instance +/// so it does not destroy the pointer the kernel does not own. +/// +struct KernelInfo : detail::KernelInfoImpl { + using Base = detail::KernelInfoImpl; + using Base::Base; + explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later + explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance + ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; } +}; + +/// +/// Create and own custom defined operation. +/// +struct Op : detail::Base { + using Base = detail::Base; + using Base::Base; + + explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used + + explicit Op(OrtOp*); ///< Take ownership of the OrtOp + + static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain, + int version, const char** type_constraint_names, + const ONNXTensorElementDataType* type_constraint_values, + size_t type_constraint_count, + const OpAttr* attr_values, + size_t attr_count, + size_t input_count, size_t output_count); + + void Invoke(const OrtKernelContext* context, + const Value* input_values, + size_t input_count, + Value* output_values, + size_t output_count); + + // For easier refactoring + void Invoke(const OrtKernelContext* context, + const OrtValue* const* input_values, + size_t input_count, + OrtValue* const* output_values, + size_t output_count); +}; + +/// +/// Provide access to per-node attributes and input shapes, so one could compute and set output shapes. +/// +struct ShapeInferContext { + struct SymbolicInteger { + SymbolicInteger(int64_t i) : i_(i), is_int_(true) {}; + SymbolicInteger(const char* s) : s_(s), is_int_(false) {}; + SymbolicInteger(const SymbolicInteger&) = default; + SymbolicInteger(SymbolicInteger&&) = default; + + SymbolicInteger& operator=(const SymbolicInteger&) = default; + SymbolicInteger& operator=(SymbolicInteger&&) = default; + + bool operator==(const SymbolicInteger& dim) const { + if (is_int_ == dim.is_int_) { + if (is_int_) { + return i_ == dim.i_; + } else { + return std::string{s_} == std::string{dim.s_}; + } + } + return false; + } + + bool IsInt() const { return is_int_; } + int64_t AsInt() const { return i_; } + const char* AsSym() const { return s_; } + + static constexpr int INVALID_INT_DIM = -2; + + private: + union { + int64_t i_; + const char* s_; + }; + bool is_int_; + }; + + using Shape = std::vector; + + ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx); + + const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); } + + size_t GetInputCount() const { return input_shapes_.size(); } + + Status SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + int64_t GetAttrInt(const char* attr_name); + + using Ints = std::vector; + Ints GetAttrInts(const char* attr_name); + + float GetAttrFloat(const char* attr_name); + + using Floats = std::vector; + Floats GetAttrFloats(const char* attr_name); + + std::string GetAttrString(const char* attr_name); + + using Strings = std::vector; + Strings GetAttrStrings(const char* attr_name); + + private: + const OrtOpAttr* GetAttrHdl(const char* attr_name) const; + const OrtApi* ort_api_; + OrtShapeInferContext* ctx_; + std::vector input_shapes_; +}; + +using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&); + +#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1 + +template +struct CustomOpBase : OrtCustomOp { + CustomOpBase() { + OrtCustomOp::version = ORT_API_VERSION; + OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast(this_)->GetName(); }; + + OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast(this_)->GetExecutionProviderType(); }; + + OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetInputTypeCount(); }; + OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputType(index); }; + OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputMemoryType(index); }; + + OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetOutputTypeCount(); }; + OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputType(index); }; + +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +#pragma warning(disable : 26409) +#endif + OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast(op_kernel); }; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif + OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputCharacteristic(index); }; + OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputCharacteristic(index); }; + + OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast(this_)->GetVariadicInputMinArity(); }; + OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast(static_cast(this_)->GetVariadicInputHomogeneity()); }; + OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast(this_)->GetVariadicOutputMinArity(); }; + OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast(static_cast(this_)->GetVariadicOutputHomogeneity()); }; +#ifdef __cpp_if_constexpr + if constexpr (WithStatus) { +#else + if (WithStatus) { +#endif + OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr { + return static_cast(this_)->CreateKernelV2(*api, info, op_kernel); + }; + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + return static_cast(op_kernel)->ComputeV2(context); + }; + } else { + OrtCustomOp::CreateKernelV2 = nullptr; + OrtCustomOp::KernelComputeV2 = nullptr; + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { + static_cast(op_kernel)->Compute(context); + }; + } + + SetShapeInferFn(0); + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->end_ver_; + }; + + OrtCustomOp::GetMayInplace = nullptr; + OrtCustomOp::ReleaseMayInplace = nullptr; + OrtCustomOp::GetAliasMap = nullptr; + OrtCustomOp::ReleaseAliasMap = nullptr; + } + + // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider + const char* GetExecutionProviderType() const { return nullptr; } + + // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below + // (inputs and outputs are required by default) + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + } + + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + } + + // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault + OrtMemType GetInputMemoryType(size_t /*index*/) const { + return OrtMemTypeDefault; + } + + // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input + // should expect at least 1 argument. + int GetVariadicInputMinArity() const { + return 1; + } + + // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments + // to a variadic input should be of the same type. + bool GetVariadicInputHomogeneity() const { + return true; + } + + // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output + // should produce at least 1 output value. + int GetVariadicOutputMinArity() const { + return 1; + } + + // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values + // produced by a variadic output should be of the same type. + bool GetVariadicOutputHomogeneity() const { + return true; + } + + // Declare list of session config entries used by this Custom Op. + // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs(). + // This default implementation returns an empty vector of config entries. + std::vector GetSessionConfigKeys() const { + return std::vector{}; + } + + // Ort::CustomOpBase derived class should provide the following static method with the type/shape inferencing + // implementation if needed: + // static OrtStatusPtr InferOutputShape(Ort::ShapeInferContext& context) + template + decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + ShapeInferContext ctx(&GetApi(), ort_ctx); + return C::InferOutputShape(ctx); + }; + return {}; + } + + template + void SetShapeInferFn(...) { + OrtCustomOp::InferOutputShapeFn = {}; + } + + protected: + // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys. + void GetSessionConfigs(std::unordered_map& out, ConstSessionOptions options) const; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; +}; + +namespace detail { +template +struct ValueInfoImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + std::string Name() const; + ConstTypeInfo TypeInfo() const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstValueInfo = detail::ValueInfoImpl>; + +/** \brief Wrapper around ::OrtValueInfo + * + */ +struct ValueInfo : detail::ValueInfoImpl { + explicit ValueInfo(std::nullptr_t) {} ///< No instance is created + /// Take ownership of a pointer created by C API + explicit ValueInfo(OrtValueInfo* p) : ValueInfoImpl{p} {} + + // Create ValueInfo for a tensor + explicit ValueInfo(const std::string& name, const ConstTypeInfo& type_info); + + ConstValueInfo GetConst() const { return ConstValueInfo{this->p_}; } +}; + +namespace detail { +template +struct NodeImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; +}; +} // namespace detail + +/** \brief Wrapper around ::OrtNode + * + */ +struct Node : detail::NodeImpl { + explicit Node(std::nullptr_t) {} ///< No instance is created + explicit Node(OrtNode* p) : NodeImpl{p} {} ///< Take ownership of a pointer created by C API + +#if !defined(ORT_MINIMAL_BUILD) + Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names); + + /// + /// Wraps CreateNode. Node takes ownership of attributes on success and updates the OpAttr in `attributes` to do so. + /// + Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes); + + private: + static void Init(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes, + OrtNode*& node); +#endif // !defined(ORT_MINIMAL_BUILD) +}; + +namespace detail { +template +struct GraphImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + +#if !defined(ORT_MINIMAL_BUILD) + void SetInputs(std::vector& inputs); + void SetOutputs(std::vector& outputs); + void AddInitializer(const std::string& name, Value& initializer, bool data_is_external); // Graph takes ownership of Value + void AddNode(Node& node); // Graph takes ownership of Node +#endif // !defined(ORT_MINIMAL_BUILD) +}; +} // namespace detail + +/** \brief Wrapper around ::OrtGraph + * + */ +struct Graph : detail::GraphImpl { + explicit Graph(std::nullptr_t) {} ///< No instance is created + explicit Graph(OrtGraph* p) : GraphImpl{p} {} ///< Take ownership of a pointer created by C API +#if !defined(ORT_MINIMAL_BUILD) + Graph(); +#endif +}; + +namespace detail { +template +struct ModelImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + +#if !defined(ORT_MINIMAL_BUILD) + void AddGraph(Graph& graph); +#endif +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstModel = detail::ModelImpl>; + +/** \brief Wrapper around ::OrtModel + * + */ +struct Model : detail::ModelImpl { + using DomainOpsetPair = std::pair; + + explicit Model(std::nullptr_t) {} ///< No instance is created + explicit Model(OrtModel* p) : ModelImpl{p} {} ///< Take ownership of a pointer created by C API + +#if !defined(ORT_MINIMAL_BUILD) + explicit Model(const std::vector& opsets); +#endif + + ConstModel GetConst() const { return ConstModel{this->p_}; } +}; +} // namespace Ort +#include "onnxruntime_cxx_inline.h" diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h new file mode 100644 index 000000000..053eebcb8 --- /dev/null +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h @@ -0,0 +1,2574 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead. +// If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead. +// +// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter +// the main C++ file with implementation details. + +#include +#include +#include +#include +#include +#include + +// Convert OrtStatus to Ort::Status and return +// instead of throwing +#define ORT_CXX_RETURN_ON_API_FAIL(expression) \ + { \ + auto ort_status = (expression); \ + if (ort_status) { \ + return Ort::Status(ort_status); \ + } \ + } + +#ifdef __cpp_if_constexpr +#define ORT_CXX_IF_CONSTEXPR if constexpr +#else +#define ORT_CXX_IF_CONSTEXPR if +#endif + +namespace Ort { + +namespace detail { +inline void ThrowStatus(const Status& st) { + std::string error_message = st.GetErrorMessage(); + OrtErrorCode error_code = st.GetErrorCode(); + ORT_CXX_API_THROW(std::move(error_message), error_code); +} +} // namespace detail + +inline void ThrowOnError(OrtStatus* ort_status) { + if (ort_status) { + Ort::Status st(ort_status); + detail::ThrowStatus(st); + } +} + +inline void ThrowOnError(const Status& st) { + if (st) { + detail::ThrowStatus(st); + } +} + +inline Status::Status(OrtStatus* status) noexcept : detail::Base{status} { +} + +inline Status::Status(const std::exception& e) noexcept { + p_ = GetApi().CreateStatus(ORT_FAIL, e.what()); +} + +inline Status::Status(const Exception& e) noexcept { + p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what()); +} + +inline Status::Status(const char* message, OrtErrorCode code) noexcept { + p_ = GetApi().CreateStatus(code, message); +} + +inline std::string Status::GetErrorMessage() const { + std::string message(GetApi().GetErrorMessage(p_)); + return message; +} + +inline OrtErrorCode Status::GetErrorCode() const { + return GetApi().GetErrorCode(p_); +} + +inline bool Status::IsOK() const noexcept { + return (p_ == nullptr); +} + +// This template converts a C++ type into it's ONNXTensorElementDataType +template +struct TypeToTensorType; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; +}; + +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ; +}; + +inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept { + if (IsNaN() || rhs.IsNaN()) { + // IEEE defines that NaN is not equal to anything, including itself. + return false; + } + return val == rhs.val; +} + +inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept { + if (IsNaN() || rhs.IsNaN()) { + // IEEE defines that NaN is unordered with respect to everything, including itself. + return false; + } + + const bool left_is_negative = IsNegative(); + if (left_is_negative != rhs.IsNegative()) { + // When the signs of left and right differ, we know that left is less than right if it is + // the negative value. The exception to this is if both values are zero, in which case IEEE + // says they should be equal, even if the signs differ. + return left_is_negative && !AreZero(*this, rhs); + } + return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative); +} + +inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size) + : allocator_(allocator), p_(p), size_(size) { +} + +inline MemoryAllocation::~MemoryAllocation() { + if (p_ != nullptr) { + // We do not throw out of destructor + auto ret = GetApi().AllocatorFree(allocator_, p_); + static_cast(ret); + } +} + +inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) { + *this = std::move(o); +} + +inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept { + OrtAllocator* alloc = nullptr; + void* p = nullptr; + size_t sz = 0; + + // Swap out this + std::swap(alloc, allocator_); + std::swap(p, p_); + std::swap(sz, size_); + + // Swap with incoming + std::swap(allocator_, o.allocator_); + std::swap(p_, o.p_); + std::swap(size_, o.size_); + + // Destroy this instance if needed + MemoryAllocation this_alloc(alloc, p, sz); + return *this; +} + +namespace detail { + +template +inline void* AllocatorImpl::Alloc(size_t size) { + void* out; + ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out)); + return out; +} + +template +inline MemoryAllocation AllocatorImpl::GetAllocation(size_t size) { + void* out; + ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out)); + MemoryAllocation result(this->p_, out, size); + return result; +} + +template +inline void AllocatorImpl::Free(void* p) { + ThrowOnError(GetApi().AllocatorFree(this->p_, p)); +} + +template +inline ConstMemoryInfo AllocatorImpl::GetInfo() const { + const OrtMemoryInfo* out; + ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out)); + return ConstMemoryInfo{out}; +} + +} // namespace detail + +inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() { + ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_)); +} + +inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) { + ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_)); +} + +namespace detail { + +template +inline std::string MemoryInfoImpl::GetAllocatorName() const { + const char* name = nullptr; + ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name)); + return std::string(name); +} + +template +inline OrtAllocatorType MemoryInfoImpl::GetAllocatorType() const { + OrtAllocatorType type; + ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type)); + return type; +} + +template +inline int MemoryInfoImpl::GetDeviceId() const { + int id = 0; + ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id)); + return id; +} + +template +inline OrtMemoryInfoDeviceType MemoryInfoImpl::GetDeviceType() const { + OrtMemoryInfoDeviceType type; + GetApi().MemoryInfoGetDeviceType(this->p_, &type); + return type; +} + +template +inline OrtMemType MemoryInfoImpl::GetMemoryType() const { + OrtMemType type; + ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type)); + return type; +} + +template +template +inline bool MemoryInfoImpl::operator==(const MemoryInfoImpl& o) const { + int comp_result = 0; + ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result)); + return comp_result == 0; +} + +} // namespace detail + +inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) { + OrtMemoryInfo* p; + ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p)); + return MemoryInfo(p); +} + +inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { + ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_)); +} + +namespace detail { +template +inline std::vector ConstIoBindingImpl::GetOutputNames() const { + AllocatorWithDefaultOptions allocator; + return binding_utils::GetOutputNamesHelper(this->p_, allocator); +} + +template +inline std::vector ConstIoBindingImpl::GetOutputNames(OrtAllocator* allocator) const { + return binding_utils::GetOutputNamesHelper(this->p_, allocator); +} + +template +inline std::vector ConstIoBindingImpl::GetOutputValues() const { + AllocatorWithDefaultOptions allocator; + return binding_utils::GetOutputValuesHelper(this->p_, allocator); +} + +template +inline std::vector ConstIoBindingImpl::GetOutputValues(OrtAllocator* allocator) const { + return binding_utils::GetOutputValuesHelper(this->p_, allocator); +} + +template +inline void IoBindingImpl::BindInput(const char* name, const Value& value) { + ThrowOnError(GetApi().BindInput(this->p_, name, value)); +} + +template +inline void IoBindingImpl::BindOutput(const char* name, const Value& value) { + ThrowOnError(GetApi().BindOutput(this->p_, name, value)); +} + +template +inline void IoBindingImpl::BindOutput(const char* name, const OrtMemoryInfo* mem_info) { + ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info)); +} + +template +inline void IoBindingImpl::ClearBoundInputs() { + GetApi().ClearBoundInputs(this->p_); +} + +template +inline void IoBindingImpl::ClearBoundOutputs() { + GetApi().ClearBoundOutputs(this->p_); +} + +template +inline void IoBindingImpl::SynchronizeInputs() { + ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_)); +} + +template +inline void IoBindingImpl::SynchronizeOutputs() { + ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_)); +} + +namespace binding_utils { +inline std::vector GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) { + std::vector result; + auto free_fn = detail::AllocatedFree(allocator); + using Ptr = std::unique_ptr; + + char* buffer = nullptr; + size_t* lengths = nullptr; + size_t count = 0; + ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count)); + + if (count == 0) { + return result; + } + + Ptr buffer_g(buffer, free_fn); + Ptr lengths_g(lengths, free_fn); + + result.reserve(count); + for (size_t i = 0; i < count; ++i) { + auto sz = *lengths; + result.emplace_back(buffer, sz); + buffer += sz; + ++lengths; + } + return result; +} + +inline std::vector GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) { + std::vector result; + size_t owned = 0; + size_t output_count = 0; + // Lambda to release the buffer when no longer needed and + // make sure that we destroy all instances on exception + auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) { + if (buffer) { + while (owned < output_count) { + auto* p = buffer + owned++; + GetApi().ReleaseValue(*p); + } + allocator->Free(allocator, buffer); + } + }; + using Ptr = std::unique_ptr; + + OrtValue** output_buffer = nullptr; + ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count)); + if (output_count == 0) { + return result; + } + + Ptr buffer_g(output_buffer, free_fn); + + result.reserve(output_count); + for (size_t i = 0; i < output_count; ++i) { + result.emplace_back(output_buffer[i]); + ++owned; + } + return result; +} + +} // namespace binding_utils +} // namespace detail + +inline IoBinding::IoBinding(Session& session) { + ThrowOnError(GetApi().CreateIoBinding(session, &this->p_)); +} + +inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) { + ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_)); +} + +inline ThreadingOptions::ThreadingOptions() { + ThrowOnError(GetApi().CreateThreadingOptions(&p_)); +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) { + ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) { + ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) { + ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() { + ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { + ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) { + ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) { + ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn)); + return *this; +} + +inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) { + ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_)); + if (strcmp(logid, "onnxruntime-node") == 0) { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); + } else { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); + } +} + +inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) { + ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_)); + if (strcmp(logid, "onnxruntime-node") == 0) { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); + } else { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); + } +} + +inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) { + ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_)); + if (strcmp(logid, "onnxruntime-node") == 0) { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); + } else { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); + } +} + +inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, + OrtLoggingLevel logging_level, _In_ const char* logid) { + ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_)); + if (strcmp(logid, "onnxruntime-node") == 0) { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); + } else { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); + } +} + +inline Env& Env::EnableTelemetryEvents() { + ThrowOnError(GetApi().EnableTelemetryEvents(p_)); + return *this; +} + +inline Env& Env::DisableTelemetryEvents() { + ThrowOnError(GetApi().DisableTelemetryEvents(p_)); + return *this; +} + +inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) { + ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level)); + return *this; +} + +inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) { + ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg)); + return *this; +} + +inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg) { + std::vector keys, values; + auto num_entries = options.size(); + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + for (const auto& entry : options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + ThrowOnError(GetApi().CreateAndRegisterAllocatorV2(p_, provider_type.c_str(), mem_info, arena_cfg, keys.data(), values.data(), num_entries)); + return *this; +} + +inline CustomOpDomain::CustomOpDomain(const char* domain) { + ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_)); +} + +inline void CustomOpDomain::Add(const OrtCustomOp* op) { + ThrowOnError(GetApi().CustomOpDomain_Add(p_, op)); +} + +inline LoraAdapter LoraAdapter::CreateLoraAdapter(const std::basic_string& adapter_path, + OrtAllocator* allocator) { + OrtLoraAdapter* p; + ThrowOnError(GetApi().CreateLoraAdapter(adapter_path.c_str(), allocator, &p)); + return LoraAdapter{p}; +} + +inline LoraAdapter LoraAdapter::CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes, + OrtAllocator* allocator) { + OrtLoraAdapter* p; + ThrowOnError(GetApi().CreateLoraAdapterFromArray(bytes, num_bytes, allocator, &p)); + return LoraAdapter{p}; +} + +inline RunOptions::RunOptions() { + ThrowOnError(GetApi().CreateRunOptions(&p_)); +} + +inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) { + ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level)); + return *this; +} + +inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) { + ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level)); + return *this; +} + +inline int RunOptions::GetRunLogVerbosityLevel() const { + int out; + ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out)); + return out; +} + +inline int RunOptions::GetRunLogSeverityLevel() const { + int out; + ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out)); + return out; +} + +inline RunOptions& RunOptions::SetRunTag(const char* run_tag) { + ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag)); + return *this; +} + +inline const char* RunOptions::GetRunTag() const { + const char* out; + ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out)); + return out; +} + +inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) { + ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value)); + return *this; +} + +inline RunOptions& RunOptions::SetTerminate() { + ThrowOnError(GetApi().RunOptionsSetTerminate(p_)); + return *this; +} + +inline RunOptions& RunOptions::UnsetTerminate() { + ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_)); + return *this; +} + +inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter) { + ThrowOnError(GetApi().RunOptionsAddActiveLoraAdapter(p_, adapter)); + return *this; +} + +#if !defined(ORT_MINIMAL_BUILD) +inline ModelCompilationOptions::ModelCompilationOptions(const Env& env, const SessionOptions& session_options) { + ThrowOnError(GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(env, session_options, &this->p_)); +} + +inline ModelCompilationOptions::ModelCompilationOptions(const Env& env, ConstSessionOptions session_options) { + ThrowOnError(GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(env, session_options, &this->p_)); +} + +inline Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options) { + return Ort::Status(GetCompileApi().CompileModel(env, model_compilation_options)); +} + +inline ModelCompilationOptions& ModelCompilationOptions::SetInputModelPath( + const ORTCHAR_T* input_model_path) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetInputModelPath(this->p_, input_model_path)); + return *this; +} + +inline ModelCompilationOptions& ModelCompilationOptions::SetInputModelFromBuffer( + const void* input_model_data, size_t input_model_data_size) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetInputModelFromBuffer(this->p_, input_model_data, + input_model_data_size)); + return *this; +} + +inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelPath( + const ORTCHAR_T* output_model_path) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelPath(this->p_, output_model_path)); + return *this; +} + +inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalInitializersFile( + const ORTCHAR_T* file_path, size_t initializer_size_threshold) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelExternalInitializersFile( + this->p_, + file_path, + initializer_size_threshold)); + return *this; +} + +inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer( + OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelBuffer(this->p_, allocator, + output_model_buffer_ptr, + output_model_buffer_size_ptr)); + return *this; +} + +inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode( + bool embed_ep_context_in_model) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode( + this->p_, + embed_ep_context_in_model)); + return *this; +} +#endif // !defined(ORT_MINIMAL_BUILD) + +namespace detail { + +template +inline Ort::SessionOptions ConstSessionOptionsImpl::Clone() const { + OrtSessionOptions* out; + ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out)); + return SessionOptions{out}; +} + +template +inline std::string ConstSessionOptionsImpl::GetConfigEntry(const char* config_key) const { + size_t size = 0; + // Feed nullptr for the data buffer to query the true size of the string value + Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + + return out; +} + +template +inline bool ConstSessionOptionsImpl::HasConfigEntry(const char* config_key) const { + int out = 0; + Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out)); + return static_cast(out); +} + +template +inline std::string ConstSessionOptionsImpl::GetConfigEntryOrDefault(const char* config_key, const std::string& def) { + if (!this->HasConfigEntry(config_key)) { + return def; + } + + return this->GetConfigEntry(config_key); +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetIntraOpNumThreads(int intra_op_num_threads) { + ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetInterOpNumThreads(int inter_op_num_threads) { + ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { + ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetDeterministicCompute(bool value) { + ThrowOnError(GetApi().SetDeterministicCompute(this->p_, value)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) { + ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::EnableProfiling(const ORTCHAR_T* profile_file_prefix) { + ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::DisableProfiling() { + ThrowOnError(GetApi().DisableProfiling(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::EnableOrtCustomOps() { + ThrowOnError(GetApi().EnableOrtCustomOps(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::EnableMemPattern() { + ThrowOnError(GetApi().EnableMemPattern(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::DisableMemPattern() { + ThrowOnError(GetApi().DisableMemPattern(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::EnableCpuMemArena() { + ThrowOnError(GetApi().EnableCpuMemArena(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::DisableCpuMemArena() { + ThrowOnError(GetApi().DisableCpuMemArena(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetExecutionMode(ExecutionMode execution_mode) { + ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetLoadCancellationFlag(bool value) { + ThrowOnError(GetApi().SessionOptionsSetLoadCancellationFlag(this->p_, value)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetLogId(const char* logid) { + ThrowOnError(GetApi().SetSessionLogId(this->p_, logid)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetLogSeverityLevel(int level) { + ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::Add(OrtCustomOpDomain* custom_op_domain) { + ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AddConfigEntry(const char* config_key, const char* config_value) { + ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AddInitializer(const char* name, const OrtValue* ort_val) { + ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::DisablePerSessionThreads() { + ThrowOnError(GetApi().DisablePerSessionThreads(this->p_)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AddExternalInitializers(const std::vector& names, + const std::vector& ort_values) { + const size_t inputs_num = names.size(); + if (inputs_num != ort_values.size()) { + ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT); + } + std::vector names_ptr; + std::vector ort_values_ptrs; + names_ptr.reserve(inputs_num); + ort_values_ptrs.reserve(inputs_num); + for (size_t i = 0; i < inputs_num; ++i) { + names_ptr.push_back(names[i].c_str()); + ort_values_ptrs.push_back(ort_values[i]); + } + ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AddExternalInitializersFromFilesInMemory(const std::vector>& file_names, + const std::vector& buffer_array, + const std::vector& file_lengths) { + const size_t inputs_num = file_names.size(); + if (inputs_num != buffer_array.size()) { + ORT_CXX_API_THROW("Expecting names and buffer_array to have the same length", ORT_INVALID_ARGUMENT); + } + if (inputs_num != file_lengths.size()) { + ORT_CXX_API_THROW("Expecting names and file_lengths to have the same length", ORT_INVALID_ARGUMENT); + } + std::vector names_ptr; + names_ptr.reserve(inputs_num); + for (size_t i = 0; i < inputs_num; ++i) { + names_ptr.push_back(file_names[i].c_str()); + } + ThrowOnError(GetApi().AddExternalInitializersFromFilesInMemory(this->p_, names_ptr.data(), buffer_array.data(), + file_lengths.data(), inputs_num)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider( + const std::string& provider_name, + const std::unordered_map& provider_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(), + keys.data(), values.data(), num_entries)); + + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { + ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) { + ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) { + ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_OpenVINO_V2(const std::unordered_map& provider_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO_V2(this->p_, + keys.data(), values.data(), num_entries)); + + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_VitisAI(this->p_, keys.data(), values.data(), num_entries)); + + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, + const CustomOpConfigs& custom_op_configs) { + // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by + // the custom op library. + for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) { + AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str()); + } + + ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsUsingFunction(const char* registration_function_name) { + ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name)); + return *this; +} + +/// Session +template +inline size_t ConstSessionImpl::GetInputCount() const { + size_t out; + ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out)); + return out; +} + +template +inline size_t ConstSessionImpl::GetOutputCount() const { + size_t out; + ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out)); + return out; +} + +template +inline size_t ConstSessionImpl::GetOverridableInitializerCount() const { + size_t out; + ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out)); + return out; +} + +template +inline std::vector ConstSessionImpl::GetInputNames() const { + AllocatorWithDefaultOptions allocator; + + auto num_inputs = GetInputCount(); + std::vector input_names; + input_names.reserve(num_inputs); + + for (size_t i = 0; i < num_inputs; ++i) { + char* name = nullptr; + ThrowOnError(GetApi().SessionGetInputName(this->p_, i, allocator, &name)); + input_names.push_back(name); + allocator.Free(name); + } + + return input_names; +} + +template +inline std::vector ConstSessionImpl::GetOutputNames() const { + AllocatorWithDefaultOptions allocator; + + auto num_inputs = GetOutputCount(); + std::vector output_names; + output_names.reserve(num_inputs); + + for (size_t i = 0; i < num_inputs; ++i) { + char* name = nullptr; + ThrowOnError(GetApi().SessionGetOutputName(this->p_, i, allocator, &name)); + output_names.push_back(name); + allocator.Free(name); + } + + return output_names; +} + +template +inline std::vector ConstSessionImpl::GetOverridableInitializerNames() const { + AllocatorWithDefaultOptions allocator; + + auto num_initializers = GetOverridableInitializerCount(); + std::vector initializer_names; + initializer_names.reserve(num_initializers); + + for (size_t i = 0; i < num_initializers; ++i) { + char* name = nullptr; + ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, i, allocator, &name)); + initializer_names.push_back(name); + } + + return initializer_names; +} + +template +inline AllocatedStringPtr ConstSessionImpl::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +template +inline AllocatedStringPtr ConstSessionImpl::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +template +inline AllocatedStringPtr ConstSessionImpl::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +template +inline uint64_t ConstSessionImpl::GetProfilingStartTimeNs() const { + uint64_t out; + ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out)); + return out; +} + +template +inline ModelMetadata ConstSessionImpl::GetModelMetadata() const { + OrtModelMetadata* out; + ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out)); + return ModelMetadata{out}; +} + +template +inline TypeInfo ConstSessionImpl::GetInputTypeInfo(size_t index) const { + OrtTypeInfo* out; + ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out)); + return TypeInfo{out}; +} + +template +inline TypeInfo ConstSessionImpl::GetOutputTypeInfo(size_t index) const { + OrtTypeInfo* out; + ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out)); + return TypeInfo{out}; +} + +template +inline TypeInfo ConstSessionImpl::GetOverridableInitializerTypeInfo(size_t index) const { + OrtTypeInfo* out; + ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out)); + return TypeInfo{out}; +} + +#if !defined(ORT_MINIMAL_BUILD) +template +inline int ConstSessionImpl::GetOpset(const std::string& domain) const { + int opset; + ThrowOnError(GetModelEditorApi().SessionGetOpsetForDomain(this->p_, domain.c_str(), &opset)); + return opset; +} +#endif // !defined(ORT_MINIMAL_BUILD) + +template +std::vector ConstSessionImpl::GetInputs() const { + const std::vector input_names = GetInputNames(); + + std::vector inputs; + inputs.reserve(input_names.size()); + + for (size_t i = 0; i < input_names.size(); ++i) { + auto type_info = GetInputTypeInfo(i); + inputs.emplace_back(ValueInfo{input_names[i], type_info.GetConst()}); + } + + return inputs; +} + +template +std::vector ConstSessionImpl::GetOutputs() const { + const std::vector output_names = GetOutputNames(); + + std::vector outputs; + outputs.reserve(output_names.size()); + + for (size_t i = 0; i < output_names.size(); ++i) { + auto type_info = GetOutputTypeInfo(i); + outputs.emplace_back(ValueInfo{output_names[i], type_info.GetConst()}); + } + + return outputs; +} + +template +inline std::vector SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, size_t output_count) { + std::vector output_values; + output_values.reserve(output_count); + for (size_t i = 0; i < output_count; i++) + output_values.emplace_back(nullptr); + Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count); + return output_values; +} + +template +inline void SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count) { + static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); + auto ort_input_values = reinterpret_cast(input_values); + auto ort_output_values = reinterpret_cast(output_values); + ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values)); +} + +template +inline void SessionImpl::Run(const RunOptions& run_options, const IoBinding& io_binding) { + ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding)); +} + +template +inline void SessionImpl::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) { + auto ort_input_values = reinterpret_cast(input_values); + auto ort_output_values = reinterpret_cast(output_values); + ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names, + ort_input_values, input_count, output_names, output_count, + ort_output_values, callback, user_data)); +} + +template +inline AllocatedStringPtr SessionImpl::EndProfilingAllocated(OrtAllocator* allocator) { + char* out = nullptr; + ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +template +inline void SessionImpl::SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len) { + ThrowOnError(GetApi().SetEpDynamicOptions(this->p_, keys, values, kv_len)); +} + +#if !defined(ORT_MINIMAL_BUILD) +template +inline void SessionImpl::FinalizeModelEditorSession(const Model& model, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container) { + ThrowOnError(GetModelEditorApi().ApplyModelToModelEditorSession(this->p_, model)); + ThrowOnError(GetModelEditorApi().FinalizeModelEditorSession(this->p_, options, prepacked_weights_container)); +} +#endif // #if !defined(ORT_MINIMAL_BUILD) + +} // namespace detail + +inline SessionOptions::SessionOptions() { + ThrowOnError(GetApi().CreateSessionOptions(&this->p_)); +} + +/// CustomOpConfigs +inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) { + std::string config_key = "custom_op."; + + config_key += custom_op_name; + config_key += "."; + config_key += config; + + return config_key; +} + +inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) { + const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key); + flat_configs_[full_flat_key] = config_value; + return *this; +} + +inline const std::unordered_map& CustomOpConfigs::GetFlattenedConfigs() const { + return flat_configs_; +} + +inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { + ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_)); +} + +inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container) { + ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_)); +} + +inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) { + ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_)); +} + +inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, + const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) { + ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options, + prepacked_weights_container, &this->p_)); +} + +#if !defined(ORT_MINIMAL_BUILD) +inline Session::Session(const Env& env, const Model& model, const SessionOptions& options) { + ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model.GetConst(), options, &this->p_)); +} + +// static +inline Session Session::CreateModelEditorSession(const Env& env, const ORTCHAR_T* model_path, + const SessionOptions& options) { + OrtSession* session = nullptr; + ThrowOnError(GetModelEditorApi().CreateModelEditorSession(env, model_path, options, &session)); + return Session(session); +} + +// static +inline Session Session::CreateModelEditorSession(const Env& env, const void* model_data, size_t model_data_length, + const SessionOptions& options) { + OrtSession* session = nullptr; + ThrowOnError(GetModelEditorApi().CreateModelEditorSessionFromArray(env, model_data, model_data_length, options, + &session)); + return Session(session); +} + +void FinalizeModelEditorSession(const Model& model, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container); +#endif // #if !defined(ORT_MINIMAL_BUILD) + +inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline std::vector ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const { + auto deletor = detail::AllocatedFree(allocator); + std::vector result; + + char** out = nullptr; + int64_t num_keys = 0; + ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys)); + if (num_keys <= 0) { + return result; + } + + // array of pointers will be freed + std::unique_ptr array_guard(out, deletor); + // reserve may throw + auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); }; + std::unique_ptr strings_guard(out, strings_deletor); + result.reserve(static_cast(num_keys)); + strings_guard.release(); + for (int64_t i = 0; i < num_keys; ++i) { + result.push_back(AllocatedStringPtr(out[i], deletor)); + } + + return result; +} + +inline int64_t ModelMetadata::GetVersion() const { + int64_t out; + ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out)); + return out; +} + +inline TensorTypeAndShapeInfo::TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type, + const std::vector& dims, + const std::vector* symbolic_dims) { + ThrowOnError(GetApi().CreateTensorTypeAndShapeInfo(&p_)); + ThrowOnError(GetApi().SetTensorElementType(p_, element_type)); + ThrowOnError(GetApi().SetDimensions(p_, dims.data(), dims.size())); + + if (symbolic_dims) { + std::vector symbolic_dims_cstr; + symbolic_dims_cstr.reserve(symbolic_dims->size()); + std::transform(symbolic_dims->begin(), symbolic_dims->end(), std::back_inserter(symbolic_dims_cstr), + [](const std::string& s) { return s.c_str(); }); + ThrowOnError(GetApi().SetSymbolicDimensions(p_, symbolic_dims_cstr.data(), symbolic_dims_cstr.size())); + } +} + +#if !defined(ORT_MINIMAL_BUILD) +// static +inline TypeInfo TypeInfo::CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_type_and_shape_info) { + OrtTypeInfo* output = nullptr; + ThrowOnError(GetModelEditorApi().CreateTensorTypeInfo(tensor_type_and_shape_info, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_type_and_shape_info) { + OrtTypeInfo* output = nullptr; + ThrowOnError(GetModelEditorApi().CreateSparseTensorTypeInfo(sparse_tensor_type_and_shape_info, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateSequenceTypeInfo(ConstTypeInfo sequence_type) { + OrtTypeInfo* output; + ThrowOnError(GetModelEditorApi().CreateSequenceTypeInfo(sequence_type, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type) { + OrtTypeInfo* output; + ThrowOnError(GetModelEditorApi().CreateMapTypeInfo(key_type, value_type, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateOptionalTypeInfo(ConstTypeInfo contained_type) { + OrtTypeInfo* output; + ThrowOnError(GetModelEditorApi().CreateOptionalTypeInfo(contained_type, &output)); + return TypeInfo{output}; +} +#endif // #if !defined(ORT_MINIMAL_BUILD) + +namespace detail { + +template +inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl::GetElementType() const { + ONNXTensorElementDataType out; + ThrowOnError(GetApi().GetTensorElementType(this->p_, &out)); + return out; +} + +template +inline size_t TensorTypeAndShapeInfoImpl::GetElementCount() const { + size_t out; + ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out)); + return static_cast(out); +} + +template +inline size_t TensorTypeAndShapeInfoImpl::GetDimensionsCount() const { + size_t out; + ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out)); + return out; +} + +template +inline void TensorTypeAndShapeInfoImpl::GetDimensions(int64_t* values, size_t values_count) const { + ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count)); +} + +template +inline void TensorTypeAndShapeInfoImpl::GetSymbolicDimensions(const char** values, size_t values_count) const { + ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count)); +} + +template +inline std::vector TensorTypeAndShapeInfoImpl::GetSymbolicDimensions() const { + std::vector out(GetDimensionsCount(), nullptr); + ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, out.data(), out.size())); + return out; +} + +template +inline std::vector TensorTypeAndShapeInfoImpl::GetShape() const { + std::vector out(GetDimensionsCount(), -1); + ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size())); + return out; +} + +template +inline ConstTensorTypeAndShapeInfo TypeInfoImpl::GetTensorTypeAndShapeInfo() const { + const OrtTensorTypeAndShapeInfo* out; + ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out)); + return ConstTensorTypeAndShapeInfo{out}; +} + +template +inline ConstSequenceTypeInfo TypeInfoImpl::GetSequenceTypeInfo() const { + const OrtSequenceTypeInfo* out; + ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out)); + return ConstSequenceTypeInfo{out}; +} + +template +inline ConstMapTypeInfo TypeInfoImpl::GetMapTypeInfo() const { + const OrtMapTypeInfo* out; + ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out)); + return ConstMapTypeInfo{out}; +} + +template +inline ONNXType TypeInfoImpl::GetONNXType() const { + ONNXType out; + ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out)); + return out; +} + +template +inline TypeInfo SequenceTypeInfoImpl::GetSequenceElementType() const { + OrtTypeInfo* output; + ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output)); + return TypeInfo{output}; +} + +template +inline TypeInfo OptionalTypeInfoImpl::GetOptionalElementType() const { + OrtTypeInfo* info; + ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info)); + return TypeInfo{info}; +} + +template +inline ONNXTensorElementDataType MapTypeInfoImpl::GetMapKeyType() const { + ONNXTensorElementDataType out; + ThrowOnError(GetApi().GetMapKeyType(this->p_, &out)); + return out; +} + +template +inline TypeInfo MapTypeInfoImpl::GetMapValueType() const { + OrtTypeInfo* output; + ThrowOnError(GetApi().GetMapValueType(this->p_, &output)); + return TypeInfo{output}; +} + +template +inline ConstOptionalTypeInfo TypeInfoImpl::GetOptionalTypeInfo() const { + const OrtOptionalTypeInfo* info; + ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info)); + return ConstOptionalTypeInfo{info}; +} + +} // namespace detail + +namespace detail { + +template +template +inline void ConstValueImpl::GetOpaqueData(const char* domain, const char* type_name, R& out) const { + ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R))); +} + +template +inline bool ConstValueImpl::IsTensor() const { + int out; + ThrowOnError(GetApi().IsTensor(this->p_, &out)); + return out != 0; +} + +template +inline bool ConstValueImpl::HasValue() const { + int out; + ThrowOnError(GetApi().HasValue(this->p_, &out)); + return out != 0; +} + +template +inline size_t ConstValueImpl::GetCount() const { + size_t out; + ThrowOnError(GetApi().GetValueCount(this->p_, &out)); + return out; +} + +template +inline Value ConstValueImpl::GetValue(int index, OrtAllocator* allocator) const { + OrtValue* out; + ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out)); + return Value{out}; +} + +template +inline size_t ConstValueImpl::GetStringTensorDataLength() const { + size_t out; + ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out)); + return out; +} + +template +inline size_t ConstValueImpl::GetStringTensorElementLength(size_t element_index) const { + size_t out; + ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out)); + return out; +} + +template +template +inline const R* ConstValueImpl::GetTensorData() const { + R* out; + ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), (void**)&out)); + return out; +} + +template +inline const void* ConstValueImpl::GetTensorRawData() const { + void* out; + ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), &out)); + return out; +} + +template +inline TypeInfo ConstValueImpl::GetTypeInfo() const { + OrtTypeInfo* output; + ThrowOnError(GetApi().GetTypeInfo(this->p_, &output)); + return TypeInfo{output}; +} + +template +inline TensorTypeAndShapeInfo ConstValueImpl::GetTensorTypeAndShapeInfo() const { + OrtTensorTypeAndShapeInfo* output; + ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output)); + return TensorTypeAndShapeInfo{output}; +} + +template +inline ConstMemoryInfo ConstValueImpl::GetTensorMemoryInfo() const { + const OrtMemoryInfo* mem_info; + ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info)); + return ConstMemoryInfo(mem_info); +} + +template +inline void ConstValueImpl::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const { + ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer)); +} + +template +inline std::string ConstValueImpl::GetStringTensorElement(size_t element_index) const { + size_t buffer_length; + ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length)); + + std::string s; + s.resize(buffer_length); + ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0])); + return s; +} + +template +inline void ConstValueImpl::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const { + ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count)); +} + +#if !defined(DISABLE_SPARSE_TENSORS) +template +inline OrtSparseFormat ConstValueImpl::GetSparseFormat() const { + OrtSparseFormat format; + ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format)); + return format; +} + +template +inline TensorTypeAndShapeInfo ConstValueImpl::GetSparseTensorValuesTypeAndShapeInfo() const { + OrtTensorTypeAndShapeInfo* output; + ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output)); + return TensorTypeAndShapeInfo{output}; +} + +template +inline TensorTypeAndShapeInfo ConstValueImpl::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const { + OrtTensorTypeAndShapeInfo* output; + ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output)); + return TensorTypeAndShapeInfo{output}; +} + +template +template +inline const R* ConstValueImpl::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const { + const void* out; + ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out)); + return reinterpret_cast(out); +} + +template +inline bool ConstValueImpl::IsSparseTensor() const { + int out; + ThrowOnError(GetApi().IsSparseTensor(this->p_, &out)); + return out != 0; +} + +template +template +inline const R* ConstValueImpl::GetSparseTensorValues() const { + const void* out; + ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out)); + return reinterpret_cast(out); +} + +#endif + +template +void ValueImpl::FillStringTensor(const char* const* s, size_t s_len) { + ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len)); +} + +template +void ValueImpl::FillStringTensorElement(const char* s, size_t index) { + ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index)); +} + +template +inline char* ValueImpl::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) { + char* result; + ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result)); + return result; +} + +template +void* ValueImpl::GetTensorMutableRawData() { + void* out; + ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out)); + return out; +} + +template +template +R* ValueImpl::GetTensorMutableData() { + R* out; + ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out)); + return out; +} + +template +template +R& ValueImpl::At(const std::vector& location) { + static_assert(!std::is_same::value, "this api does not support std::string"); + R* out; + ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out)); + return *out; +} + +#if !defined(DISABLE_SPARSE_TENSORS) +template +void ValueImpl::UseCooIndices(int64_t* indices_data, size_t indices_num) { + ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num)); +} + +template +void ValueImpl::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) { + ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num)); +} + +template +void ValueImpl::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) { + ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data)); +} + +template +void ValueImpl::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param, + const int64_t* indices_data, size_t indices_num) { + ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape, + values_param.values_shape_len, values_param.data.p_data, + indices_data, indices_num)); +} + +template +void ValueImpl::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, + const OrtSparseValuesParam& values, + const int64_t* inner_indices_data, size_t inner_indices_num, + const int64_t* outer_indices_data, size_t outer_indices_num) { + ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, + inner_indices_data, inner_indices_num, + outer_indices_data, outer_indices_num)); +} + +template +void ValueImpl::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, + const OrtSparseValuesParam& values, + const Shape& indices_shape, + const int32_t* indices_data) { + ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, + indices_shape.shape, indices_shape.shape_len, + indices_data)); +} + +#endif // !defined(DISABLE_SPARSE_TENSORS) + +} // namespace detail + +template +inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, + const int64_t* shape, size_t shape_len) { + return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); +} + +inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); + return Value{out}; +} + +inline Value Value::CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateTensorWithDataAndDeleterAsOrtValue(deleter, p_data, p_data_byte_count, + shape, shape_len, type, &out)); + return Value{out}; +} + +template +inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) { + return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type); +} + +inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out)); + return Value{out}; +} + +#if !defined(DISABLE_SPARSE_TENSORS) + +template +inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, + const Shape& values_shape) { + return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType::type); +} + +inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, + const Shape& values_shape, ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len, + values_shape.shape, values_shape.shape_len, type, + &out)); + return Value{out}; +} + +template +inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) { + return CreateSparseTensor(allocator, dense_shape, TypeToTensorType::type); +} + +inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, + ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out)); + return Value{out}; +} +#endif // !defined(DISABLE_SPARSE_TENSORS) + +inline Value Value::CreateMap(const Value& keys, const Value& values) { + OrtValue* out; + const OrtValue* inputs[2] = {keys, values}; + ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out)); + return Value{out}; +} + +inline Value Value::CreateSequence(const std::vector& values) { + OrtValue* out; + std::vector values_ort{values.data(), values.data() + values.size()}; + ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out)); + return Value{out}; +} + +template +inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) { + OrtValue* out; + ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out)); + return Value{out}; +} + +// +// Custom OP Inlines +// +inline Logger::Logger(const OrtLogger* logger) : logger_(logger) { + Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_)); +} + +inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept { + return cached_severity_level_; +} + +inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, + const char* func_name, const char* message) const noexcept { + OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number, + func_name); + return Status{status}; +} + +// Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security) +// for gcc and clang. The alternative is to use actual C-style variadic parameters and apply +// __attribute__(format(printf...)), which does not work with variadic templates. +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wformat-nonliteral" +#pragma GCC diagnostic ignored "-Wformat-security" +#elif defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" +#pragma clang diagnostic ignored "-Wformat-security" +#endif +template +inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, + int line_number, const char* func_name, const char* format, + Args&&... args) const noexcept { + int msg_len = std::snprintf(nullptr, 0U, format, std::forward(args)...); + + if (msg_len < 0) { // Formatting error + return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL); + } + + OrtStatus* status = nullptr; + const size_t buffer_size = static_cast(msg_len) + 1U; + + constexpr size_t kStackBufferSize = 1024; + + if (buffer_size < kStackBufferSize) { + char buffer[kStackBufferSize]; + snprintf(buffer, kStackBufferSize, format, std::forward(args)...); + status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name); + } else { + // std::make_unique is only supported starting at C++14. +#if (__cplusplus >= 201402L) || (_MSC_VER >= 1900) + auto buffer = std::make_unique(buffer_size); +#else + std::unique_ptr buffer(new char[buffer_size]); +#endif + std::snprintf(buffer.get(), buffer_size, format, std::forward(args)...); + status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name); + } + + return Status{status}; +} +// Re-enable -Wformat-nonliteral and -Wformat-security +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#elif defined(__clang__) +#pragma clang diagnostic pop +#endif + +inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) { +} + +inline size_t KernelContext::GetInputCount() const { + size_t out = 0; + Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out)); + return out; +} + +inline size_t KernelContext::GetOutputCount() const { + size_t out = 0; + Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out)); + return out; +} + +inline ConstValue KernelContext::GetInput(size_t index) const { + const OrtValue* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out)); + return ConstValue{out}; +} + +inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const { + OrtValue* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out)); + return UnownedValue(out); +} + +inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector& dims) const { + OrtValue* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out)); + return UnownedValue(out); +} + +inline void* KernelContext::GetGPUComputeStream() const { + void* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out)); + return out; +} + +inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const { + OrtAllocator* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out)); + return out; +} + +inline Logger KernelContext::GetLogger() const { + const OrtLogger* out = nullptr; + ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out)); + return Logger{out}; +} + +inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const { + ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data)); +} + +inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) { + Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_)); +} + +namespace detail { +template +inline KernelInfo KernelInfoImpl::Copy() const { + OrtKernelInfo* info_copy = nullptr; + Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy)); + return KernelInfo{info_copy}; +} + +template +inline size_t KernelInfoImpl::GetInputCount() const { + size_t out = 0; + ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out)); + return out; +} + +template +inline size_t KernelInfoImpl::GetOutputCount() const { + size_t out = 0; + ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out)); + return out; +} + +template +inline std::string KernelInfoImpl::GetInputName(size_t index) const { + size_t size = 0; + + // Feed nullptr for the data buffer to query the true size of the string value + Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + + return out; +} + +template +inline std::string KernelInfoImpl::GetOutputName(size_t index) const { + size_t size = 0; + + // Feed nullptr for the data buffer to query the true size of the string value + Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + + return out; +} + +template +inline TypeInfo KernelInfoImpl::GetInputTypeInfo(size_t index) const { + OrtTypeInfo* out = nullptr; + ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out)); + return TypeInfo{out}; +} + +template +inline TypeInfo KernelInfoImpl::GetOutputTypeInfo(size_t index) const { + OrtTypeInfo* out = nullptr; + ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out)); + return TypeInfo{out}; +} + +template +inline Value KernelInfoImpl::GetTensorAttribute(const char* name, OrtAllocator* allocator) const { + OrtValue* out = nullptr; + ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out)); + return Value{out}; +} + +template +inline ConstValue KernelInfoImpl::GetTensorConstantInput(size_t index, int* is_constant) const { + const OrtValue* out = nullptr; + ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out)); + return ConstValue{out}; +} + +template +inline std::string KernelInfoImpl::GetNodeName() const { + size_t size = 0; + + // Feed nullptr for the data buffer to query the true size of the string value + Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + + return out; +} + +template +inline Logger KernelInfoImpl::GetLogger() const { + const OrtLogger* out = nullptr; + ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out)); + return Logger{out}; +} + +inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) { + Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out)); +} + +inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) { + Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out)); +} + +inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) { + size_t size = 0; + // Feed nullptr for the data buffer to query the true size of the string attribute + Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + out.swap(result); +} + +inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector& result) { + size_t size = 0; + // Feed nullptr for the data buffer to query the true size of the attribute + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size)); + + std::vector out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size)); + out.swap(result); +} + +inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector& result) { + size_t size = 0; + + // Feed nullptr for the data buffer to query the true size of the attribute + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size)); + + std::vector out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size)); + out.swap(result); +} +} // namespace detail + +inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl{info} {} + +inline Op::Op(OrtOp* p) : detail::Base(p) {} + +inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version, + const char** type_constraint_names, + const ONNXTensorElementDataType* type_constraint_values, + size_t type_constraint_count, + const OpAttr* attr_values, size_t attr_count, + size_t input_count, size_t output_count) { + static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*), + "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely"); + auto attr_input_values = reinterpret_cast(attr_values); + OrtOp* op; + Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values, + static_cast(type_constraint_count), + attr_input_values, + static_cast(attr_count), + static_cast(input_count), + static_cast(output_count), &op)); + return Op{op}; +} + +inline void Op::Invoke(const OrtKernelContext* context, + const Value* input_values, + size_t input_count, + Value* output_values, + size_t output_count) { + static_assert(sizeof(Value) == sizeof(OrtValue*), + "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); + auto ort_input_values = reinterpret_cast(input_values); + auto ort_output_values = reinterpret_cast(output_values); + Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast(input_count), + ort_output_values, static_cast(output_count))); +} + +inline void Op::Invoke(const OrtKernelContext* context, + const OrtValue* const* input_values, + size_t input_count, + OrtValue* const* output_values, + size_t output_count) { + Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast(input_count), + output_values, static_cast(output_count))); +} + +inline std::string GetVersionString() { + return OrtGetApiBase()->GetVersionString(); +} + +inline std::string GetBuildInfoString() { + return GetApi().GetBuildInfoString(); +} + +inline std::vector GetAvailableProviders() { + char** providers; + int len; + + auto release_fn = [&len](char** providers) { + // This should always return nullptr. + ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len)); + }; + + ThrowOnError(GetApi().GetAvailableProviders(&providers, &len)); + std::unique_ptr guard(providers, release_fn); + std::vector available_providers; + available_providers.reserve(static_cast(len)); + for (int i = 0; i < len; ++i) { + available_providers.emplace_back(providers[i]); + } + return available_providers; +} + +template +void CustomOpBase::GetSessionConfigs(std::unordered_map& out, + ConstSessionOptions options) const { + const TOp* derived = static_cast(this); + std::vector keys = derived->GetSessionConfigKeys(); + + out.reserve(keys.size()); + + std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), ""); + const size_t prefix_size = config_entry_key.length(); + + for (const auto& key : keys) { + config_entry_key.resize(prefix_size); + config_entry_key.append(key); + out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), ""); + } +} + +inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api, + OrtShapeInferContext* ctx) : ort_api_(ort_api), ctx_(ctx) { + size_t input_count = 0; + Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputCount(ctx_, &input_count)); + for (size_t ith_input = 0; ith_input < input_count; ++ith_input) { + OrtTensorTypeAndShapeInfo* info{}; + Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputTypeShape(ctx, ith_input, &info)); + TensorTypeAndShapeInfo type_shape_info(info); + auto integer_shape = type_shape_info.GetShape(); + std::vector symbolic_shape(integer_shape.size(), {}); + if (!integer_shape.empty()) { + type_shape_info.GetSymbolicDimensions(&symbolic_shape[0], integer_shape.size()); + } + Shape shape; + for (size_t ith = 0; ith < integer_shape.size(); ++ith) { + if (symbolic_shape[ith] && std::string{symbolic_shape[ith]}.size() > 0) { + shape.emplace_back(symbolic_shape[ith]); + } else { + shape.emplace_back(integer_shape[ith]); + } + } + input_shapes_.push_back(std::move(shape)); + type_shape_info.release(); + } +} + +inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type) { + OrtTensorTypeAndShapeInfo* info = {}; + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info)); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetTensorElementType(info, type)); + + using InfoPtr = std::unique_ptr>; + + InfoPtr info_ptr(info, [this](OrtTensorTypeAndShapeInfo* obj) { + ort_api_->ReleaseTensorTypeAndShapeInfo(obj); + }); + + std::vector integer_dims; + std::vector symbolic_dims; + + for (const auto dim : shape) { + if (dim.IsInt()) { + integer_dims.push_back(dim.AsInt()); + symbolic_dims.push_back(""); + } else { + if (!dim.AsSym() || std::string{dim.AsSym()}.empty()) { + ORT_CXX_API_THROW("Symbolic dim must not be an empty string", ORT_INVALID_ARGUMENT); + } + integer_dims.push_back(SymbolicInteger::INVALID_INT_DIM); + symbolic_dims.push_back(dim.AsSym()); + } + } + + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size())); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size())); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->ShapeInferContext_SetOutputTypeShape(ctx_, indice, info)); + return Status{nullptr}; +} + +inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + int64_t i = {}; + size_t out = {}; + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out)); + return i; +} + +inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + int64_t i = {}; + size_t out = {}; + // first call to get the bytes needed + // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. + // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). + // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out); + if (status) { + size_t num_i = out / sizeof(int64_t); + ShapeInferContext::Ints ints(num_i, 0); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out)); + return ints; + } else { + if (out == 0u) { + return {}; + } + return {i}; + } +} + +inline float ShapeInferContext::GetAttrFloat(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + float f = {}; + size_t out = {}; + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out)); + return f; +} + +inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + float f = {}; + size_t out = {}; + // first call to get the bytes needed + // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. + // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). + // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out); + if (status) { + size_t num_f = out / sizeof(float); + ShapeInferContext::Floats floats(num_f, 0); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out)); + return floats; + } else { + if (out == 0u) { + return {}; + } + return {f}; + } +} + +inline std::string ShapeInferContext::GetAttrString(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + char c = {}; + size_t out = {}; + // first call to get the bytes needed + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out); + if (status) { + std::vector chars(out, '\0'); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out)); + return {chars.data()}; + } else { + return {c}; + } +} + +inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + char c = {}; + size_t out = {}; + // first call to get the bytes needed + // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. + // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). + // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out); + if (status) { + std::vector chars(out, '\0'); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out)); + ShapeInferContext::Strings strings; + char* char_st = chars.data(); + char* char_ed = char_st + out; + while (char_st < char_ed) { + strings.emplace_back(char_st); + while (*char_st != '\0') { + char_st++; + } + char_st++; + } + return strings; + } else { + if (out == 0u) { + return {}; + } + return {std::string{c}}; + } +} + +inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const { + const OrtOpAttr* attr_hdl = {}; + Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl)); + return attr_hdl; +} + +namespace detail { +inline std::vector StringsToCharPtrs(const std::vector& strings) { + std::vector ptrs; + ptrs.reserve(strings.size()); + std::transform(strings.begin(), strings.end(), std::back_inserter(ptrs), + [](const std::string& s) { return s.c_str(); }); + + return ptrs; +} +} // namespace detail + +#if !defined(ORT_MINIMAL_BUILD) +// static +inline void Node::Init(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes, + OrtNode*& node) { + auto inputs = detail::StringsToCharPtrs(input_names); + auto outputs = detail::StringsToCharPtrs(output_names); + + std::vector attributes_ptrs; + attributes_ptrs.reserve(attributes.size()); + std::transform(attributes.begin(), attributes.end(), std::back_inserter(attributes_ptrs), + [](OpAttr& attr) -> OrtOpAttr* { return attr; }); + + ThrowOnError(GetModelEditorApi().CreateNode(operator_name.c_str(), operator_domain.c_str(), node_name.c_str(), + inputs.data(), inputs.size(), + outputs.data(), outputs.size(), + attributes_ptrs.data(), attributes_ptrs.size(), + &node)); + + // Node now owns the attributes + std::for_each(attributes.begin(), attributes.end(), [](OpAttr& attr) { attr.release(); }); +} + +inline Node::Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes) { + Init(operator_name, operator_domain, node_name, input_names, output_names, attributes, p_); +} + +inline Node::Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names) { + std::vector empty_attributes; + Init(operator_name, operator_domain, node_name, input_names, output_names, empty_attributes, p_); +} + +inline Graph::Graph() { + ThrowOnError(GetModelEditorApi().CreateGraph(&p_)); +} + +inline Model::Model(const std::vector& opsets) { + std::vector domains; + std::vector versions; + domains.reserve(opsets.size()); + versions.reserve(opsets.size()); + + for (const auto& pair : opsets) { + domains.push_back(pair.first.c_str()); + versions.push_back(pair.second); + } + + ThrowOnError(GetModelEditorApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_)); +} + +inline ValueInfo::ValueInfo(const std::string& name, const ConstTypeInfo& type_info) { + ThrowOnError(GetModelEditorApi().CreateValueInfo(name.c_str(), type_info, &p_)); +} +#endif // !defined(ORT_MINIMAL_BUILD) + +namespace detail { +template <> +inline std::string ValueInfoImpl::Name() const { + const char* name = nullptr; + ThrowOnError(GetApi().GetValueInfoName(this->p_, &name)); + return name; +} + +template <> +inline ConstTypeInfo ValueInfoImpl::TypeInfo() const { + const OrtTypeInfo* type_info = nullptr; + ThrowOnError(GetApi().GetValueInfoTypeInfo(this->p_, &type_info)); + return ConstTypeInfo{type_info}; +} + +#if !defined(ORT_MINIMAL_BUILD) +template <> +inline void GraphImpl::SetInputs(std::vector& inputs) { + std::vector inputs_ptrs; + inputs_ptrs.reserve(inputs.size()); + std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_ptrs), + [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); + + ThrowOnError(GetModelEditorApi().SetGraphInputs(p_, inputs_ptrs.data(), inputs_ptrs.size())); + + // Graph now owns the inputs + std::for_each(inputs.begin(), inputs.end(), [](ValueInfo& vi) { vi.release(); }); +} + +template <> +inline void GraphImpl::SetOutputs(std::vector& outputs) { + std::vector outputs_ptrs; + outputs_ptrs.reserve(outputs.size()); + std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_ptrs), + [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); + + ThrowOnError(GetModelEditorApi().SetGraphOutputs(p_, outputs_ptrs.data(), outputs_ptrs.size())); + + // Graph now owns the outputs + std::for_each(outputs.begin(), outputs.end(), [](ValueInfo& vi) { vi.release(); }); +} + +template <> +inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { + // Graph takes ownership of `initializer` + ThrowOnError(GetModelEditorApi().AddInitializerToGraph(p_, name.c_str(), initializer.release(), data_is_external)); +} + +template <> +inline void GraphImpl::AddNode(Node& node) { + // Graph takes ownership of `node` + ThrowOnError(GetModelEditorApi().AddNodeToGraph(p_, node.release())); +} + +template <> +inline void ModelImpl::AddGraph(Graph& graph) { + // Model takes ownership of `graph` + ThrowOnError(GetModelEditorApi().AddGraphToModel(p_, graph.release())); +} +#endif // !defined(ORT_MINIMAL_BUILD) + +} // namespace detail +} // namespace Ort diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_float16.h b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_float16.h new file mode 100644 index 000000000..408d3ccfb --- /dev/null +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_float16.h @@ -0,0 +1,535 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +namespace onnxruntime_float16 { + +namespace detail { + +enum class endian { +#if defined(_WIN32) + little = 0, + big = 1, + native = little, +#elif defined(__GNUC__) || defined(__clang__) + little = __ORDER_LITTLE_ENDIAN__, + big = __ORDER_BIG_ENDIAN__, + native = __BYTE_ORDER__, +#else +#error onnxruntime_float16::detail::endian is not implemented in this environment. +#endif +}; + +static_assert( + endian::native == endian::little || endian::native == endian::big, + "Only little-endian or big-endian native byte orders are supported."); + +} // namespace detail + +/// +/// Shared implementation between public and internal classes. CRTP pattern. +/// +template +struct Float16Impl { + protected: + /// + /// Converts from float to uint16_t float16 representation + /// + /// + /// + constexpr static uint16_t ToUint16Impl(float v) noexcept; + + /// + /// Converts float16 to float + /// + /// float representation of float16 value + float ToFloatImpl() const noexcept; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + uint16_t AbsImpl() const noexcept { + return static_cast(val & ~kSignMask); + } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + uint16_t NegateImpl() const noexcept { + return IsNaN() ? val : static_cast(val ^ kSignMask); + } + + public: + // uint16_t special values + static constexpr uint16_t kSignMask = 0x8000U; + static constexpr uint16_t kBiasedExponentMask = 0x7C00U; + static constexpr uint16_t kPositiveInfinityBits = 0x7C00U; + static constexpr uint16_t kNegativeInfinityBits = 0xFC00U; + static constexpr uint16_t kPositiveQNaNBits = 0x7E00U; + static constexpr uint16_t kNegativeQNaNBits = 0xFE00U; + static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number + static constexpr uint16_t kOneBits = 0x3C00U; + static constexpr uint16_t kMinusOneBits = 0xBC00U; + + uint16_t val{0}; + + Float16Impl() = default; + + /// + /// Checks if the value is negative + /// + /// true if negative + bool IsNegative() const noexcept { + return static_cast(val) < 0; + } + + /// + /// Tests if the value is NaN + /// + /// true if NaN + bool IsNaN() const noexcept { + return AbsImpl() > kPositiveInfinityBits; + } + + /// + /// Tests if the value is finite + /// + /// true if finite + bool IsFinite() const noexcept { + return AbsImpl() < kPositiveInfinityBits; + } + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + bool IsPositiveInfinity() const noexcept { + return val == kPositiveInfinityBits; + } + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + bool IsNegativeInfinity() const noexcept { + return val == kNegativeInfinityBits; + } + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + bool IsInfinity() const noexcept { + return AbsImpl() == kPositiveInfinityBits; + } + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + bool IsNaNOrZero() const noexcept { + auto abs = AbsImpl(); + return (abs == 0 || abs > kPositiveInfinityBits); + } + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + bool IsNormal() const noexcept { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) + } + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + bool IsSubnormal() const noexcept { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent) + } + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept { + return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; + } + + bool operator==(const Float16Impl& rhs) const noexcept { + if (IsNaN() || rhs.IsNaN()) { + // IEEE defines that NaN is not equal to anything, including itself. + return false; + } + return val == rhs.val; + } + + bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); } + + bool operator<(const Float16Impl& rhs) const noexcept { + if (IsNaN() || rhs.IsNaN()) { + // IEEE defines that NaN is unordered with respect to everything, including itself. + return false; + } + + const bool left_is_negative = IsNegative(); + if (left_is_negative != rhs.IsNegative()) { + // When the signs of left and right differ, we know that left is less than right if it is + // the negative value. The exception to this is if both values are zero, in which case IEEE + // says they should be equal, even if the signs differ. + return left_is_negative && !AreZero(*this, rhs); + } + return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative); + } +}; + +// The following Float16_t conversions are based on the code from +// Eigen library. + +// The conversion routines are Copyright (c) Fabian Giesen, 2016. +// The original license follows: +// +// Copyright (c) Fabian Giesen, 2016 +// All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted. +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +namespace detail { +union float32_bits { + unsigned int u; + float f; +}; +} // namespace detail + +template +inline constexpr uint16_t Float16Impl::ToUint16Impl(float v) noexcept { + detail::float32_bits f{}; + f.f = v; + + constexpr detail::float32_bits f32infty = {255 << 23}; + constexpr detail::float32_bits f16max = {(127 + 16) << 23}; + constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; + constexpr unsigned int sign_mask = 0x80000000u; + uint16_t val = static_cast(0x0u); + + unsigned int sign = f.u & sign_mask; + f.u ^= sign; + + // NOTE all the integer compares in this function can be safely + // compiled into signed compares since all operands are below + // 0x80000000. Important if you want fast straight SSE2 code + // (since there's no unsigned PCMPGTD). + + if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set) + val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf + } else { // (De)normalized number or zero + if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + f.f += denorm_magic.f; + + // and one integer subtract of the bias later, we have our final float! + val = static_cast(f.u - denorm_magic.u); + } else { + unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but + // without arithmetic overflow. + f.u += 0xc8000fffU; + // rounding bias part 2 + f.u += mant_odd; + // take the bits! + val = static_cast(f.u >> 13); + } + } + + val |= static_cast(sign >> 16); + return val; +} + +template +inline float Float16Impl::ToFloatImpl() const noexcept { + constexpr detail::float32_bits magic = {113 << 23}; + constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift + detail::float32_bits o{}; + + o.u = (val & 0x7fff) << 13; // exponent/mantissa bits + unsigned int exp = shifted_exp & o.u; // just the exponent + o.u += (127 - 15) << 23; // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.u += (128 - 16) << 23; // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.u += 1 << 23; // extra exp adjust + o.f -= magic.f; // re-normalize + } + + // Attempt to workaround the Internal Compiler Error on ARM64 + // for bitwise | operator, including std::bitset +#if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC) + if (IsNegative()) { + return -o.f; + } +#else + // original code: + o.u |= (val & 0x8000U) << 16U; // sign bit +#endif + return o.f; +} + +/// Shared implementation between public and internal classes. CRTP pattern. +template +struct BFloat16Impl { + protected: + /// + /// Converts from float to uint16_t float16 representation + /// + /// + /// + static uint16_t ToUint16Impl(float v) noexcept; + + /// + /// Converts bfloat16 to float + /// + /// float representation of bfloat16 value + float ToFloatImpl() const noexcept; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + uint16_t AbsImpl() const noexcept { + return static_cast(val & ~kSignMask); + } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + uint16_t NegateImpl() const noexcept { + return IsNaN() ? val : static_cast(val ^ kSignMask); + } + + public: + // uint16_t special values + static constexpr uint16_t kSignMask = 0x8000U; + static constexpr uint16_t kBiasedExponentMask = 0x7F80U; + static constexpr uint16_t kPositiveInfinityBits = 0x7F80U; + static constexpr uint16_t kNegativeInfinityBits = 0xFF80U; + static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U; + static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U; + static constexpr uint16_t kMaxValueBits = 0x7F7FU; + static constexpr uint16_t kRoundToNearest = 0x7FFFU; + static constexpr uint16_t kOneBits = 0x3F80U; + static constexpr uint16_t kMinusOneBits = 0xBF80U; + + uint16_t val{0}; + + BFloat16Impl() = default; + + /// + /// Checks if the value is negative + /// + /// true if negative + bool IsNegative() const noexcept { + return static_cast(val) < 0; + } + + /// + /// Tests if the value is NaN + /// + /// true if NaN + bool IsNaN() const noexcept { + return AbsImpl() > kPositiveInfinityBits; + } + + /// + /// Tests if the value is finite + /// + /// true if finite + bool IsFinite() const noexcept { + return AbsImpl() < kPositiveInfinityBits; + } + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + bool IsPositiveInfinity() const noexcept { + return val == kPositiveInfinityBits; + } + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + bool IsNegativeInfinity() const noexcept { + return val == kNegativeInfinityBits; + } + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + bool IsInfinity() const noexcept { + return AbsImpl() == kPositiveInfinityBits; + } + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + bool IsNaNOrZero() const noexcept { + auto abs = AbsImpl(); + return (abs == 0 || abs > kPositiveInfinityBits); + } + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + bool IsNormal() const noexcept { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) + } + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + bool IsSubnormal() const noexcept { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent) + } + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept { + // IEEE defines that positive and negative zero are equal, this gives us a quick equality check + // for two values by or'ing the private bits together and stripping the sign. They are both zero, + // and therefore equivalent, if the resulting value is still zero. + return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; + } +}; + +template +inline uint16_t BFloat16Impl::ToUint16Impl(float v) noexcept { + uint16_t result; + if (std::isnan(v)) { + result = kPositiveQNaNBits; + } else { + auto get_msb_half = [](float fl) { + uint16_t result; +#ifdef __cpp_if_constexpr + if constexpr (detail::endian::native == detail::endian::little) { +#else + if (detail::endian::native == detail::endian::little) { +#endif + std::memcpy(&result, reinterpret_cast(&fl) + sizeof(uint16_t), sizeof(uint16_t)); + } else { + std::memcpy(&result, &fl, sizeof(uint16_t)); + } + return result; + }; + + uint16_t upper_bits = get_msb_half(v); + union { + uint32_t U32; + float F32; + }; + F32 = v; + U32 += (upper_bits & 1) + kRoundToNearest; + result = get_msb_half(F32); + } + return result; +} + +template +inline float BFloat16Impl::ToFloatImpl() const noexcept { + if (IsNaN()) { + return std::numeric_limits::quiet_NaN(); + } + float result; + char* const first = reinterpret_cast(&result); + char* const second = first + sizeof(uint16_t); +#ifdef __cpp_if_constexpr + if constexpr (detail::endian::native == detail::endian::little) { +#else + if (detail::endian::native == detail::endian::little) { +#endif + std::memset(first, 0, sizeof(uint16_t)); + std::memcpy(second, &val, sizeof(uint16_t)); + } else { + std::memcpy(first, &val, sizeof(uint16_t)); + std::memset(second, 0, sizeof(uint16_t)); + } + return result; +} + +} // namespace onnxruntime_float16 diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_lite_custom_op.h b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_lite_custom_op.h new file mode 100644 index 000000000..ce87d8c56 --- /dev/null +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_lite_custom_op.h @@ -0,0 +1,1119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Summary +// The header has APIs to save custom op authors the trouble of defining schemas, +// which will be inferred by functions' signature, as long as their argument list has types supported here. +// Input could be: +// 1. Tensor of onnx data types. +// 2. Span of onnx data types. +// 3. Scalar of onnx data types. +// A input could be optional if indicated as std::optional<...>. +// For an output, it must be a tensor of onnx data types. +// Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op. +// For concrete examples, please search keyword "LiteCustomOpTest" under "/onnxruntime/test/". +// Note - all APIs in this header are ABI. + +#pragma once +#include "onnxruntime_cxx_api.h" +#include +#include +#include +#include + +namespace Ort { +namespace Custom { + +class ArgBase { + public: + ArgBase(OrtKernelContext* ctx, + size_t indice, + bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {} + virtual ~ArgBase() {}; + + protected: + struct KernelContext ctx_; + size_t indice_; + bool is_input_; +}; + +using ArgPtr = std::unique_ptr; +using ArgPtrs = std::vector; + +class TensorBase : public ArgBase { + public: + TensorBase(OrtKernelContext* ctx, + size_t indice, + bool is_input) : ArgBase(ctx, indice, is_input) {} + + operator bool() const { + return shape_.has_value(); + } + + const std::vector& Shape() const { + if (!shape_.has_value()) { + ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return shape_.value(); + } + + ONNXTensorElementDataType Type() const { + return type_; + } + + int64_t NumberOfElement() const { + if (shape_.has_value()) { + return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); + } else { + return 0; + } + } + + std::string Shape2Str() const { + if (shape_.has_value()) { + std::string shape_str; + for (const auto& dim : *shape_) { + shape_str.append(std::to_string(dim)); + shape_str.append(", "); + } + return shape_str; + } else { + return "empty"; + } + } + + bool IsCpuTensor() const { + return strcmp("Cpu", mem_type_) == 0; + } + + virtual const void* DataRaw() const = 0; + virtual size_t SizeInBytes() const = 0; + + protected: + std::optional> shape_; + ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + const char* mem_type_ = "Cpu"; +}; + +template +struct Span { + const T* data_ = {}; + size_t size_ = {}; + void Assign(const T* data, size_t size) { + data_ = data; + size_ = size; + } + size_t size() const { return size_; } + T operator[](size_t indice) const { + return data_[indice]; + } + const T* data() const { return data_; } +}; + +template +class Tensor : public TensorBase { + public: + using TT = typename std::remove_reference::type; + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { + if (is_input_) { + if (indice >= ctx_.GetInputCount()) { + ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); + } + const_value_ = ctx_.GetInput(indice); + auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo(); + shape_ = type_shape_info.GetShape(); + } + } + const TT* Data() const { + return reinterpret_cast(const_value_.GetTensorRawData()); + } + TT* Allocate(const std::vector& shape) { + shape_ = shape; + if (!data_) { + shape_ = shape; + data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData(); + } + return data_; + } + static TT GetT() { return (TT)0; } + const Span& AsSpan() { + if (!shape_.has_value() || shape_->size() != 1) { + ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor", + OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + span_.Assign(Data(), static_cast((*shape_)[0])); + return span_; + } + const T& AsScalar() { + if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) { + ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor", + OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return *Data(); + } + const void* DataRaw() const override { + return reinterpret_cast(Data()); + } + + size_t SizeInBytes() const override { + return sizeof(TT) * static_cast(NumberOfElement()); + } + + private: + ConstValue const_value_; // for input + TT* data_{}; // for output + Span span_; +}; + +template <> +class Tensor : public TensorBase { + public: + using strings = std::vector; + + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { + if (is_input_) { + if (indice >= ctx_.GetInputCount()) { + ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); + } + auto const_value = ctx_.GetInput(indice); + auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); + shape_ = type_shape_info.GetShape(); + auto num_chars = const_value.GetStringTensorDataLength(); + // note - there will be copy ... + auto num_strings = static_cast(NumberOfElement()); + if (num_strings) { + std::vector chars(num_chars + 1, '\0'); + std::vector offsets(num_strings); + const_value.GetStringTensorContent(static_cast(chars.data()), num_chars, offsets.data(), offsets.size()); + auto upper_bound = num_strings - 1; + input_strings_.resize(num_strings); + for (size_t i = upper_bound;; --i) { + if (i < upper_bound) { + chars[offsets[i + 1]] = '\0'; + } + input_strings_[i] = chars.data() + offsets[i]; + if (0 == i) { + break; + } + } + } + } + } + const strings& Data() const { + return input_strings_; + } + const void* DataRaw() const override { + if (input_strings_.size() != 1) { + ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return reinterpret_cast(input_strings_[0].c_str()); + } + size_t SizeInBytes() const override { + if (input_strings_.size() != 1) { + ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return input_strings_[0].size(); + } + void SetStringOutput(const strings& ss, const std::vector& dims) { + shape_ = dims; + std::vector raw; + for (const auto& s : ss) { + raw.push_back(s.data()); + } + auto output = ctx_.GetOutput(indice_, dims.data(), dims.size()); + // note - there will be copy ... + output.FillStringTensor(raw.data(), raw.size()); + } + const Span& AsSpan() { + ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + const std::string& AsScalar() { + if (input_strings_.size() != 1) { + ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor", + OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return input_strings_[0]; + } + + private: + std::vector input_strings_; // for input +}; + +template <> +class Tensor : public TensorBase { + public: + using strings = std::vector; + using string_views = std::vector; + + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { + if (is_input_) { + if (indice >= ctx_.GetInputCount()) { + ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); + } + auto const_value = ctx_.GetInput(indice); + auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); + shape_ = type_shape_info.GetShape(); + auto num_chars = const_value.GetStringTensorDataLength(); + chars_.resize(num_chars + 1, '\0'); + auto num_strings = static_cast(NumberOfElement()); + if (num_strings) { + std::vector offsets(num_strings); + const_value.GetStringTensorContent(static_cast(chars_.data()), num_chars, offsets.data(), offsets.size()); + offsets.push_back(num_chars); + for (size_t i = 0; i < num_strings; ++i) { + input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]); + } + } + } + } + const string_views& Data() const { + return input_string_views_; + } + const void* DataRaw() const override { + if (input_string_views_.size() != 1) { + ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return reinterpret_cast(input_string_views_[0].data()); + } + size_t SizeInBytes() const override { + if (input_string_views_.size() != 1) { + ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return input_string_views_[0].size(); + } + void SetStringOutput(const strings& ss, const std::vector& dims) { + shape_ = dims; + std::vector raw; + for (const auto& s : ss) { + raw.push_back(s.data()); + } + auto output = ctx_.GetOutput(indice_, dims.data(), dims.size()); + // note - there will be copy ... + output.FillStringTensor(raw.data(), raw.size()); + } + const Span& AsSpan() { + ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + std::string_view AsScalar() { + if (input_string_views_.size() != 1) { + ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor", + OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return input_string_views_[0]; + } + + private: + std::vector chars_; // for input + std::vector input_string_views_; // for input +}; + +using TensorPtr = std::unique_ptr; +using TensorPtrs = std::vector; + +struct TensorArray : public ArgBase { + TensorArray(OrtKernelContext* ctx, + size_t start_indice, + bool is_input) : ArgBase(ctx, + start_indice, + is_input) { + if (is_input) { + auto input_count = ctx_.GetInputCount(); + for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) { + auto const_value = ctx_.GetInput(start_indice); + auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); + auto type = type_shape_info.GetElementType(); + TensorPtr tensor; + switch (type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + tensor = std::make_unique>(ctx, ith_input, true); + break; + default: + ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION); + break; + } + tensors_.emplace_back(tensor.release()); + } // for + } + } + template + T* AllocateOutput(size_t ith_output, const std::vector& shape) { + // ith_output is the indice of output relative to the tensor array + // indice_ + ith_output is the indice relative to context + auto tensor = std::make_unique>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false); + auto raw_output = tensor.get()->Allocate(shape); + tensors_.emplace_back(tensor.release()); + return raw_output; + } + Tensor& AllocateStringTensor(size_t ith_output) { + // ith_output is the indice of output relative to the tensor array + // indice_ + ith_output is the indice relative to context + auto tensor = std::make_unique>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false); + Tensor& output = *tensor; + tensors_.emplace_back(tensor.release()); + return output; + } + size_t Size() const { + return tensors_.size(); + } + const TensorPtr& operator[](size_t ith_input) const { + // ith_input is the indice of output relative to the tensor array + return tensors_.at(ith_input); + } + + private: + TensorPtrs tensors_; +}; + +using Variadic = TensorArray; + +/* +Note: +OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core. +The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so: +1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierarchy. +2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp, + hence memory could still be recycled properly. +Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety. +*/ +struct OrtLiteCustomOp : public OrtCustomOp { + using ConstOptionalFloatTensor = std::optional&>; + using OptionalFloatTensor = std::optional>; + + // CreateTuple + template + static typename std::enable_if>::type + CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) { + return std::make_tuple(); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + std::tuple current = std::tuple{context}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + std::tuple current = std::tuple{*context}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + +#ifdef ORT_CUDA_CTX + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + thread_local CudaContext cuda_context; + cuda_context.Init(*context); + std::tuple current = std::tuple{cuda_context}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } +#endif + +#ifdef ORT_ROCM_CTX + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + thread_local RocmContext rocm_context; + rocm_context.Init(*context); + std::tuple current = std::tuple{rocm_context}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } +#endif + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_input, true)); + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_input, true)); + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_output, false)); + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_output, false)); + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + +#define CREATE_TUPLE_INPUT(data_type) \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{&reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{&reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } \ + template \ + static typename std::enable_if::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsScalar()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsScalar()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } +#define CREATE_TUPLE_OUTPUT(data_type) \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_output < num_output) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } +#define CREATE_TUPLE(data_type) \ + CREATE_TUPLE_INPUT(data_type) \ + CREATE_TUPLE_OUTPUT(data_type) + + CREATE_TUPLE(bool) + CREATE_TUPLE(float) + CREATE_TUPLE(Ort::Float16_t) + CREATE_TUPLE(Ort::BFloat16_t) + CREATE_TUPLE(double) + CREATE_TUPLE(int8_t) + CREATE_TUPLE(int16_t) + CREATE_TUPLE(int32_t) + CREATE_TUPLE(int64_t) + CREATE_TUPLE(uint8_t) + CREATE_TUPLE(uint16_t) + CREATE_TUPLE(uint32_t) + CREATE_TUPLE(uint64_t) + CREATE_TUPLE(std::string) + CREATE_TUPLE_INPUT(std::string_view) + CREATE_TUPLE(Ort::Float8E4M3FN_t) + CREATE_TUPLE(Ort::Float8E4M3FNUZ_t) + CREATE_TUPLE(Ort::Float8E5M2_t) + CREATE_TUPLE(Ort::Float8E5M2FNUZ_t) + + // ParseArgs ... + template + static typename std::enable_if<0 == sizeof...(Ts)>::type + ParseArgs(std::vector&, std::vector&) { + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + ParseArgs(input_types, output_types); + } + +#ifdef ORT_CUDA_CTX + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + ParseArgs(input_types, output_types); + } +#endif + +#ifdef ORT_ROCM_CTX + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + ParseArgs(input_types, output_types); + } +#endif + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + +#define PARSE_INPUT_BASE(pack_type, onnx_type) \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + input_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + input_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + input_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } + +#define PARSE_INPUT(data_type, onnx_type) \ + PARSE_INPUT_BASE(const Custom::Tensor*, onnx_type) \ + PARSE_INPUT_BASE(const Custom::Tensor&, onnx_type) \ + PARSE_INPUT_BASE(const Custom::Span*, onnx_type) \ + PARSE_INPUT_BASE(const Custom::Span&, onnx_type) \ + PARSE_INPUT_BASE(data_type, onnx_type) + +#define PARSE_OUTPUT(data_type, onnx_type) \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same*>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + output_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same&>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + output_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same*>>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + output_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } + +#define PARSE_ARGS(data_type, onnx_type) \ + PARSE_INPUT(data_type, onnx_type) \ + PARSE_OUTPUT(data_type, onnx_type) + + PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) + PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) + PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) + PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) + PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) + PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16) + PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) + PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) + PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) + PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) + PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) + PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) + PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) + PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output + PARSE_ARGS(Ort::Float8E4M3FN_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN) + PARSE_ARGS(Ort::Float8E4M3FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ) + PARSE_ARGS(Ort::Float8E5M2_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) + PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ) + + OrtLiteCustomOp(const char* op_name, + const char* execution_provider, + ShapeInferFn shape_infer_fn, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), + execution_provider_(execution_provider), + shape_infer_fn_(shape_infer_fn), + start_ver_(start_ver), + end_ver_(end_ver) { + OrtCustomOp::version = ORT_API_VERSION; + + OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; + OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); }; + OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; }; + + OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->input_types_.size(); + }; + + OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->input_types_[indice]; + }; + + OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->output_types_.size(); + }; + + OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->output_types_[indice]; + }; + + OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL; + }; + + OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL; + }; + + OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { + return 1; + }; + + OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { + return 0; + }; + + OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { + return 1; + }; + + OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { + return 0; + }; + + OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; }; + OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; }; + OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; }; + OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; }; + + OrtCustomOp::CreateKernelV2 = {}; + OrtCustomOp::KernelComputeV2 = {}; + OrtCustomOp::KernelCompute = {}; + + OrtCustomOp::InferOutputShapeFn = {}; + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->end_ver_; + }; + + OrtCustomOp::GetMayInplace = {}; + OrtCustomOp::ReleaseMayInplace = {}; + OrtCustomOp::GetAliasMap = {}; + OrtCustomOp::ReleaseAliasMap = {}; + } + + const std::string op_name_; + const std::string execution_provider_; + + std::vector input_types_; + std::vector output_types_; + + ShapeInferFn shape_infer_fn_ = {}; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; + + void* compute_fn_ = {}; + void* compute_fn_return_status_ = {}; +}; + +//////////////////////////// OrtLiteCustomFunc //////////////////////////////// +// The struct is to implement function-as-op. +// E.g. a function might be defined as: +// void Filter(const Ort::Custom::Tensor& floats_in, Ort::Custom::Tensor& floats_out) { ... } +// It could be registered this way: +// Ort::CustomOpDomain v2_domain{"v2"}; +// std::unique_ptr fil_op_ptr{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)}; +// v2_domain.Add(fil_op_ptr.get()); +// session_options.Add(v2_domain); +// For the complete example, please search keyword "LiteCustomOpTest" under "/onnxruntime/test/". +template +struct OrtLiteCustomFunc : public OrtLiteCustomOp { + using ComputeFn = void (*)(Args...); + using ComputeFnReturnStatus = Status (*)(Args...); + using MyType = OrtLiteCustomFunc; + + struct Kernel { + size_t num_input_{}; + size_t num_output_{}; + ComputeFn compute_fn_{}; + ComputeFnReturnStatus compute_fn_return_status_{}; + std::string ep_{}; + }; + + OrtLiteCustomFunc(const char* op_name, + const char* execution_provider, + ComputeFn compute_fn, + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_ = reinterpret_cast(compute_fn); + ParseArgs(input_types_, output_types_); + + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { + auto kernel = reinterpret_cast(op_kernel); + std::vector args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t); + }; + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { + auto kernel = std::make_unique(); + auto me = static_cast(this_); + kernel->compute_fn_ = reinterpret_cast(me->compute_fn_); + Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); + Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); + auto self = static_cast(this_); + kernel->ep_ = self->execution_provider_; + return reinterpret_cast(kernel.release()); + }; + + OrtCustomOp::KernelDestroy = [](void* op_kernel) { + delete reinterpret_cast(op_kernel); + }; + + if (shape_infer_fn_) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + auto shape_info_fn = static_cast(op)->shape_infer_fn_; + ShapeInferContext ctx(&GetApi(), ort_ctx); + return shape_info_fn(ctx); + }; + } + } + + OrtLiteCustomFunc(const char* op_name, + const char* execution_provider, + ComputeFnReturnStatus compute_fn_return_status, + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_return_status_ = reinterpret_cast(compute_fn_return_status); + ParseArgs(input_types_, output_types_); + + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + auto kernel = reinterpret_cast(op_kernel); + std::vector args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t); + }; + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { + auto kernel = std::make_unique(); + auto me = static_cast(this_); + kernel->compute_fn_return_status_ = reinterpret_cast(me->compute_fn_return_status_); + Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); + Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); + auto self = static_cast(this_); + kernel->ep_ = self->execution_provider_; + return reinterpret_cast(kernel.release()); + }; + + OrtCustomOp::KernelDestroy = [](void* op_kernel) { + delete reinterpret_cast(op_kernel); + }; + + if (shape_infer_fn_) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + auto shape_info_fn = static_cast(op)->shape_infer_fn_; + ShapeInferContext ctx(&GetApi(), ort_ctx); + return shape_info_fn(ctx); + }; + } + } +}; // struct OrtLiteCustomFunc + +/////////////////////////// OrtLiteCustomStruct /////////////////////////// +// The struct is to implement struct-as-op. +// E.g. a struct might be defined as: +// struct Merge { +// Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...} +// void Compute(const Ort::Custom::Tensor& strings_in, +// std::string_view string_in, +// Ort::Custom::Tensor* strings_out) {...} +// bool reverse_ = false; +// }; +// It could be registered this way: +// Ort::CustomOpDomain v2_domain{"v2"}; +// std::unique_ptr mrg_op_ptr{Ort::Custom::CreateLiteCustomOp("Merge", "CPUExecutionProvider")}; +// v2_domain.Add(mrg_op_ptr.get()); +// session_options.Add(v2_domain); +// For the complete example, please search keyword "LiteCustomOpTest" under "/onnxruntime/test/". +template +struct OrtLiteCustomStruct : public OrtLiteCustomOp { + template + using CustomComputeFn = void (CustomOp::*)(Args...); + + template + using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...); + + using MyType = OrtLiteCustomStruct; + + struct Kernel { + size_t num_input_{}; + size_t num_output_{}; + std::unique_ptr custom_op_; + std::string ep_{}; + }; + + OrtLiteCustomStruct(const char* op_name, + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) { + SetCompute(&CustomOp::Compute); + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { + auto kernel = std::make_unique(); + Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); + Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); + kernel->custom_op_ = std::make_unique(ort_api, info); + auto self = static_cast(this_); + kernel->ep_ = self->execution_provider_; + return reinterpret_cast(kernel.release()); + }; + + OrtCustomOp::KernelDestroy = [](void* op_kernel) { + delete reinterpret_cast(op_kernel); + }; + + SetShapeInfer(0); + } + + template + void SetCompute(CustomComputeFn) { + ParseArgs(input_types_, output_types_); + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { + auto kernel = reinterpret_cast(op_kernel); + ArgPtrs args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t); + }; + } + + template + void SetCompute(CustomComputeFnReturnStatus) { + ParseArgs(input_types_, output_types_); + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + auto kernel = reinterpret_cast(op_kernel); + ArgPtrs args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t); + }; + } + + template + decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + ShapeInferContext ctx(&GetApi(), ort_ctx); + return C::InferOutputShape(ctx); + }; + return {}; + } + + template + void SetShapeInfer(...) { + OrtCustomOp::InferOutputShapeFn = {}; + } +}; // struct OrtLiteCustomStruct + +/////////////////////////// CreateLiteCustomOp //////////////////////////// + +template +OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, + const char* execution_provider, + void (*custom_compute_fn)(Args...), + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { + using LiteOp = OrtLiteCustomFunc; + return std::make_unique(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release(); +} + +template +OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, + const char* execution_provider, + Status (*custom_compute_fn_v2)(Args...), + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { + using LiteOp = OrtLiteCustomFunc; + return std::make_unique(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release(); +} + +template +OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { + using LiteOp = OrtLiteCustomStruct; + return std::make_unique(op_name, execution_provider, start_ver, end_ver).release(); +} + +} // namespace Custom +} // namespace Ort diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h new file mode 100644 index 000000000..f40ea6591 --- /dev/null +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +/* + * This file defines RunOptions Config Keys and format of the Config Values. + * + * The Naming Convention for a RunOptions Config Key, + * "[Area][.[SubArea1].[SubArea2]...].[Keyname]" + * Such as "ep.cuda.use_arena" + * The Config Key cannot be empty + * The maximum length of the Config Key is 128 + * + * The string format of a RunOptions Config Value is defined individually for each Config. + * The maximum length of the Config Value is 1024 + */ + +// Key for enabling shrinkages of user listed device memory arenas. +// Expects a list of semi-colon separated key value pairs separated by colon in the following format: +// "device_0:device_id_0;device_1:device_id_1" +// No white-spaces allowed in the provided list string. +// Currently, the only supported devices are : "cpu", "gpu" (case sensitive). +// If "cpu" is included in the list, DisableCpuMemArena() API must not be called (i.e.) arena for cpu should be enabled. +// Example usage: "cpu:0;gpu:0" (or) "gpu:0" +// By default, the value for this key is empty (i.e.) no memory arenas are shrunk +static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage"; + +// Set to '1' to not synchronize execution providers with CPU at the end of session run. +// Per default it will be set to '0' +// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream. +static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers"; + +// Set HTP performance mode for QNN HTP backend before session run. +// options for HTP performance mode: "burst", "balanced", "default", "high_performance", +// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", +// "sustained_high_performance". Default to "default". +static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode"; + +// Set HTP performance mode for QNN HTP backend post session run. +static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run"; + +// Set RPC control latency for QNN HTP backend +static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency"; + +// Set QNN Lora Config File for apply Lora in QNN context binary +static const char* const kOrtRunOptionsConfigQnnLoraConfig = "qnn.lora_config"; + +// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true. +// The value should be an integer. If the value is not set, the default value is 0 and +// ORT session only captures one cuda graph before another capture is requested. +// If the value is set to -1, cuda graph capture/replay is disabled in that run. +// User are not expected to set the value to 0 as it is reserved for internal use. +static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id"; diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h new file mode 100644 index 000000000..379c74e01 --- /dev/null +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h @@ -0,0 +1,345 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +/* + * This file defines SessionOptions Config Keys and format of the Config Values. + * + * The Naming Convention for a SessionOptions Config Key, + * "[Area][.[SubArea1].[SubArea2]...].[Keyname]" + * Such as "ep.cuda.use_arena" + * The Config Key cannot be empty + * The maximum length of the Config Key is 1024 + * + * The string format of a SessionOptions Config Value is defined individually for each Config. + * The maximum length of the Config Value is 2048 + */ + +// Key for disable PrePacking, +// If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value) +static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking"; + +// A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session +// will be used. Use this to override the usage of env allocators on a per session level. +static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators"; + +// Set to 'ORT' (case sensitive) to load an ORT format model. +// If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT +static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format"; + +// Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set. +// If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'. +static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format"; + +// If a value is "1", flush-to-zero and denormal-as-zero are applied. The default is "0". +// When multiple sessions are created, a main thread doesn't override changes from succeeding session options, +// but threads in session thread pools follow option changes. +// When ORT runs with OpenMP, the same rule is applied, i.e. the first session option to flush-to-zero and +// denormal-as-zero is only applied to global OpenMP thread pool, which doesn't support per-session thread pool. +// Note that an alternative way not using this option at runtime is to train and export a model without denormals +// and that's recommended because turning this option on may hurt model accuracy. +static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.set_denormal_as_zero"; + +// It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not. +// "0": enable. ORT does fusion logic for QDQ format. +// "1": disable. ORT doesn't do fusion logic for QDQ format. +// Its default value is "0" unless the DirectML execution provider is registered, in which case it defaults to "1". +static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq"; + +// It controls whether to enable Double QDQ remover and Identical Children Consolidation +// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs +// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs +// Its default value is "0" +static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover"; + +// If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been +// completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the +// Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to +// 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on +// other factors like whether the model was created using Quantization Aware Training or Post Training Quantization. +// As such, it's best to test to determine if enabling this works well for your scenario. +// The default value is "0" +// Available since version 1.11. +static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup"; + +// Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0". +// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this. +static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation"; + +// This setting controls whether to enable AheadOfTime function inlining. +// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model +// as possible with the help of enabled execution providers. +// This can reduce the number of function calls and improve performance because it is done before +// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available, +// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time. +// "0": enable; "1": disable. +// Its default value is "0". +static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining"; + +#ifdef ENABLE_TRAINING +// Specifies a path of the file containing a list of memory optimization configurations. +// The value should be a string indicating the file path of the config file. +// The content of the config file is a JSON struct like this: +// [ +// "Gelu+Cast+:1:0", +// "Dropout+:1:1" +// ] +// Taking the example of "Gelu+Cast+:1:0", +// > "Gelu+Cast+" is the subgraph string, a valid "subgraph string" should be one subgraph representation +// output by ORT graph transformations. +// > "1" is "optimization strategy", valid values: 0 - disabled, 1 - recompute. +// > "0" is "number of subgraph to apply" which is used to control how many subgraphs to apply optimization, +// to avoid "oversaving" the memory. +static const char* const kOrtSessionOptionsMemoryOptimizerApplyConfig = "optimization.memory_optimizer_config"; + +// Specifies the config for detecting subgraphs for memory footprint reduction. +// The value should be a string contains int separated using commas. The default value is "0:0". +static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimization.enable_memory_probe_recompute_config"; +#endif + +// This setting if set should contain a comma separated list of optimizers names that should be disabled. +// Optimizers may take time to execute and affect model loading time. If you feel that a specific optimizer +// does not provider runtime benefits, but affects your model loading time you may disable it using this config +// entry. This option is not enabled in ORT_MINIMAL_BUILD build. +// A list of optimizes is available in onnxruntime/core/optimizer/graph_transformer_utils.cc +// +// Default is an empty string which means no optimizers are disabled. +static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers"; + +// Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0". +// Using device allocators means the memory allocation is made using malloc/new. +static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers"; + +// Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking +// "0": thread will block if found no job to run +// "1": default, thread will spin a number of times before blocking +static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning"; +static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning"; + +// Key for using model bytes directly for ORT format +// If a session is created using an input byte array contains the ORT format model data, +// By default we will copy the model bytes at the time of session creation to ensure the model bytes +// buffer is valid. +// Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller +// has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed. +static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly"; + +/// +/// Key for using the ORT format model flatbuffer bytes directly for initializers. +/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization. +/// Requires `session.use_ort_model_bytes_directly` to be true. +/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire +/// duration of the InferenceSession. +/// +static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers = + "session.use_ort_model_bytes_for_initializers"; + +// This should only be specified when exporting an ORT format model for use on a different platform. +// If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0" +// Available since version 1.11. +static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed"; + +// x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8. +// To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if +// turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512 +// platforms. +static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision"; + +// Specifies how minimal build graph optimizations are handled in a full build. +// These optimizations are at the extended level or higher. +// Possible values and their effects are: +// "save": Save runtime optimizations when saving an ORT format model. +// "apply": Only apply optimizations available in a minimal build. +// ""/: Apply optimizations available in a full build. +// Available since version 1.11. +static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations = + "optimization.minimal_build_optimizations"; + +// Note: The options specific to an EP should be specified prior to appending that EP to the session options object in +// order for them to take effect. + +// Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be +// run by the NNAPI EP. +// The value should be a ","-delimited list of op types. For example, "Add,Sub". +// If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op +// exclusion, set the value to "". +static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops"; + +// Enabling dynamic block-sizing for multithreading. +// With a positive value, thread pool will split a task of N iterations to blocks of size starting from: +// N / (num_of_threads * dynamic_block_base) +// As execution progresses, the size will decrease according to the diminishing residual of N, +// meaning the task will be distributed in smaller granularity for better parallelism. +// For some models, it helps to reduce the variance of E2E inference latency and boost performance. +// The feature will not function by default, specify any positive integer, e.g. "4", to enable it. +// Available since version 1.11. +static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base"; + +// This option allows to decrease CPU usage between infrequent +// requests and forces any TP threads spinning stop immediately when the last of +// concurrent Run() call returns. +// Spinning is restarted on the next Run() call. +// Applies only to internal thread-pools +static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop"; + +// "1": all inconsistencies encountered during shape and type inference +// will result in failures. +// "0": in some cases warnings will be logged but processing will continue. The default. +// May be useful to expose bugs in models. +static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference"; + +// "1": every model using a more recent opset than the latest released one will fail +// "0": the model may or may not work if onnxruntime cannot find an implementation, this option +// is used for development purpose. +static const char* const kOrtSessionOptionsConfigStrictAllowReleasedOpsetsOnly = "session.allow_released_opsets_only"; + +// The file saves configuration for partitioning node among logic streams +static const char* const kNodePartitionConfigFile = "session.node_partition_config_file"; + +// This Option allows setting affinities for intra op threads. +// Affinity string follows format: +// logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id +// Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to. +// e.g.1,2,3;4,5 +// specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th. +// To ease the configuration, an "interval" is also allowed: +// e.g. 1-8;8-16;17-24 +// orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth. +// Note: +// 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which +// is started and managed by the calling app; +// 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors, +// an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group. +// Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary. +static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities"; + +// This option will dump out the model to assist debugging any issues with layout transformation, +// and is primarily intended for developer usage. It is only relevant if an execution provider that requests +// NHWC layout is enabled such as NNAPI, XNNPACK or QNN. +// +// Default is off. Set to "1" to enable. +// +// If modified by layout transformation the model will be dumped after these steps: +// 1) insertion of the layout transformation Transpose nodes +// 2) after those are optimized using the transpose optimizer, +// 3) after the L1 transformers are applied to the updated graph. +// The model will be saved to filename post_layout_transform_step_.onnx. +static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation"; + +// Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are +// assigned (i.e., "fallback") to the CPU EP by default. +// +// This option allows the user to disable the fallback of unsupported graph nodes to the CPU EP. +// If this option is set to "1", session creation will fail if the execution providers other than the CPU EP cannot +// fully support all of the nodes in the graph. +// +// It is invalid to set this option and explicitly add the CPU EP to the session. In this case, session creation +// will also fail with an error. +// +// Option values: +// - "0": CPU EP fallback is not disabled. [DEFAULT] +// - "1": CPU EP fallback is disabled. +static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disable_cpu_ep_fallback"; + +// Use this config when serializing a large model after optimization to specify an external initializers file +static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName = + "session.optimized_model_external_initializers_file_name"; + +// Use this config to control the minimum size of the initializer when externalizing it during serialization +static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = + "session.optimized_model_external_initializers_min_size_in_bytes"; + +// When loading model from memory buffer and the model has external initializers +// Use this config to set the external data file folder path +// All external data files should be in the same folder +static const char* const kOrtSessionOptionsModelExternalInitializersFileFolderPath = + "session.model_external_initializers_file_folder_path"; + +// Use this config when saving pre-packed constant initializers to an external data file. +// This allows you to memory map pre-packed initializers on model load and leave it to +// to the OS the amount of memory consumed by the pre-packed initializers. Otherwise, +// pre-packed data resides on the heap. +// +// - "0": Default is not save pre-packed initializers to a data file. +// - "1": Save pre-packed constant initializers to an external data file. +// Sample usage: sess_options.add_session_config_entry(kOrtSessionOptionsSavePrePackedConstantInitializers, "1") +static const char* const kOrtSessionOptionsSavePrePackedConstantInitializers = + "session.save_external_prepacked_constant_initializers"; + +// Use this config when you want to collect memory stats for each node in the graph. +// The file format is a CSV file with the following columns: +// The file will be created if it does not exist, and will be overwritten if it does. +// +// The content of the file can be used to estimate memory requirements at run time including +// the temporary allocations. This operation is preferably done on a CPU device, as the model may exceed +// device memory limits in constrained environments. When enabling this option, it is important to disable +// memory patterns, as they tend to allocate large blocks to avoid fragmentation and accommodate needs of multiple +// kernels. Memory patterns may make it difficult to allocate on a device with limited memory. +// +// The collected stats then can be used to partition the graph among the devices in a way that only the +// required memory is allocated on each device. +// +// node_name, initializers_memory, dynamic_outputs_sizes, temp_allocations_size +// +// - "full path to file": there is not a default for this option. If the file can not be opened for writing, an error will be returned. +static const char* const kOrtSessionOptionsCollectNodeMemoryStatsToFile = "session.collect_node_memory_stats_to_file"; + +/// This is a composite CSV setting formatted as "memory limit in kb,file name for collected stats" +/// "limit > 0": enables Capacity Aware Partitioning for Cuda EP. `limit` is optional and when absent +/// the provider may attempt to figure out the memory available automatically. +/// The setting with no limit is expected to look like: ",file name for collected stats" +/// The EP will place nodes on device "file name" : +/// this file is expected to be found at the same folder with the model. The file contains +/// pre-recorded stats collected when running with kOrtSessionOptionsCollectNodeMemoryStatsToFile enforce (see above) +static const char* const kOrtSessionOptionsResourceCudaPartitioningSettings = + "session.resource_cuda_partitioning_settings"; + +// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file. +// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. +// "0": disable. (default) +// "1": enable. +static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable"; + +// Specify the file path for the Onnx model which has EP context. +// Default to original_file_name_ctx.onnx if not specified +// Folder is not a valid option +static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path"; + +// Flag to specify whether to dump the EP context into the Onnx model. +// "0": dump the EP context into separate file, keep the file name in the Onnx model. (default). +// "1": dump the EP context into the Onnx model. +static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; + +// Specify the EPContext node name prefix to make it unique +// in case user need to merge/connect multiple EPContext nodes in one model +static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix"; + +// Share EP related resources across sessions +static const char* const kOrtSessionOptionShareEpContexts = "ep.share_ep_contexts"; + +// Stop to share EP related resources across sessions from then on +static const char* const kOrtSessionOptionStopShareEpContexts = "ep.stop_share_ep_contexts"; + +// Use this config when dumping EP context model with an external initializers file +// All initializers will be inside the external data file if specified, otherwise all in Onnx file +static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFileName = + "ep.context_model_external_initializers_file_name"; + +// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul. +// Option values: +// - "0": Gemm FastMath mode is not enabled. [DEFAULT] +// - "1": Gemm FastMath mode is enabled. +static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; + +// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option. +// Refer to MatMulNBits op schema for more details. +// If not provided, default is 4. +static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; + +// THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME +// Meant to be used with SetEpDynamicOptions +// Specify the type of workload for this session. +// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default] +// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. +static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type"; diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/webgpu_provider_factory.h b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/webgpu_provider_factory.h new file mode 100644 index 000000000..0b45b847d --- /dev/null +++ b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Headers/webgpu_provider_factory.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Dummy file to provide a signal in the ONNX Runtime C cocoapod as to whether the WebGPU EP was included in the build. +// If it was, this file will be included in the cocoapod, and a test like this can be used: +// +// #if __has_include() +// #define WEBGPU_EP_AVAILABLE 1 +// #else +// #define WEBGPU_EP_AVAILABLE 0 +// #endif + +// The WebGPU EP can be enabled via the generic SessionOptionsAppendExecutionProvider method, so no direct usage of +// the provider factory is required. diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Info.plist b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Info.plist new file mode 100644 index 000000000..bce43958c Binary files /dev/null and b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/Info.plist differ diff --git a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/libonnxruntime.1.19.0.dylib b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/onnxruntime similarity index 63% rename from mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/libonnxruntime.1.19.0.dylib rename to mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/onnxruntime index f80172fd1..4b8db8192 100755 Binary files a/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/libonnxruntime.1.19.0.dylib and b/mobile/examples/phi-3/ios/LocalLLM/LocalLLM/lib/onnxruntime.framework/onnxruntime differ