diff --git a/src/ref.ts b/src/ref.ts index a14098cb..a080d563 100644 --- a/src/ref.ts +++ b/src/ref.ts @@ -1,5 +1,6 @@ /* eslint-disable no-param-reassign */ import type * as React from 'react'; +import type { ReactNode } from 'react'; import { isValidElement } from 'react'; import { isForwardRef, isFragment, isMemo } from 'react-is'; import useMemo from './hooks/useMemo'; @@ -37,20 +38,14 @@ export function useComposeRef(...refs: React.Ref[]): React.Ref { ); } -interface WithRef { - ref: React.Ref; -} - -export function supportRef(nodeOrComponent: any): nodeOrComponent is WithRef { - if (isFragment(nodeOrComponent)) { +export const supportRef = (value: any): value is React.RefAttributes => { + if (isFragment(value)) { return false; } - if (isForwardRef(nodeOrComponent)) { + if (isForwardRef(value)) { return true; } - const type = isMemo(nodeOrComponent) - ? nodeOrComponent.type.type - : nodeOrComponent.type; + const type = isMemo(value) ? value.type.type : value.type; // Function component node if (typeof type === 'function' && !type.prototype?.render) { @@ -58,23 +53,22 @@ export function supportRef(nodeOrComponent: any): nodeOrComponent is WithRef { } // Class component - if ( - typeof nodeOrComponent === 'function' && - !nodeOrComponent.prototype?.render - ) { + if (typeof value === 'function' && !value.prototype?.render) { return false; } return true; -} +}; -export function supportNodeRef(node: React.ReactNode): boolean { +export function supportNodeRef(node: ReactNode): boolean { if (!isValidElement(node)) { return false; } + if (isFragment(node)) { return false; } + return supportRef(node); } /* eslint-enable */