diff --git a/common/Services/IExtensionService.cs b/common/Services/IExtensionService.cs index 998f3c5db..941e61e58 100644 --- a/common/Services/IExtensionService.cs +++ b/common/Services/IExtensionService.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Threading.Tasks; -using Windows.ApplicationModel.AppExtensions; namespace DevHome.Common.Services; @@ -16,12 +15,10 @@ public interface IExtensionService Task> GetInstalledExtensionsAsync(Microsoft.Windows.DevHome.SDK.ProviderType providerType, bool includeDisabledExtensions = false); - Task> GetAllExtensionsAsync(); + IExtensionWrapper? GetInstalledExtension(string extensionUniqueId); Task SignalStopExtensionsAsync(); - Task> GetInstalledAppExtensionsAsync(); - public event EventHandler OnExtensionsChanged; public void EnableExtension(string extensionUniqueId); diff --git a/src/Models/ExtensionWrapper.cs b/src/Models/ExtensionWrapper.cs index 1cf86570e..215ae01e0 100644 --- a/src/Models/ExtensionWrapper.cs +++ b/src/Models/ExtensionWrapper.cs @@ -107,12 +107,13 @@ await Task.Run(() => try { var hr = PInvoke.CoCreateInstance(Guid.Parse(ExtensionClassId), null, CLSCTX.CLSCTX_LOCAL_SERVER, typeof(IExtension).GUID, out var extensionObj); - extensionPtr = Marshal.GetIUnknownForObject(extensionObj); if (hr < 0) { Marshal.ThrowExceptionForHR(hr); } + extensionPtr = Marshal.GetIUnknownForObject(extensionObj); + _extensionObject = MarshalInterface.FromAbi(extensionPtr); } finally diff --git a/src/Services/ExtensionService.cs b/src/Services/ExtensionService.cs index e38d301cc..e09d7c6ea 100644 --- a/src/Services/ExtensionService.cs +++ b/src/Services/ExtensionService.cs @@ -9,12 +9,10 @@ using DevHome.Telemetry; using Microsoft.UI.Xaml; using Microsoft.Windows.DevHome.SDK; -using Newtonsoft.Json.Linq; using Serilog; using Windows.ApplicationModel; using Windows.ApplicationModel.AppExtensions; using Windows.Foundation.Collections; -using YamlDotNet.Core.Tokens; using static DevHome.Common.Helpers.ManagementInfrastructureHelper; namespace DevHome.Services; @@ -37,11 +35,9 @@ public class ExtensionService : IExtensionService, IDisposable private const string CreateInstanceProperty = "CreateInstance"; private const string ClassIdProperty = "@ClassId"; -#pragma warning disable IDE0044 // Add readonly modifier - private static List _installedExtensions = new(); - private static List _enabledExtensions = new(); - private static List _installedWidgetsPackageFamilyNames = new(); -#pragma warning restore IDE0044 // Add readonly modifier + private static readonly List _installedExtensions = new(); + private static readonly List _enabledExtensions = new(); + private static readonly List _installedWidgetsPackageFamilyNames = new(); public ExtensionService(ILocalSettingsService settingsService) { @@ -158,7 +154,7 @@ private async Task IsValidDevHomeExtension(Package package) return (devHomeProvider, classIds); } - public async Task> GetInstalledAppExtensionsAsync() + private async Task> GetInstalledAppExtensionsAsync() { return await AppExtensionCatalog.Open("com.microsoft.devhome").FindAllAsync(); } @@ -227,6 +223,12 @@ public async Task> GetInstalledExtensionsAsync(bo } } + public IExtensionWrapper? GetInstalledExtension(string extensionUniqueId) + { + var extension = _installedExtensions.Where(extension => extension.ExtensionUniqueId.Equals(extensionUniqueId, StringComparison.Ordinal)); + return extension.FirstOrDefault(); + } + private async Task> GetInstalledWidgetExtensionsAsync() { await _getInstalledWidgetsLock.WaitAsync(); @@ -259,20 +261,6 @@ public async Task> GetInstalledDevHomeWidgetPackageFamilyNam return ids; } - public async Task> GetAllExtensionsAsync() - { - var installedExtensions = await GetInstalledExtensionsAsync(); - foreach (var installedExtension in installedExtensions) - { - if (!installedExtension.IsRunning()) - { - await installedExtension.StartExtensionAsync(); - } - } - - return installedExtensions; - } - public async Task SignalStopExtensionsAsync() { var installedExtensions = await GetInstalledExtensionsAsync(); @@ -381,13 +369,13 @@ private List GetCreateInstanceList(IPropertySet activationPropSet) public void EnableExtension(string extensionUniqueId) { - var extension = _installedExtensions.Where(extension => extension.ExtensionUniqueId == extensionUniqueId); + var extension = _installedExtensions.Where(extension => extension.ExtensionUniqueId.Equals(extensionUniqueId, StringComparison.Ordinal)); _enabledExtensions.Add(extension.First()); } public void DisableExtension(string extensionUniqueId) { - var extension = _enabledExtensions.Where(extension => extension.ExtensionUniqueId == extensionUniqueId); + var extension = _enabledExtensions.Where(extension => extension.ExtensionUniqueId.Equals(extensionUniqueId, StringComparison.Ordinal)); _enabledExtensions.Remove(extension.First()); } diff --git a/tools/Dashboard/DevHome.Dashboard/Extensions/ServiceExtensions.cs b/tools/Dashboard/DevHome.Dashboard/Extensions/ServiceExtensions.cs index a414ab11b..cb3b0a45e 100644 --- a/tools/Dashboard/DevHome.Dashboard/Extensions/ServiceExtensions.cs +++ b/tools/Dashboard/DevHome.Dashboard/Extensions/ServiceExtensions.cs @@ -29,6 +29,7 @@ public static IServiceCollection AddDashboard(this IServiceCollection services, services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); return services; } diff --git a/tools/Dashboard/DevHome.Dashboard/Services/IWidgetExtensionService.cs b/tools/Dashboard/DevHome.Dashboard/Services/IWidgetExtensionService.cs new file mode 100644 index 000000000..a3247d7df --- /dev/null +++ b/tools/Dashboard/DevHome.Dashboard/Services/IWidgetExtensionService.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Threading.Tasks; + +namespace DevHome.Dashboard.Services; + +internal interface IWidgetExtensionService +{ + /// + /// Gets whether the given providerDefinitionId represents a CoreWidgetProvider of any build ring + /// + /// True if the given providerDefinitionId represents a CoreWidgetProvider, otherwise false. + bool IsCoreWidgetProvider(string providerDefinitionId); + + Task EnsureCoreWidgetExtensionStarted(string providerDefinitionId); +} diff --git a/tools/Dashboard/DevHome.Dashboard/Services/WidgetExtensionService.cs b/tools/Dashboard/DevHome.Dashboard/Services/WidgetExtensionService.cs new file mode 100644 index 000000000..4d3553289 --- /dev/null +++ b/tools/Dashboard/DevHome.Dashboard/Services/WidgetExtensionService.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Threading.Tasks; +using DevHome.Common.Services; + +namespace DevHome.Dashboard.Services; + +internal sealed class WidgetExtensionService : IWidgetExtensionService +{ + private const string ExtensionUniqueIdStable = "Microsoft.Windows.DevHome_8wekyb3d8bbwe!App!PG-SP-ID1"; + private const string ExtensionUniqueIdCanary = "Microsoft.Windows.DevHome.Canary_8wekyb3d8bbwe!App!PG-SP-ID1"; + private const string ExtensionUniqueIdDev = "Microsoft.Windows.DevHome.Dev_8wekyb3d8bbwe!App!PG-SP-ID1"; + + private const string ProviderDefinitionStable = "Microsoft.Windows.DevHome_8wekyb3d8bbwe!App!!CoreWidgetProvider"; + private const string ProviderDefinitionCanary = "Microsoft.Windows.DevHome.Canary_8wekyb3d8bbwe!App!!CoreWidgetProvider"; + private const string ProviderDefinitionDev = "Microsoft.Windows.DevHome.Dev_8wekyb3d8bbwe!App!!CoreWidgetProvider"; + + private readonly IExtensionService _extensionService; + + public WidgetExtensionService(IExtensionService extensionService) + { + _extensionService = extensionService; + } + + /// + public bool IsCoreWidgetProvider(string providerDefinitionId) + { + return providerDefinitionId.Equals(ProviderDefinitionStable, StringComparison.Ordinal) || + providerDefinitionId.Equals(ProviderDefinitionCanary, StringComparison.Ordinal) || + providerDefinitionId.Equals(ProviderDefinitionDev, StringComparison.Ordinal); + } + + public async Task EnsureCoreWidgetExtensionStarted(string providerDefinitionId) + { + if (providerDefinitionId.StartsWith(ProviderDefinitionStable, StringComparison.Ordinal)) + { + await EnsureExtensionStarted(ExtensionUniqueIdStable); + } + else if (providerDefinitionId.StartsWith(ProviderDefinitionCanary, StringComparison.Ordinal)) + { + await EnsureExtensionStarted(ExtensionUniqueIdCanary); + } + else if (providerDefinitionId.StartsWith(ProviderDefinitionDev, StringComparison.Ordinal)) + { + await EnsureExtensionStarted(ExtensionUniqueIdDev); + } + } + + private async Task EnsureExtensionStarted(string extensionUniqueId) + { + var extensionWrapper = _extensionService.GetInstalledExtension(extensionUniqueId); + if (!extensionWrapper.IsRunning()) + { + await extensionWrapper.StartExtensionAsync(); + } + } +} diff --git a/tools/Dashboard/DevHome.Dashboard/Views/DashboardView.xaml.cs b/tools/Dashboard/DevHome.Dashboard/Views/DashboardView.xaml.cs index 90dcc39f1..f5c0b53d6 100644 --- a/tools/Dashboard/DevHome.Dashboard/Views/DashboardView.xaml.cs +++ b/tools/Dashboard/DevHome.Dashboard/Views/DashboardView.xaml.cs @@ -50,6 +50,7 @@ public partial class DashboardView : ToolPage, IDisposable private static DispatcherQueue _dispatcherQueue; private readonly ILocalSettingsService _localSettingsService; + private readonly IWidgetExtensionService _widgetExtensionService; private bool _disposedValue; private const string DraggedWidget = "DraggedWidget"; @@ -67,6 +68,7 @@ public DashboardView() _dispatcherQueue = Application.Current.GetService(); _localSettingsService = Application.Current.GetService(); + _widgetExtensionService = Application.Current.GetService(); #if DEBUG Loaded += AddResetButton; @@ -592,6 +594,14 @@ await Task.Run(async () => return; } + // The WidgetService will start the widget provider, however Dev Home won't know about it and won't be + // able to send disposed events when Dev Home closes. Ensure the provider is started here so we can + // tell the extension to dispose later. + if (_widgetExtensionService.IsCoreWidgetProvider(comSafeWidgetDefinition.ProviderDefinitionId)) + { + await _widgetExtensionService.EnsureCoreWidgetExtensionStarted(comSafeWidgetDefinition.ProviderDefinitionId); + } + TelemetryFactory.Get().Log( "Dashboard_ReportPinnedWidget", LogLevel.Critical, diff --git a/tools/ExtensionLibrary/DevHome.ExtensionLibrary/ViewModels/ExtensionLibraryViewModel.cs b/tools/ExtensionLibrary/DevHome.ExtensionLibrary/ViewModels/ExtensionLibraryViewModel.cs index 5efa72977..2cee317b3 100644 --- a/tools/ExtensionLibrary/DevHome.ExtensionLibrary/ViewModels/ExtensionLibraryViewModel.cs +++ b/tools/ExtensionLibrary/DevHome.ExtensionLibrary/ViewModels/ExtensionLibraryViewModel.cs @@ -10,7 +10,6 @@ using CommunityToolkit.Mvvm.Input; using CommunityToolkit.WinUI; using DevHome.Common.Extensions; -using DevHome.Common.Helpers; using DevHome.Common.Services; using Microsoft.UI.Dispatching; using Microsoft.UI.Xaml; @@ -27,18 +26,18 @@ public partial class ExtensionLibraryViewModel : ObservableObject { private readonly ILogger _log = Log.ForContext("SourceContext", nameof(ExtensionLibraryViewModel)); - private readonly string devHomeProductId = "9N8MHTPHNGVV"; + private const string DevHomeProductId = "9N8MHTPHNGVV"; private readonly IExtensionService _extensionService; private readonly DispatcherQueue _dispatcherQueue; // All internal Dev Home extensions that should allow users to enable/disable them, should add // their class Ids to this set. - private readonly HashSet _internalClassIdsToBeShownInExtensionsPage = new() - { + private readonly HashSet _internalClassIdsToBeShownInExtensionsPage = + [ HyperVExtensionClassId, WSLExtensionClassId, - }; + ]; public ObservableCollection StorePackagesList { get; set; } @@ -68,7 +67,7 @@ public async Task GetUpdatesButtonAsync() [RelayCommand] public async Task LoadedAsync() { - await GetInstalledExtensionsAsync(); + await GetInstalledPackagesAndExtensionsAsync(); GetAvailablePackages(); } @@ -77,12 +76,12 @@ private async void OnExtensionsChanged(object? sender, EventArgs e) await _dispatcherQueue.EnqueueAsync(async () => { ShouldShowStoreError = false; - await GetInstalledExtensionsAsync(); + await GetInstalledPackagesAndExtensionsAsync(); GetAvailablePackages(); }); } - private async Task GetInstalledExtensionsAsync() + private async Task GetInstalledPackagesAndExtensionsAsync() { var extensionWrappers = await _extensionService.GetInstalledExtensionsAsync(true); @@ -166,7 +165,7 @@ private async void GetAvailablePackages() var productId = productObj.GetNamedString("ProductId"); // Don't show self as available. - if (productId == devHomeProductId) + if (productId == DevHomeProductId) { continue; }