diff --git a/src/middleware.ts b/src/middleware.ts index d44d9a921c..240a60a3f3 100644 --- a/src/middleware.ts +++ b/src/middleware.ts @@ -14,6 +14,15 @@ import { const DEVTOOLS = Symbol() +type DevtoolsType = { + prefix: string + subscribe: (dispatch: any) => () => void + unsubscribe: () => void + send: (action: string, state: any) => void + init: (state: any) => void + error: (payload: any) => void +} + export const redux = ( reducer: (state: S, action: A) => S, @@ -24,7 +33,7 @@ export const redux = get: GetState A }>, api: StoreApi A }> & { dispatch: (a: A) => A - devtools?: any + devtools?: DevtoolsType } ): S & { dispatch: (a: A) => A } => { api.dispatch = (action: A) => { @@ -56,14 +65,12 @@ export const devtools = InnerCustomSetState extends NamedSet, InnerCustomGetState extends GetState, InnerCustomStoreApi extends StoreApi & { - setState: NamedSet + dispatch?: unknown + devtools?: DevtoolsType }, - OuterCustomSetState extends InnerCustomSetState & SetState, + OuterCustomSetState extends SetState, OuterCustomGetState extends InnerCustomGetState, - OuterCustomStoreApi extends InnerCustomStoreApi & { - dispatch?: unknown - devtools: any - } + OuterCustomStoreApi extends InnerCustomStoreApi >( fn: ( set: InnerCustomSetState, @@ -110,7 +117,7 @@ export const devtools = ) { console.warn('Please install/enable Redux devtools extension') } - api.devtools = null + delete api.devtools return fn( set as unknown as InnerCustomSetState, get as InnerCustomGetState, @@ -119,7 +126,7 @@ export const devtools = } const namedSet: NamedSet = (state, replace, name) => { set(state, replace) - if (!api.dispatch) { + if (!api.dispatch && api.devtools) { api.devtools.send(api.devtools.prefix + (name || 'action'), get()) } } @@ -142,15 +149,15 @@ export const devtools = const newState = api.getState() if (state !== newState) { savedSetState(state, replace) - if (state !== (newState as any)[DEVTOOLS]) { + if (state !== (newState as any)[DEVTOOLS] && api.devtools) { api.devtools.send(api.devtools.prefix + 'setState', api.getState()) } } } options = typeof options === 'string' ? { name: options } : options - api.devtools = extension.connect({ ...options }) - api.devtools.prefix = options?.name ? `${options.name} > ` : '' - api.devtools.subscribe((message: any) => { + const connection = (api.devtools = extension.connect({ ...options })) + connection.prefix = options?.name ? `${options.name} > ` : '' + connection.subscribe((message: any) => { if (message.type === 'ACTION' && message.payload) { try { api.setState(JSON.parse(message.payload)) @@ -178,7 +185,7 @@ export const devtools = message.type === 'DISPATCH' && message.payload?.type === 'COMMIT' ) { - api.devtools.init(api.getState()) + connection.init(api.getState()) } else if ( message.type === 'DISPATCH' && message.payload?.type === 'IMPORT_STATE' @@ -192,16 +199,16 @@ export const devtools = const action = actions[index] || 'No action found' if (index === 0) { - api.devtools.init(state) + connection.init(state) } else { savedSetState(state) - api.devtools.send(action, api.getState()) + connection.send(action, api.getState()) } } ) } }) - api.devtools.init(initialState) + connection.init(initialState) } return initialState } @@ -305,42 +312,57 @@ export function subscribeWithSelector< } type Combine = Omit & U -export const combine = - < - PrimaryState extends State, - SecondaryState extends State, - OuterCustomSetState extends SetState>, - OuterCustomGetState extends GetState>, - OuterCustomStoreApi extends StoreApi>, - InnerCustomSetState extends OuterCustomSetState extends NamedSet< - Combine - > - ? NamedSet - : SetState, - InnerCustomGetState extends GetState, - InnerCustomStoreApi extends StoreApi - >( - initialState: PrimaryState, - create: ( - set: InnerCustomSetState, - get: InnerCustomGetState, - api: InnerCustomStoreApi - ) => SecondaryState - ) => - ( - set: OuterCustomSetState, - get: OuterCustomGetState, - api: OuterCustomStoreApi + +export function combine< + PrimaryState extends State, + SecondaryState extends State +>( + initialState: PrimaryState, + create: ( + set: NamedSet, + get: GetState, + api: StoreApi + ) => SecondaryState +): ( + set: NamedSet>, + get: GetState>, + api: StoreApi> +) => Combine + +export function combine< + PrimaryState extends State, + SecondaryState extends State +>( + initialState: PrimaryState, + create: ( + set: SetState, + get: GetState, + api: StoreApi + ) => SecondaryState +): ( + set: SetState>, + get: GetState>, + api: StoreApi> +) => Combine + +export function combine< + PrimaryState extends State, + SecondaryState extends State +>( + initialState: PrimaryState, + create: ( + set: SetState, + get: GetState, + api: StoreApi + ) => SecondaryState +) { + return ( + set: SetState>, + get: GetState>, + api: StoreApi> ) => - Object.assign( - {}, - initialState, - create( - set as unknown as InnerCustomSetState, - get as unknown as InnerCustomGetState, - api as unknown as InnerCustomStoreApi - ) - ) as Combine + Object.assign({}, initialState, create(set as any, get as any, api as any)) +} type DeepPartial = { [P in keyof T]?: DeepPartial diff --git a/tests/middlewareTypes.test.tsx b/tests/middlewareTypes.test.tsx index 51659517c4..fa1c3db771 100644 --- a/tests/middlewareTypes.test.tsx +++ b/tests/middlewareTypes.test.tsx @@ -1,12 +1,6 @@ import { produce } from 'immer' import type { Draft } from 'immer' -import create, { - GetState, - State, - StateCreator, - StoreApi, - UseBoundStore, -} from 'zustand' +import create, { State, StateCreator, UseBoundStore } from 'zustand' import { NamedSet, combine, @@ -20,12 +14,7 @@ type TImmerConfigFn = ( partial: ((draft: Draft) => void) | T, replace?: boolean ) => void -type TImmerConfig< - T extends State, - CustomSetState = TImmerConfigFn, - CustomGetState = GetState, - CustomStoreApi extends StoreApi = StoreApi -> = StateCreator +type TImmerConfig = StateCreator> interface ISelectors { use: { @@ -272,6 +261,29 @@ it('should have correct type when creating store with devtool, persist and immer TestComponent }) +it('should have correct type when creating store with devtools', () => { + const useStore = create( + devtools((set) => ({ + testKey: 'test', + setTestKey: (testKey: string) => { + set((state) => ({ + testKey: state.testKey + testKey, + })) + }, + })) + ) + + const TestComponent = (): JSX.Element => { + useStore().testKey + useStore().setTestKey('') + useStore.getState().testKey + useStore.getState().setTestKey('') + + return <> + } + TestComponent +}) + it('should have correct type when creating store with redux', () => { const useStore = create( redux<{ count: number }, { type: 'INC' }>( @@ -296,6 +308,31 @@ it('should have correct type when creating store with redux', () => { TestComponent }) +it('should combine devtools and immer', () => { + const useStore = create( + devtools( + immer((set) => ({ + testKey: 'test', + setTestKey: (testKey: string) => { + set((state) => { + state.testKey = testKey + }) + }, + })) + ) + ) + + const TestComponent = (): JSX.Element => { + useStore().testKey + useStore().setTestKey('') + useStore.getState().testKey + useStore.getState().setTestKey('') + + return <> + } + TestComponent +}) + it('should combine devtools and redux', () => { const useStore = create( devtools( @@ -346,7 +383,9 @@ it('should combine subscribeWithSelector and combine', () => { const useStore = create( subscribeWithSelector( combine({ count: 1 }, (set, get) => ({ - inc: () => set({ count: get().count + 1 }, false, 'inc'), + inc: () => set({ count: get().count + 1 }, false), + // FIXME hope this to fail // @ts-expect-error + incInvalid: () => set({ count: get().count + 1 }, false, 'inc'), })) ) )