diff --git a/packages/fiber/src/core/hooks.tsx b/packages/fiber/src/core/hooks.tsx index 6a0e8d8e07..c7f6f8bc7e 100644 --- a/packages/fiber/src/core/hooks.tsx +++ b/packages/fiber/src/core/hooks.tsx @@ -5,7 +5,6 @@ import { GLTF } from 'three/examples/jsm/loaders/GLTFLoader' import { suspend, preload, clear } from 'suspend-react' import { context, RootState, RenderCallback } from './store' import { buildGraph, ObjectMap, is, useMutableCallback, useIsomorphicLayoutEffect } from './utils' -import { LoadingManager } from 'three' import { LocalState, Instance } from './renderer' export interface Loader extends THREE.Loader { @@ -17,8 +16,13 @@ export interface Loader extends THREE.Loader { ): unknown } -export type Extensions = (loader: THREE.Loader) => void +export type LoaderProto = new (...args: any) => Loader +export type LoaderReturnType> = T extends unknown + ? Awaited['loadAsync']>> + : T +// TODO: this isn't used anywhere, remove in v9 export type LoaderResult = T extends any[] ? Loader : Loader +export type Extensions }> = (loader: T['prototype']) => void export type ConditionalType = Child extends Parent ? Truthy : Falsy export type BranchingReturn = ConditionalType @@ -74,8 +78,11 @@ export function useGraph(object: THREE.Object3D) { return React.useMemo(() => buildGraph(object), [object]) } -function loadingFn(extensions?: Extensions, onProgress?: (event: ProgressEvent) => void) { - return function (Proto: new () => LoaderResult, ...input: string[]) { +function loadingFn>( + extensions?: Extensions, + onProgress?: (event: ProgressEvent) => void, +) { + return function (Proto: L, ...input: string[]) { // Construct new loader and run extensions const loader = new Proto() if (extensions) extensions(loader) @@ -105,37 +112,37 @@ function loadingFn(extensions?: Extensions, onProgress?: (event: ProgressEven * Note: this hook's caller must be wrapped with `React.Suspense` * @see https://docs.pmnd.rs/react-three-fiber/api/hooks#useloader */ -export function useLoader( - Proto: new (manager?: LoadingManager) => LoaderResult, +export function useLoader, R = LoaderReturnType>( + Proto: L, input: U, - extensions?: Extensions, + extensions?: Extensions, onProgress?: (event: ProgressEvent) => void, -): U extends any[] ? BranchingReturn[] : BranchingReturn { +): U extends any[] ? BranchingReturn[] : BranchingReturn { // Use suspense to load async assets const keys = (Array.isArray(input) ? input : [input]) as string[] - const results = suspend(loadingFn(extensions, onProgress), [Proto, ...keys], { equal: is.equ }) + const results = suspend(loadingFn(extensions, onProgress), [Proto, ...keys], { equal: is.equ }) // Return the object/s return (Array.isArray(input) ? results : results[0]) as U extends any[] - ? BranchingReturn[] - : BranchingReturn + ? BranchingReturn[] + : BranchingReturn } /** * Preloads an asset into cache as a side-effect. */ -useLoader.preload = function ( - Proto: new () => LoaderResult, +useLoader.preload = function >( + Proto: L, input: U, - extensions?: Extensions, + extensions?: Extensions, ) { const keys = (Array.isArray(input) ? input : [input]) as string[] - return preload(loadingFn(extensions), [Proto, ...keys]) + return preload(loadingFn(extensions), [Proto, ...keys]) } /** * Removes a loaded asset from cache. */ -useLoader.clear = function (Proto: new () => LoaderResult, input: U) { +useLoader.clear = function >(Proto: L, input: U) { const keys = (Array.isArray(input) ? input : [input]) as string[] return clear([Proto, ...keys]) } diff --git a/packages/fiber/tests/core/hooks.test.tsx b/packages/fiber/tests/core/hooks.test.tsx index 29342ccc39..830f8ca396 100644 --- a/packages/fiber/tests/core/hooks.test.tsx +++ b/packages/fiber/tests/core/hooks.test.tsx @@ -211,6 +211,21 @@ describe('hooks', () => { expect(scene.children[0]).toBe(MockMesh) }) + it('can handle useLoader with a loader extension', async () => { + class Loader extends THREE.Loader { + load = (_url: string) => null + } + + let proto!: Loader + + function Test() { + return useLoader(Loader, '', (loader) => (proto = loader)) + } + await act(async () => createRoot(canvas).render()) + + expect(proto).toBeInstanceOf(Loader) + }) + it('can handle useGraph hook', async () => { const group = new THREE.Group() const mat1 = new THREE.MeshBasicMaterial()