Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poor performance for reading Numpy #460

Open
2 tasks done
ZHUI opened this issue Apr 2, 2024 · 9 comments
Open
2 tasks done

Poor performance for reading Numpy #460

ZHUI opened this issue Apr 2, 2024 · 9 comments

Comments

@ZHUI
Copy link

ZHUI commented Apr 2, 2024

System Info

I test 7b model with fp32 weight, store with numpy format. I found that compared with pickle, the loading speed is slower more than 50% !!!

-rw-r--r-- 1 root root 3.6G Apr  2 16:26 checkpoint-12/model-00001-of-00008.pdparams
-rw-r--r-- 1 root root 3.6G Apr  2 16:32 checkpoint-12/model-00001-of-00008.safetensors
-rw-r--r-- 1 root root 3.1G Apr  2 16:26 checkpoint-12/model-00002-of-00008.pdparams
-rw-r--r-- 1 root root 3.1G Apr  2 16:32 checkpoint-12/model-00002-of-00008.safetensors
-rw-r--r-- 1 root root 3.1G Apr  2 16:27 checkpoint-12/model-00003-of-00008.pdparams
-rw-r--r-- 1 root root 3.1G Apr  2 16:32 checkpoint-12/model-00003-of-00008.safetensors
-rw-r--r-- 1 root root 3.1G Apr  2 16:27 checkpoint-12/model-00004-of-00008.pdparams
-rw-r--r-- 1 root root 3.1G Apr  2 16:32 checkpoint-12/model-00004-of-00008.safetensors
-rw-r--r-- 1 root root 3.1G Apr  2 16:27 checkpoint-12/model-00005-of-00008.pdparams
-rw-r--r-- 1 root root 3.1G Apr  2 16:32 checkpoint-12/model-00005-of-00008.safetensors
-rw-r--r-- 1 root root 3.1G Apr  2 16:27 checkpoint-12/model-00006-of-00008.pdparams
-rw-r--r-- 1 root root 3.1G Apr  2 16:33 checkpoint-12/model-00006-of-00008.safetensors
-rw-r--r-- 1 root root 3.1G Apr  2 16:28 checkpoint-12/model-00007-of-00008.pdparams
-rw-r--r-- 1 root root 3.1G Apr  2 16:33 checkpoint-12/model-00007-of-00008.safetensors
-rw-r--r-- 1 root root 3.6G Apr  2 16:28 checkpoint-12/model-00008-of-00008.pdparams
-rw-r--r-- 1 root root 3.6G Apr  2 16:33 checkpoint-12/model-00008-of-00008.safetensors
-rw-r--r-- 1 root root  25K Apr  2 16:10 checkpoint-12/model.safetensors.index.json

time usage.

sf model-00001-of-00008.safetensors 3.4121220111846924
pk model-00001-of-00008.safetensors 2.195117473602295
sf model-00002-of-00008.safetensors 3.004627227783203
pk model-00002-of-00008.safetensors 1.9284331798553467
sf model-00003-of-00008.safetensors 2.887206792831421
pk model-00003-of-00008.safetensors 1.8887608051300049
sf model-00004-of-00008.safetensors 2.8507916927337646
pk model-00004-of-00008.safetensors 2.080396890640259
sf model-00005-of-00008.safetensors 2.830484390258789
pk model-00005-of-00008.safetensors 1.8540270328521729
sf model-00006-of-00008.safetensors 2.8113412857055664
pk model-00006-of-00008.safetensors 1.916459321975708
sf model-00007-of-00008.safetensors 2.8474719524383545
pk model-00007-of-00008.safetensors 1.835508108139038
sf model-00008-of-00008.safetensors 3.3513264656066895
pk model-00008-of-00008.safetensors 2.2008140087127686

Information

  • The official example scripts
  • My own modified scripts

Reproduction

python3.10
safetensors=0.4.2

Expected behavior

$ cat test_safetensor.py

import pickle
import numpy as np
from safetensors.numpy import load_file
from safetensors.numpy import save_file

w = np.random.randn(256,1024,1024)

state_dict = {"weight": w}

