Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Oct 20, 2022
1 parent a6a7d57 commit 2e2be49
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@
"token_embedder": "shared",
"encoder_norm": "final_layer_norm",
"decoder_norm": "final_layer_norm",
"relpos_bias/rel_embedding": "TO_RENAME",
"relpos_bias/rel_embedding": "block/0/layer/0/SelfAttention/relative_attention_bias/weight",
"router/router_weights/w/": "router/classifier/",
"roer/roer_weights/w/": "router/classifier/",



}

FLAX_MODELS = {
Expand All @@ -70,49 +73,61 @@ def rename_keys(s_dict):
keys = list(s_dict.keys())
for key in keys:
layer_to_block_of_layer = r".*/layers_(\d+)"

new_key = key
if re.match(layer_to_block_of_layer, key):
new_key = re.sub(r"layers_(\d+)", r"block/\1/layer", key)
s_dict[new_key] = s_dict.pop(key)
key = new_key
new_key = re.sub(r"layers_(\d+)", r"block/\1/layer", new_key)
# s_dict[new_key] = s_dict.pop(key)

layer_to_block_of_layer = r"(encoder|decoder)\/"

if re.match(layer_to_block_of_layer, key):
groups = re.match(layer_to_block_of_layer, key).groups()
groups = re.match(layer_to_block_of_layer, new_key).groups()
if groups[0] == "encoder":
new_key = re.sub(r"/mlp/", r"/1/mlp/", key)
new_key = re.sub(r"/mlp/", r"/1/mlp/", new_key)
new_key = re.sub(r"/pre_mlp_layer_norm/", r"/0/layer_norm/", new_key)

elif groups[0] == "decoder":
new_key = re.sub(r"/mlp/", r"/2/mlp/", key)
new_key = re.sub(r"/mlp/", r"/2/mlp/", new_key)
new_key = re.sub(r"/pre_mlp_layer_norm/", r"/1/layer_norm/", new_key)

# 2. Convert other classic mappings
for old_key, temp_key in MOE_LAYER_NAME_MAPPING.items():
if old_key in new_key:
new_key = new_key.replace(old_key, temp_key)


print(f"{key} -> {new_key}")
s_dict[new_key] = s_dict.pop(key)
# 3. Take extra care of the EXPERTS layer
s_dict[new_key] = s_dict.pop(key)


# 3. Take extra care of the EXPERTS layer
for key in list(s_dict.keys()):
if "expert" in key:

num_experts = s_dict[key].shape[0]
expert_weihts = s_dict[key]
for idx in range(num_experts):
s_dict[key.replace("expert/", f"experts/expert_{idx}/")] = expert_weihts[idx]
s_dict.pop(key)

return s_dict


def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, pytorch_dump_path):
# Initialise PyTorch model

print(f"Loading flax weights from : {flax_checkpoint_path}")
# get_file_from_repo(flax_checkpoint_path, "flax_params.flax")
with open(os.path.join(flax_checkpoint_path, "flax_params.flax"), "rb") as f:
path = get_file_from_repo(flax_checkpoint_path, "flax_params.flax", use_auth_token = "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") # get_file_from_repo(config_file, "flax_params.flax") # get_file_from_repo(config_file, "flax_params.flax", use_auth_token = "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")

config_file = get_file_from_repo(flax_checkpoint_path, "config.json", use_auth_token = "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")
with open(os.path.join(path), "rb") as f:
params = msgpack_restore(f.read())

config = SwitchTransformersConfig.from_pretrained(config_file)
config = SwitchTransformersConfig.from_pretrained(config_file, relative_attention_num_buckets=12)
pt_model = SwitchTransformersForConditionalGeneration(config)

params = flatten_dict(params, sep="/")
params = rename_keys(params)
breakpoint()
params = unflatten_dict(params, sep="/")

load_flax_weights_in_pytorch_model(pt_model, params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,13 +729,8 @@ def __init__(self, config, embed_tokens=None):
self.block = nn.ModuleList()
for i in range(config.num_layers):

# is_sparse = (i % sparse_step == 1) if sparse_step > 0 else False
if self.is_decoder:
even = 1
else:
even = 0

is_sparse = (i % sparse_step == even) if sparse_step > 0 else False
is_sparse = (i % sparse_step == 1) if sparse_step > 0 else False

self.block.append(
SwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse)
Expand Down

0 comments on commit 2e2be49

Please sign in to comment.