Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Recursively derive seeds and add custom account resolver #2194

Merged
merged 10 commits into from Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions CHANGELOG.md
Expand Up @@ -20,10 +20,12 @@ The minor version will be incremented upon a breaking change and the patch versi
* lang: Add parsing for consts from impl blocks for IDL PDA seeds generation ([#2128](https://github.com/coral-xyz/anchor/pull/2014))
* lang: Account closing reassigns to system program and reallocates ([#2169](https://github.com/coral-xyz/anchor/pull/2169)).
* ts: Add coders for SPL programs ([#2143](https://github.com/coral-xyz/anchor/pull/2143)).
* ts: Add `has_one` relations inference so accounts mapped via has_one relationships no longer need to be provided
* ts: Add ability to set args after setting accounts and retriving pubkyes
* ts: Add `.prepare()` to builder pattern
* ts: Add `has_one` relations inference so accounts mapped via has_one relationships no longer need to be provided ([#2160](https://github.com/coral-xyz/anchor/pull/2160))
* ts: Add ability to set args after setting accounts and retrieving pubkyes ([#2160](https://github.com/coral-xyz/anchor/pull/2160))
* ts: Add `.prepare()` to builder pattern ([#2160](https://github.com/coral-xyz/anchor/pull/2160))
* spl: Add `freeze_delegated_account` and `thaw_delegated_account` wrappers ([#2164](https://github.com/coral-xyz/anchor/pull/2164))
* ts: Add nested PDA inference ([#2194](https://github.com/coral-xyz/anchor/pull/2194))
* ts: Add ability to resolve missing accounts with a custom resolver ([#2194](https://github.com/coral-xyz/anchor/pull/2194))

### Fixes

Expand Down
19 changes: 19 additions & 0 deletions tests/pda-derivation/programs/pda-derivation/src/lib.rs
Expand Up @@ -67,11 +67,30 @@ pub struct InitMyAccount<'info> {
bump,
)]
account: Account<'info, MyAccount>,
nested: Nested<'info>,
#[account(mut)]
payer: Signer<'info>,
system_program: Program<'info, System>,
}

#[derive(Accounts)]
pub struct Nested<'info> {
#[account(
seeds = [
"nested-seed".as_bytes(),
b"test".as_ref(),
MY_SEED.as_ref(),
MY_SEED_STR.as_bytes(),
MY_SEED_U8.to_le_bytes().as_ref(),
&MY_SEED_U32.to_le_bytes(),
&MY_SEED_U64.to_le_bytes(),
],
bump,
)]
/// CHECK: Not needed
account_nested: AccountInfo<'info>,
}

#[account]
pub struct MyAccount {
data: u64,
Expand Down
27 changes: 27 additions & 0 deletions tests/pda-derivation/tests/typescript.spec.ts
Expand Up @@ -65,4 +65,31 @@ describe("typescript", () => {
.data;
expect(actualData.toNumber()).is.equal(1337);
});

it("should allow custom resolvers", async () => {
let called = false;
const customProgram = new Program<PdaDerivation>(
program.idl,
program.programId,
program.provider,
program.coder,
(instruction) => {
if (instruction.name === "initMyAccount") {
return async ({ accounts }) => {
called = true;
return accounts;
};
}
}
);
await customProgram.methods
.initMyAccount(seedA)
.accounts({
base: base.publicKey,
base2: base.publicKey,
})
.pubkeys();

expect(called).is.true;
});
});
78 changes: 57 additions & 21 deletions ts/packages/anchor/src/program/accounts-resolver.ts
Expand Up @@ -17,6 +17,14 @@ import { BorshAccountsCoder } from "src/coder/index.js";

type Accounts = { [name: string]: PublicKey | Accounts };

export type CustomAccountResolver<IDL extends Idl> = (params: {
args: Array<any>;
accounts: Accounts;
provider: Provider;
programId: PublicKey;
idlIx: AllInstructions<IDL>;
}) => Promise<Accounts>;

// Populates a given accounts context with PDAs and common missing accounts.
export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
_args: Array<any>;
Expand All @@ -35,7 +43,8 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
private _provider: Provider,
private _programId: PublicKey,
private _idlIx: AllInstructions<IDL>,
_accountNamespace: AccountNamespace<IDL>
_accountNamespace: AccountNamespace<IDL>,
private _customResolver?: CustomAccountResolver<IDL>
) {
this._args = _args;
this._accountStore = new AccountStore(_provider, _accountNamespace);
Expand Down Expand Up @@ -84,25 +93,22 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
}
}

for (let k = 0; k < this._idlIx.accounts.length; k += 1) {
// Cast is ok because only a non-nested IdlAccount can have a seeds
// cosntraint.
const accountDesc = this._idlIx.accounts[k] as IdlAccount;
const accountDescName = camelCase(accountDesc.name);

// PDA derived from IDL seeds.
if (
accountDesc.pda &&
accountDesc.pda.seeds.length > 0 &&
!this._accounts[accountDescName]
) {
await this.autoPopulatePda(accountDesc);
continue;
}
// Auto populate pdas and relations until we stop finding new accounts
while (
(await this.resolvePdas(this._idlIx.accounts)) +
(await this.resolveRelations(this._idlIx.accounts)) >
0
) {}

if (this._customResolver) {
this._accounts = await this._customResolver({
args: this._args,
accounts: this._accounts,
provider: this._provider,
programId: this._programId,
idlIx: this._idlIx,
});
}

// Auto populate has_one relationships until we stop finding new accounts
while ((await this.resolveRelations(this._idlIx.accounts)) > 0) {}
}

private get(path: string[]): PublicKey | undefined {
Expand Down Expand Up @@ -130,6 +136,36 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
});
}

private async resolvePdas(
accounts: IdlAccountItem[],
path: string[] = []
): Promise<number> {
let found = 0;
for (let k = 0; k < accounts.length; k += 1) {
const accountDesc = accounts[k];
const subAccounts = (accountDesc as IdlAccounts).accounts;
if (subAccounts) {
found += await this.resolvePdas(subAccounts, [
...path,
accountDesc.name,
]);
}

const accountDescCasted: IdlAccount = accountDesc as IdlAccount;
const accountDescName = camelCase(accountDesc.name);
// PDA derived from IDL seeds.
if (
accountDescCasted.pda &&
accountDescCasted.pda.seeds.length > 0 &&
!this.get([...path, accountDescName])
) {
await this.autoPopulatePda(accountDescCasted, path);
found += 1;
}
}
return found;
}

private async resolveRelations(
accounts: IdlAccountItem[],
path: string[] = []
Expand Down Expand Up @@ -172,7 +208,7 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
return found;
}

private async autoPopulatePda(accountDesc: IdlAccount) {
private async autoPopulatePda(accountDesc: IdlAccount, path: string[] = []) {
if (!accountDesc.pda || !accountDesc.pda.seeds)
throw new Error("Must have seeds");

Expand All @@ -183,7 +219,7 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
const programId = await this.parseProgramId(accountDesc);
const [pubkey] = await PublicKey.findProgramAddress(seeds, programId);

this._accounts[camelCase(accountDesc.name)] = pubkey;
this.set([...path, camelCase(accountDesc.name)], pubkey);
}

private async parseProgramId(accountDesc: IdlAccount): Promise<PublicKey> {
Expand Down
19 changes: 16 additions & 3 deletions ts/packages/anchor/src/program/index.ts
@@ -1,7 +1,7 @@
import { inflate } from "pako";
import { PublicKey } from "@solana/web3.js";
import Provider, { getProvider } from "../provider.js";
import { Idl, idlAddress, decodeIdlAccount } from "../idl.js";
import { Idl, idlAddress, decodeIdlAccount, IdlInstruction } from "../idl.js";
import { Coder, BorshCoder } from "../coder/index.js";
import NamespaceFactory, {
RpcNamespace,
Expand All @@ -16,6 +16,7 @@ import NamespaceFactory, {
import { utf8 } from "../utils/bytes/index.js";
import { EventManager } from "./event.js";
import { Address, translateAddress } from "./common.js";
import { CustomAccountResolver } from "./accounts-resolver.js";

export * from "./common.js";
export * from "./context.js";
Expand Down Expand Up @@ -263,12 +264,18 @@ export class Program<IDL extends Idl = Idl> {
* @param programId The on-chain address of the program.
* @param provider The network and wallet context to use. If not provided
* then uses [[getProvider]].
* @param getCustomResolver A function that returns a custom account resolver
* for the given instruction. This is useful for resolving
* public keys of missing accounts when building instructions
*/
public constructor(
idl: IDL,
programId: Address,
provider?: Provider,
coder?: Coder
coder?: Coder,
getCustomResolver?: (
instruction: IdlInstruction
) => CustomAccountResolver<IDL> | undefined
) {
programId = translateAddress(programId);

Expand All @@ -293,7 +300,13 @@ export class Program<IDL extends Idl = Idl> {
methods,
state,
views,
] = NamespaceFactory.build(idl, this._coder, programId, provider);
] = NamespaceFactory.build(
idl,
this._coder,
programId,
provider,
getCustomResolver ?? (() => undefined)
);
this.rpc = rpc;
this.instruction = instruction;
this.transaction = transaction;
Expand Down
11 changes: 8 additions & 3 deletions ts/packages/anchor/src/program/namespace/index.ts
Expand Up @@ -2,7 +2,7 @@ import camelCase from "camelcase";
import { PublicKey } from "@solana/web3.js";
import { Coder } from "../../coder/index.js";
import Provider from "../../provider.js";
import { Idl } from "../../idl.js";
import { Idl, IdlInstruction } from "../../idl.js";
import StateFactory, { StateClient } from "./state.js";
import InstructionFactory, { InstructionNamespace } from "./instruction.js";
import TransactionFactory, { TransactionNamespace } from "./transaction.js";
Expand All @@ -12,6 +12,7 @@ import SimulateFactory, { SimulateNamespace } from "./simulate.js";
import { parseIdlErrors } from "../common.js";
import { MethodsBuilderFactory, MethodsNamespace } from "./methods";
import ViewFactory, { ViewNamespace } from "./views";
import { CustomAccountResolver } from "../accounts-resolver.js";

// Re-exports.
export { StateClient } from "./state.js";
Expand All @@ -32,7 +33,10 @@ export default class NamespaceFactory {
idl: IDL,
coder: Coder,
programId: PublicKey,
provider: Provider
provider: Provider,
getCustomResolver?: (
instruction: IdlInstruction
) => CustomAccountResolver<IDL> | undefined
): [
RpcNamespace<IDL>,
InstructionNamespace<IDL>,
Expand Down Expand Up @@ -85,7 +89,8 @@ export default class NamespaceFactory {
rpcItem,
simulateItem,
viewItem,
account
account,
getCustomResolver && getCustomResolver(idlIx)
);
const name = camelCase(idlIx.name);

Expand Down
17 changes: 12 additions & 5 deletions ts/packages/anchor/src/program/namespace/methods.ts
Expand Up @@ -22,7 +22,10 @@ import { SimulateFn } from "./simulate.js";
import { ViewFn } from "./views.js";
import Provider from "../../provider.js";
import { AccountNamespace } from "./account.js";
import { AccountsResolver } from "../accounts-resolver.js";
import {
AccountsResolver,
CustomAccountResolver,
} from "../accounts-resolver.js";
import { Accounts } from "../context.js";

export type MethodsNamespace<
Expand All @@ -40,7 +43,8 @@ export class MethodsBuilderFactory {
rpcFn: RpcFn<IDL>,
simulateFn: SimulateFn<IDL>,
viewFn: ViewFn<IDL> | undefined,
accountNamespace: AccountNamespace<IDL>
accountNamespace: AccountNamespace<IDL>,
customResolver?: CustomAccountResolver<IDL>
): MethodsFn<IDL, I, MethodsBuilder<IDL, I>> {
return (...args) =>
new MethodsBuilder(
Expand All @@ -53,7 +57,8 @@ export class MethodsBuilderFactory {
provider,
programId,
idlIx,
accountNamespace
accountNamespace,
customResolver
);
}
}
Expand All @@ -78,7 +83,8 @@ export class MethodsBuilder<IDL extends Idl, I extends AllInstructions<IDL>> {
_provider: Provider,
_programId: PublicKey,
_idlIx: AllInstructions<IDL>,
_accountNamespace: AccountNamespace<IDL>
_accountNamespace: AccountNamespace<IDL>,
_customResolver?: CustomAccountResolver<IDL>
) {
this._args = _args;
this._accountsResolver = new AccountsResolver(
Expand All @@ -87,7 +93,8 @@ export class MethodsBuilder<IDL extends Idl, I extends AllInstructions<IDL>> {
_provider,
_programId,
_idlIx,
_accountNamespace
_accountNamespace,
_customResolver
);
}

Expand Down