import time
fp = "test_file.safetensors"
t1 = time.time()
save_file(state_dict,fp)
print("sf save:",  time.time()-t1)

t1 = time.time()
state = load_file(fp) 
print("sf load:", time.time()-t1)


fp = fp.replace(".safetensors", ".pickle")
t1 = time.time()
state = pickle.dump(state_dict, open(fp, "wb"))
print("pickle save:",  time.time()-t1)
 

t1 = time.time()
state = pickle.load(open(fp, "rb"))
print("pickle load:", time.time()-t1)

results:

sf save: 2.818842887878418
sf load: 1.8608193397521973
pickle save: 2.3684301376342773
pickle load: 1.004188060760498

@ZHUI
Copy link
Author

ZHUI commented Apr 2, 2024

@Narsil please help

@ZHUI
Copy link
Author

ZHUI commented Apr 2, 2024

sf save: 2.6092689037323
sf load: 1.7551813125610352
sf load2: 2.1516408920288086
pickle save: 2.2464730739593506
pickle load: 0.9010257720947266

The load API is even more slow than load_file

import pickle
import numpy as np
from safetensors.numpy import load_file, load
from safetensors.numpy import save_file

w = np.empty([256,1024,1024])

state_dict = {"weight": w}

import time
fp = "test_file.safetensors"
t1 = time.time()
save_file(state_dict,fp)
print("sf save:",  time.time()-t1)

t1 = time.time()
state = load_file(fp) 
print("sf load:", time.time()-t1)


t1 = time.time()
with open(fp, "rb") as f:
    data = f.read()
loaded = load(data)
print("sf load2:", time.time()-t1)

fp = fp.replace(".safetensors", ".pickle")
t1 = time.time()
state = pickle.dump(state_dict, open(fp, "wb"))
print("pickle save:",  time.time()-t1)
 

t1 = time.time()
state = pickle.load(open(fp, "rb"))
print("pickle load:", time.time()-t1)
    

@ZHUI
Copy link
Author

ZHUI commented Apr 2, 2024

@mishig25 can you give some help?

@ZHUI
Copy link
Author

ZHUI commented Apr 3, 2024

@LysandreJik can you give some help?

@ZHUI
Copy link
Author

ZHUI commented Apr 10, 2024

For load_file API

The core problem is memcpy for mmap memory is very slow. see:

https://stackoverflow.com/questions/52845387/improving-mmap-memcpy-file-read-performance

for my case, open(filename); f.read() is 2 GB/s, for memcpy(mmap(filename)) is 1.3 GB/s. which is mach more slower than read file.

Can we have more faster way to support.

For load API

There are additional MEM->MEM copy time!!!

f.read() copied file to the MEM, but the load API use PyByteArray::new and cause additional MEM->MEM copy!

let pydata: PyObject = PyByteArray::new(py, tensor.data()).into();

Advise

Can we support loading file without additional MEM->MEM?
If memcpy + mmap is inevitable, can we have substitution?

@Narsil
Copy link
Collaborator

Narsil commented Apr 15, 2024

for my case, open(filename); f.read() is 2 GB/s, for memcpy(mmap(filename)) is 1.3 GB/s. which is mach more slower than read file.

Something is wrong in your system, what are you using ? Windows + WSL is a usual culprit for very poor mmap support/performance.
HDD are also a big source of it although they are much less commonplace these days.

In order to make things "fast" we could always skip a few things, but that makes the thing unsafe (necessarily since Python doesn't have ownership semantics).

Pyo3 0.21 could enable something a bit faster though since we could skip the rust owned version of the tensors.

@ZHUI
Copy link
Author

ZHUI commented Apr 15, 2024

see PyO3/pyo3#4058 (comment)

https://stackoverflow.com/questions/52845387/improving-mmap-memcpy-file-read-performance

my os ubuntu 18.04. you can have a test using above scripts.

There are some suggestions for using madvice(.., MADV_SEQUENTIAL);
PyO3/pyo3#4058 (comment)

Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label May 16, 2024
@ZHUI
Copy link
Author

ZHUI commented May 16, 2024

Still a big problem.

@github-actions github-actions bot removed the Stale label May 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants