diff --git a/internal/crypto/cipher_decrypt_test.go b/internal/crypto/cipher_decrypt_test.go index 7058110280..73f129146a 100644 --- a/internal/crypto/cipher_decrypt_test.go +++ b/internal/crypto/cipher_decrypt_test.go @@ -11,13 +11,9 @@ import ( "github.com/gotd/td/internal/testutil" ) -type Zero struct{} - -func (Zero) Read(p []byte) (n int, err error) { return len(p), nil } - func TestDecrypt(t *testing.T) { // Test vector from grammers. - c := NewClientCipher(Zero{}) + c := NewClientCipher(testutil.ZeroRand{}) var msg EncryptedMessage b := &bin.Buffer{Buf: []byte{ 122, 113, 131, 194, 193, 14, 79, 77, 249, 69, 250, 154, 154, 189, 53, 231, 195, 132, @@ -57,8 +53,8 @@ func TestCipher_Decrypt(t *testing.T) { t.Fatal(err) } - c := NewClientCipher(Zero{}) - s := NewServerCipher(Zero{}) + c := NewClientCipher(testutil.ZeroRand{}) + s := NewServerCipher(testutil.ZeroRand{}) tests := []struct { name string data []byte diff --git a/internal/crypto/cipher_encrypt_test.go b/internal/crypto/cipher_encrypt_test.go index a4c4edc3ec..46eb763798 100644 --- a/internal/crypto/cipher_encrypt_test.go +++ b/internal/crypto/cipher_encrypt_test.go @@ -6,10 +6,11 @@ import ( "github.com/stretchr/testify/require" "github.com/gotd/td/bin" + "github.com/gotd/td/internal/testutil" ) func TestEncrypt(t *testing.T) { - c := NewClientCipher(Zero{}) + c := NewClientCipher(testutil.ZeroRand{}) var authKey Key for i := 0; i < 256; i++ { diff --git a/internal/crypto/rsa_pad_test.go b/internal/crypto/rsa_pad_test.go index b1a697703a..3b64573b5d 100644 --- a/internal/crypto/rsa_pad_test.go +++ b/internal/crypto/rsa_pad_test.go @@ -28,7 +28,7 @@ t6N/byY9Nw9p21Og3AoXSL2q/2IJ1WRUhebgAdGVMlV1fkuOQoEzR7EdpqtQD9Cs a.NoError(err) data := bytes.Repeat([]byte{'a'}, 144) - encrypted, err := RSAPad(data, keys[0], Zero{}) + encrypted, err := RSAPad(data, keys[0], testutil.ZeroRand{}) a.NoError(err) a.Len(encrypted, 256) diff --git a/internal/crypto/srp/hash.go b/internal/crypto/srp/hash.go new file mode 100644 index 0000000000..4372e6cb0d --- /dev/null +++ b/internal/crypto/srp/hash.go @@ -0,0 +1,118 @@ +package srp + +import ( + "crypto/sha256" + "crypto/sha512" + "math/big" + + "github.com/go-faster/errors" + "golang.org/x/crypto/pbkdf2" +) + +// Hash computes user password hash using parameters from server. +// +// See https://core.telegram.org/api/srp#checking-the-password-with-srp. +func (s SRP) Hash(password, srpB, random []byte, i Input) (Answer, error) { + p := s.bigFromBytes(i.P) + if err := checkInput(i.G, p); err != nil { + return Answer{}, errors.Wrap(err, "validate algo") + } + + g := big.NewInt(int64(i.G)) + // It is safe to use FillBytes directly because we know that 64-bit G always smaller than + // 256-bit destination array. + var gBytes [256]byte + g.FillBytes(gBytes[:]) + + // random 2048-bit number a + a := s.bigFromBytes(random) + + // `g_a = pow(g, a) mod p` + ga, ok := s.pad256FromBig(s.bigExp(g, a, p)) + if !ok { + return Answer{}, errors.New("g_a is too big") + } + + // `g_b = srp_B` + gb := s.pad256(srpB) + + // `u = H(g_a | g_b)` + u := s.bigFromBytes(s.hash(ga[:], gb[:])) + + // `x = PH2(password, salt1, salt2)` + // `v = pow(g, x) mod p` + x, v := s.computeXV(password, i.Salt1, i.Salt2, g, p) + + // `k = (k * v) mod p` + k := s.bigFromBytes(s.hash(i.P, gBytes[:])) + + // `k_v = (k * v) % p` + kv := k.Mul(k, v).Mod(k, p) + + // `t = (g_b - k_v) % p` + t := s.bigFromBytes(srpB) + if t.Sub(t, kv).Cmp(big.NewInt(0)) == -1 { + t.Add(t, p) + } + + // `s_a = pow(t, a + u * x) mod p` + sa, ok := s.pad256FromBig(s.bigExp(t, u.Mul(u, x).Add(u, a), p)) + if !ok { + return Answer{}, errors.New("s_a is too big") + } + + // `k_a = H(s_a)` + ka := sha256.Sum256(sa[:]) + + // `M1 = H(H(p) xor H(g) | H2(salt1) | H2(salt2) | g_a | g_b | k_a)` + xorHpHg := xor32(sha256.Sum256(i.P), sha256.Sum256(gBytes[:])) + M1 := s.hash( + xorHpHg[:], + s.hash(i.Salt1), + s.hash(i.Salt2), + ga[:], + gb[:], + ka[:], + ) + + return Answer{ + A: ga[:], + M1: M1, + }, nil +} + +// The main hashing function H is sha256: +// +// H(data) := sha256(data) +func (s SRP) hash(data ...[]byte) []byte { + h := sha256.New() + for i := range data { + h.Write(data[i]) + } + return h.Sum(nil) +} + +// The salting hashing function SH is defined as follows: +// +// SH(data, salt) := H(salt | data | salt) +func (s SRP) saltHash(data, salt []byte) []byte { + return s.hash(salt, data, salt) +} + +// The primary password hashing function is defined as follows: +// +// PH1(password, salt1, salt2) := SH(SH(password, salt1), salt2) +func (s SRP) primary(password, salt1, salt2 []byte) []byte { + return s.saltHash(s.saltHash(password, salt1), salt2) +} + +// The secondary password hashing function is defined as follows: +// +// PH2(password, salt1, salt2) := SH(pbkdf2(sha512, PH1(password, salt1, salt2), salt1, 100000), salt2) +func (s SRP) secondary(password, salt1, salt2 []byte) []byte { + return s.saltHash(s.pbkdf2(s.primary(password, salt1, salt2), salt1, 100000), salt2) +} + +func (s SRP) pbkdf2(ph1, salt1 []byte, n int) []byte { + return pbkdf2.Key(ph1, salt1, n, 64, sha512.New) +} diff --git a/internal/crypto/srp/new_hash.go b/internal/crypto/srp/new_hash.go new file mode 100644 index 0000000000..b5bb44a9bc --- /dev/null +++ b/internal/crypto/srp/new_hash.go @@ -0,0 +1,54 @@ +package srp + +import ( + "io" + "math/big" + + "github.com/go-faster/errors" +) + +// computeXV computes following numbers +// +// `x = PH2(password, salt1, salt2)` +// `v = pow(g, x) mod p` +// +// TDLib uses terms `clientSalt` for `salt1` and `serverSalt` for `salt2`. +func (s SRP) computeXV(password, clientSalt, serverSalt []byte, g, p *big.Int) (x, v *big.Int) { + // `x = PH2(password, salt1, salt2)` + x = new(big.Int).SetBytes(s.secondary(password, clientSalt, serverSalt)) + // `v = pow(g, x) mod p` + v = new(big.Int).Exp(g, x, p) + return x, v +} + +// NewHash computes new user password hash using parameters from server. +// +// See https://core.telegram.org/api/srp#setting-a-new-2fa-password. +// +// TDLib implementation: +// See https://github.com/tdlib/td/blob/fa8feefed70d64271945e9d5fd010b957d93c8cd/td/telegram/PasswordManager.cpp#L57. +// +// TDesktop implementation: +// See https://github.com/telegramdesktop/tdesktop/blob/v3.4.8/Telegram/SourceFiles/core/core_cloud_password.cpp#L68. +func (s SRP) NewHash(password []byte, i Input) (hash, newSalt []byte, _ error) { + // Generate a new new_password_hash using the KDF algorithm specified in the new_settings, + // just append 32 sufficiently random bytes to the salt1, first. Proceed as for checking passwords with SRP, + // just stop at the generation of the v parameter, and use it as new_password_hash: + p := new(big.Int).SetBytes(i.P) + if err := checkInput(i.G, p); err != nil { + return nil, nil, errors.Wrap(err, "validate algo") + } + + // Make a copy. + newClientSalt := append([]byte(nil), i.Salt1...) + newClientSalt = append(newClientSalt, make([]byte, 32)...) + // ... append 32 sufficiently random bytes to the salt1 ... + if _, err := io.ReadFull(s.random, newClientSalt[len(newClientSalt)-32:]); err != nil { + return nil, nil, err + } + + _, v := s.computeXV(password, newClientSalt, i.Salt2, big.NewInt(int64(i.G)), p) + // As usual in big endian form, padded to 2048 bits. + padded, _ := s.pad256FromBig(v) + return padded[:], newClientSalt, nil +} diff --git a/internal/crypto/srp/new_hash_test.go b/internal/crypto/srp/new_hash_test.go new file mode 100644 index 0000000000..bcf7be97c6 --- /dev/null +++ b/internal/crypto/srp/new_hash_test.go @@ -0,0 +1,72 @@ +package srp + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gotd/td/internal/testutil" +) + +func TestSRP_NewHash(t *testing.T) { + password := []uint8{ + 110, 101, 119, 80, 97, 115, 115, 119, 111, 114, 100, + } + i := Input{ + Salt1: []uint8{ + 230, 200, 149, 125, 223, 152, 141, 72, + }, + Salt2: []uint8{ + 159, 99, 68, 130, 43, 9, 108, 255, 135, 239, 164, 38, 245, 120, 87, 182, + }, + G: 3, + P: []uint8{ + 199, 28, 174, 185, 198, 177, 201, 4, 142, 108, 82, 47, 112, 241, 63, 115, + 152, 13, 64, 35, 142, 62, 33, 193, 73, 52, 208, 55, 86, 61, 147, 15, + 72, 25, 138, 10, 167, 193, 64, 88, 34, 148, 147, 210, 37, 48, 244, 219, + 250, 51, 111, 110, 10, 201, 37, 19, 149, 67, 174, 212, 76, 206, 124, 55, + 32, 253, 81, 246, 148, 88, 112, 90, 198, 140, 212, 254, 107, 107, 19, 171, + 220, 151, 70, 81, 41, 105, 50, 132, 84, 241, 143, 175, 140, 89, 95, 100, + 36, 119, 254, 150, 187, 42, 148, 29, 91, 205, 29, 74, 200, 204, 73, 136, + 7, 8, 250, 155, 55, 142, 60, 79, 58, 144, 96, 190, 230, 124, 249, 164, + 164, 166, 149, 129, 16, 81, 144, 126, 22, 39, 83, 181, 107, 15, 107, 65, + 13, 186, 116, 216, 168, 75, 42, 20, 179, 20, 78, 14, 241, 40, 71, 84, + 253, 23, 237, 149, 13, 89, 101, 180, 185, 221, 70, 88, 45, 177, 23, 141, + 22, 156, 107, 196, 101, 176, 214, 255, 156, 163, 146, 143, 239, 91, 154, 228, + 228, 24, 252, 21, 232, 62, 190, 160, 248, 127, 169, 255, 94, 237, 112, 5, + 13, 237, 40, 73, 244, 123, 249, 89, 217, 86, 133, 12, 233, 41, 133, 31, + 13, 129, 21, 246, 53, 177, 5, 238, 46, 78, 21, 208, 75, 36, 84, 191, + 111, 79, 173, 240, 52, 177, 4, 3, 17, 156, 216, 227, 185, 47, 204, 91, + }, + } + expectedHash := []uint8{ + 24, 106, 193, 141, 204, 87, 144, 191, 107, 186, 33, 189, 149, 141, 55, 94, + 229, 72, 26, 240, 2, 158, 155, 215, 169, 198, 142, 201, 38, 189, 81, 150, + 216, 31, 140, 216, 181, 142, 224, 108, 138, 16, 173, 234, 204, 127, 86, 232, + 25, 255, 81, 72, 37, 222, 177, 91, 31, 173, 236, 106, 174, 23, 162, 68, + 203, 35, 72, 141, 23, 52, 156, 212, 38, 26, 139, 164, 218, 123, 156, 44, + 229, 196, 0, 20, 221, 158, 54, 39, 80, 172, 243, 172, 137, 184, 184, 245, + 198, 24, 240, 182, 165, 114, 195, 143, 255, 58, 85, 77, 136, 24, 160, 184, + 231, 182, 1, 94, 24, 54, 18, 138, 30, 78, 45, 92, 249, 151, 29, 29, + 208, 72, 170, 24, 29, 134, 17, 82, 234, 231, 21, 83, 150, 38, 128, 99, + 35, 135, 184, 154, 193, 134, 95, 222, 215, 200, 195, 218, 166, 78, 211, 141, + 194, 80, 54, 102, 63, 160, 207, 119, 72, 197, 46, 161, 156, 24, 126, 112, + 167, 82, 168, 5, 62, 64, 157, 72, 148, 33, 138, 66, 147, 147, 208, 51, + 130, 228, 30, 80, 183, 65, 91, 59, 138, 208, 146, 253, 7, 144, 248, 141, + 137, 78, 132, 220, 167, 143, 71, 244, 33, 137, 55, 215, 170, 153, 216, 140, + 135, 192, 155, 203, 141, 168, 144, 229, 53, 2, 102, 35, 206, 166, 252, 139, + 61, 37, 219, 112, 203, 66, 170, 164, 131, 35, 146, 125, 135, 168, 252, 241, + } + expectedNewSalt := []uint8{ + 230, 200, 149, 125, 223, 152, 141, 72, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + } + + a := require.New(t) + s := NewSRP(testutil.ZeroRand{}) + hash, newSalt, err := s.NewHash(password, i) + a.NoError(err) + a.Equal(expectedHash, hash) + a.Equal(expectedNewSalt, newSalt) +} diff --git a/internal/crypto/srp/pad.go b/internal/crypto/srp/pad.go new file mode 100644 index 0000000000..a0f29f5b6b --- /dev/null +++ b/internal/crypto/srp/pad.go @@ -0,0 +1,23 @@ +package srp + +import ( + "math/big" + + "github.com/gotd/td/internal/crypto" +) + +func (s SRP) pad256FromBig(i *big.Int) (b [256]byte, r bool) { + r = crypto.FillBytes(i, b[:]) + return b, r +} + +func (s SRP) pad256(b []byte) [256]byte { + if len(b) >= 256 { + return *(*[256]byte)(b[len(b)-256:]) + } + + var tmp [256]byte + copy(tmp[256-len(b):], b) + + return tmp +} diff --git a/internal/crypto/srp/srp.go b/internal/crypto/srp/srp.go index 0e48018ae0..a13d4b2fdd 100644 --- a/internal/crypto/srp/srp.go +++ b/internal/crypto/srp/srp.go @@ -2,13 +2,10 @@ package srp import ( - "crypto/sha512" + "crypto/sha256" "io" "math/big" - "github.com/go-faster/errors" - "golang.org/x/crypto/pbkdf2" - "github.com/go-faster/xor" "github.com/gotd/td/internal/crypto" @@ -48,100 +45,8 @@ type Answer struct { M1 []byte } -// Hash computes user password hash using parameters from server. -// Parameters -// -// See https://core.telegram.org/api/srp#checking-the-password-with-srp. -func (s SRP) Hash(password, srpB, random []byte, i Input) (Answer, error) { - if err := s.checkGP(i.G, i.P); err != nil { - return Answer{}, errors.Wrap(err, "validate algo") - } - - p := s.bigFromBytes(i.P) - g := big.NewInt(int64(i.G)) - gBytes, ok := s.paddedFromBig(g) - if !ok { - return Answer{}, errors.Errorf("invalid g (%d)", i.G) - } - - // random 2048-bit number a - a := s.bigFromBytes(random) - - // `g_a = pow(g, a) mod p` - ga, ok := s.paddedFromBig(s.bigExp(g, a, p)) - if !ok { - return Answer{}, errors.New("g_a is too big") - } - - // `g_b = srp_B` - gb := s.pad256(srpB) - - // `u = H(g_a | g_b)` - u := s.bigFromBytes(s.hash(ga, gb)) - - // `x = PH2(password, salt1, salt2)` - x := s.bigFromBytes(s.secondary(password, i.Salt1, i.Salt2)) - - // `v = pow(g, x) mod p` - v := s.bigExp(g, x, p) - - // `k = (k * v) mod p` - k := s.bigFromBytes(s.hash(i.P, gBytes)) - - // `k_v = (k * v) % p` - kv := k.Mul(k, v).Mod(k, p) - - // `t = (g_b - k_v) % p` - t := s.bigFromBytes(srpB) - if t.Sub(t, kv).Cmp(big.NewInt(0)) == -1 { - t.Add(t, p) - } - - // `s_a = pow(t, a + u * x) mod p` - sa, ok := s.paddedFromBig(s.bigExp(t, u.Mul(u, x).Add(u, a), p)) - if !ok { - return Answer{}, errors.New("s_a is too big") - } - - // `k_a = H(s_a)` - ka := s.hash(sa) - - // `M1 = H(H(p) xor H(g) | H2(salt1) | H2(salt2) | g_a | g_b | k_a)` - M1 := s.hash( - s.bytesXor(s.hash(i.P), s.hash(gBytes)), - s.hash(i.Salt1), - s.hash(i.Salt2), - ga, - gb, - ka, - ) - - return Answer{ - A: ga, - M1: M1, - }, nil -} - -func (s SRP) paddedFromBig(i *big.Int) ([]byte, bool) { - b := make([]byte, 256) - r := crypto.FillBytes(i, b) - return b, r -} - -func (s SRP) pad256(b []byte) []byte { - if len(b) >= 256 { - return b[len(b)-256:] - } - - var tmp [256]byte - copy(tmp[256-len(b):], b) - - return tmp[:] -} - -func (s SRP) bytesXor(a, b []byte) []byte { - res := make([]byte, len(a)) - xor.Bytes(res, a, b) +func xor32(a, b [sha256.Size]byte) (res [sha256.Size]byte) { + xor.Bytes(res[:], a[:], b[:]) return res } @@ -153,39 +58,6 @@ func (s SRP) bigExp(x, y, m *big.Int) *big.Int { return new(big.Int).Exp(x, y, m) } -// The main hashing function H is sha256: -// -// H(data) := sha256(data) -func (s SRP) hash(data ...[]byte) []byte { - return crypto.SHA256(data...) -} - -// The salting hashing function SH is defined as follows: -// -// SH(data, salt) := H(salt | data | salt) -func (s SRP) saltHash(data, salt []byte) []byte { - return s.hash(salt, data, salt) -} - -// The primary password hashing function is defined as follows: -// -// PH1(password, salt1, salt2) := SH(SH(password, salt1), salt2) -func (s SRP) primary(password, salt1, salt2 []byte) []byte { - return s.saltHash(s.saltHash(password, salt1), salt2) -} - -// The secondary password hashing function is defined as follows: -// -// PH2(password, salt1, salt2) := SH(pbkdf2(sha512, PH1(password, salt1, salt2), salt1, 100000), salt2) -func (s SRP) secondary(password, salt1, salt2 []byte) []byte { - return s.saltHash(s.pbkdf2(s.primary(password, salt1, salt2), salt1, 100000), salt2) -} - -func (s SRP) pbkdf2(ph1, salt1 []byte, n int) []byte { - return pbkdf2.Key(ph1, salt1, n, 64, sha512.New) -} - -func (s SRP) checkGP(g int, pBytes []byte) error { - p := s.bigFromBytes(pBytes) +func checkInput(g int, p *big.Int) error { return crypto.CheckDH(g, p) } diff --git a/internal/crypto/srp/srp_test.go b/internal/crypto/srp/srp_test.go index 6431e6df1c..0b4e82b794 100644 --- a/internal/crypto/srp/srp_test.go +++ b/internal/crypto/srp/srp_test.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "encoding/hex" "fmt" + "math/big" "testing" "github.com/stretchr/testify/assert" @@ -61,7 +62,7 @@ func TestSRP(t *testing.T) { } for i := range tests { tcase := tests[i] - t.Run(fmt.Sprintf("#%v", i), func(t *testing.T) { + t.Run(fmt.Sprintf("Test%d", i+1), func(t *testing.T) { random := setByte(256, 1) srp := NewSRP(rand.Reader) got, err := srp.Hash(tcase.args.password, tcase.args.srpB, random, tcase.args.mp) @@ -95,15 +96,13 @@ const pBase64 = "xxyuucaxyQSObFIvcPE_c5gNQCOOPiHBSTTQN1Y9kw9IGYoKp8FAWCKUk9IlMPT "X7ZUNWWW0ud1GWC2xF40WnGvEZbDW_5yjko_vW5rk5Bj8Feg-vqD4f6n_Xu1wBQ3tKEn0e_lZ2VaFDOkphR8NgRX2NbEF7i5OFdBLJFS_b0-" + "t8DSxBAMRnNjjuS_MWw==" -func TestSRP_checkPG(t *testing.T) { - s := NewSRP(nil) - +func Test_checkInput(t *testing.T) { p, err := base64.URLEncoding.DecodeString(pBase64) if err != nil { t.Fatal("no err expected", err) } - err = s.checkGP(3, p) + err = checkInput(3, big.NewInt(0).SetBytes(p)) if err != nil { t.Fatal("no err expected", err) } diff --git a/internal/mtproto/handle_message_fuzz.go b/internal/mtproto/handle_message_fuzz.go index 3adc479a4b..4cadab31bc 100644 --- a/internal/mtproto/handle_message_fuzz.go +++ b/internal/mtproto/handle_message_fuzz.go @@ -13,14 +13,11 @@ import ( "github.com/gotd/td/internal/mt" "github.com/gotd/td/internal/proto" "github.com/gotd/td/internal/rpc" + "github.com/gotd/td/internal/testutil" "github.com/gotd/td/internal/tmap" "github.com/gotd/td/tg" ) -type Zero struct{} - -func (Zero) Read(p []byte) (n int, err error) { return len(p), nil } - type fuzzHandler struct { types *tmap.Constructor } @@ -67,7 +64,7 @@ func init() { ), } c := &Conn{ - rand: Zero{}, + rand: testutil.ZeroRand{}, rpc: rpc.New(rpc.NopSend, rpc.Options{}), log: zap.NewNop(), messageID: proto.NewMessageIDGen(time.Now), diff --git a/internal/mtproto/zero_test.go b/internal/mtproto/zero_test.go index c8759408d8..bbcdc4ff20 100644 --- a/internal/mtproto/zero_test.go +++ b/internal/mtproto/zero_test.go @@ -3,6 +3,6 @@ package mtproto -type Zero struct{} +import "github.com/gotd/td/internal/testutil" -func (Zero) Read(p []byte) (n int, err error) { return len(p), nil } +type Zero = testutil.ZeroRand diff --git a/internal/testutil/rand.go b/internal/testutil/rand.go index 29512d7ba5..685765726c 100644 --- a/internal/testutil/rand.go +++ b/internal/testutil/rand.go @@ -5,6 +5,12 @@ import ( "math/rand" ) +// ZeroRand is zero random source. +type ZeroRand struct{} + +// Read implements io.Reader. +func (ZeroRand) Read(p []byte) (n int, err error) { return len(p), nil } + func randSeed(data []byte) int64 { if len(data) == 0 { return 0 diff --git a/tdjson/encoder_test.go b/tdjson/encoder_test.go index 37ec267376..e4dd550695 100644 --- a/tdjson/encoder_test.go +++ b/tdjson/encoder_test.go @@ -88,6 +88,10 @@ func TestEncodeDecode(t *testing.T) { }, }, }, + // Test empty array. + &tdapi.ReplyMarkupInlineKeyboard{ + Rows: [][]tdapi.InlineKeyboardButton{}, + }, } for _, typ := range types { diff --git a/telegram/auth/auth_test.go b/telegram/auth/auth_test.go index e3b1c54ff9..8fde0bce0b 100644 --- a/telegram/auth/auth_test.go +++ b/telegram/auth/auth_test.go @@ -2,8 +2,13 @@ package auth import ( "crypto/rand" + "testing" + "github.com/stretchr/testify/require" + + "github.com/gotd/td/internal/testutil" "github.com/gotd/td/tg" + "github.com/gotd/td/tgmock" ) const ( @@ -19,3 +24,21 @@ func testClient(invoker tg.Invoker) *Client { appHash: testAppHash, } } + +func mockClient(t *testing.T) (*tgmock.Mock, *Client) { + mock := tgmock.New(t) + return mock, NewClient(tg.NewClient(mock), testutil.ZeroRand{}, testAppID, testAppHash) +} + +func mockTest(cb func( + a *require.Assertions, + mock *tgmock.Mock, + client *Client, +)) func(t *testing.T) { + return func(t *testing.T) { + a := require.New(t) + m, client := mockClient(t) + + cb(a, m, client) + } +} diff --git a/telegram/auth/flow.go b/telegram/auth/flow.go index 9e90bec63b..3267ccd487 100644 --- a/telegram/auth/flow.go +++ b/telegram/auth/flow.go @@ -233,12 +233,16 @@ func (t testAuth) Code(ctx context.Context, sentCode *tg.AuthSentCode) (string, GetLength() int } - typ, ok := sentCode.Type.(notFlashing) - if !ok { - return "", errors.Errorf("unexpected type: %T", sentCode.Type) + length := 5 + if sentCode != nil { + typ, ok := sentCode.Type.(notFlashing) + if !ok { + return "", errors.Errorf("unexpected type: %T", sentCode.Type) + } + length = typ.GetLength() } - return strings.Repeat(strconv.Itoa(t.dc), typ.GetLength()), nil + return strings.Repeat(strconv.Itoa(t.dc), length), nil } func (t testAuth) AcceptTermsOfService(ctx context.Context, tos tg.HelpTermsOfService) error { diff --git a/telegram/auth/flow_test.go b/telegram/auth/flow_test.go index 457f0f0a0a..155b43fc33 100644 --- a/telegram/auth/flow_test.go +++ b/telegram/auth/flow_test.go @@ -2,6 +2,7 @@ package auth_test import ( "context" + "strings" "testing" "github.com/stretchr/testify/require" @@ -81,3 +82,33 @@ func TestEnvAuth(t *testing.T) { a.NoError(err) a.Equal("password", result) } + +func TestTestAuth(t *testing.T) { + a := require.New(t) + ctx := context.Background() + testAuth := auth.Test(testutil.ZeroRand{}, 2) + + _, err := testAuth.Code(ctx, &tg.AuthSentCode{ + Type: &tg.AuthSentCodeTypeFlashCall{}, + }) + a.Error(err) + + result, err := testAuth.Code(ctx, nil) + a.NoError(err) + a.Equal("22222", result) + + result, err = testAuth.Code(ctx, &tg.AuthSentCode{ + Type: &tg.AuthSentCodeTypeApp{ + Length: 1, + }, + }) + a.NoError(err) + a.Equal("2", result) + + result, err = testAuth.Phone(ctx) + a.NoError(err) + a.True(strings.HasPrefix(result, "999662")) + + _, err = testAuth.Password(ctx) + a.ErrorIs(err, auth.ErrPasswordNotProvided) +} diff --git a/telegram/auth/password.go b/telegram/auth/password.go new file mode 100644 index 0000000000..91742ab7fc --- /dev/null +++ b/telegram/auth/password.go @@ -0,0 +1,179 @@ +package auth + +import ( + "context" + "fmt" + "time" + + "github.com/go-faster/errors" + + "github.com/gotd/td/internal/crypto" + "github.com/gotd/td/internal/crypto/srp" + "github.com/gotd/td/tg" +) + +// PasswordHash computes password hash to log in. +// +// See https://core.telegram.org/api/srp#checking-the-password-with-srp. +func PasswordHash( + password []byte, + srpID int64, + srpB, secureRandom []byte, + alg tg.PasswordKdfAlgoClass, +) (*tg.InputCheckPasswordSRP, error) { + s := srp.NewSRP(crypto.DefaultRand()) + + algo, ok := alg.(*tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow) + if !ok { + return nil, errors.Errorf("unsupported algo: %T", alg) + } + + a, err := s.Hash(password, srpB, secureRandom, srp.Input(*algo)) + if err != nil { + return nil, errors.Wrap(err, "create SRP answer") + } + + return &tg.InputCheckPasswordSRP{ + SRPID: srpID, + A: a.A, + M1: a.M1, + }, nil +} + +// NewPasswordHash computes new password hash to update password. +// +// Notice that NewPasswordHash mutates given alg. +// +// See https://core.telegram.org/api/srp#setting-a-new-2fa-password. +func NewPasswordHash( + password []byte, + algo *tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow, +) (hash []byte, _ error) { + s := srp.NewSRP(crypto.DefaultRand()) + + hash, newSalt, err := s.NewHash(password, srp.Input(*algo)) + if err != nil { + return nil, errors.Wrap(err, "create SRP answer") + } + algo.Salt1 = newSalt + + return hash, nil +} + +var ( + emptyPassword tg.InputCheckPasswordSRPClass = &tg.InputCheckPasswordEmpty{} +) + +// UpdatePasswordOptions is options structure for UpdatePassword. +type UpdatePasswordOptions struct { + // Hint is new password hint. + Hint string + // Password is password callback. + // + // If password was requested and Password is nil, ErrPasswordNotProvided error will be returned. + Password func(ctx context.Context) (string, error) +} + +// UpdatePassword sets new cloud password for this account. +// +// See https://core.telegram.org/api/srp#setting-a-new-2fa-password. +func (c *Client) UpdatePassword( + ctx context.Context, + newPassword string, + opts UpdatePasswordOptions, +) error { + p, err := c.api.AccountGetPassword(ctx) + if err != nil { + return errors.Wrap(err, "get SRP parameters") + } + + algo, ok := p.NewAlgo.(*tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow) + if !ok { + return errors.Errorf("unsupported algo: %T", p.NewAlgo) + } + + newHash, err := NewPasswordHash([]byte(newPassword), algo) + if err != nil { + return errors.Wrap(err, "compute new password hash") + } + + var old = emptyPassword + if p.HasPassword { + if opts.Password == nil { + return ErrPasswordNotProvided + } + + oldPassword, err := opts.Password(ctx) + if err != nil { + return errors.Wrap(err, "get password") + } + + hash, err := PasswordHash([]byte(oldPassword), p.SRPID, p.SRPB, p.SecureRandom, p.CurrentAlgo) + if err != nil { + return errors.Wrap(err, "compute old password hash") + } + old = hash + } + + if _, err := c.api.AccountUpdatePasswordSettings(ctx, &tg.AccountUpdatePasswordSettingsRequest{ + Password: old, + NewSettings: tg.AccountPasswordInputSettings{ + NewAlgo: algo, + NewPasswordHash: newHash, + Hint: opts.Hint, + }, + }); err != nil { + return errors.Wrap(err, "update password") + } + return nil +} + +// ResetFailedWaitError reports that you recently requested a password reset that was cancel and need to wait until the +// specified date before requesting another reset. +type ResetFailedWaitError struct { + Result tg.AccountResetPasswordFailedWait +} + +// Until returns time required to wait. +func (r ResetFailedWaitError) Until() time.Duration { + retryDate := time.Unix(int64(r.Result.RetryDate), 0) + return time.Until(retryDate) +} + +// Error implements error. +func (r *ResetFailedWaitError) Error() string { + return fmt.Sprintf("wait to reset password (%s)", r.Until()) +} + +// ResetPassword resets cloud password and returns time to wait until reset be performed. +// If time is zero, password was successfully reset. +// +// May return ResetFailedWaitError. +// +// See https://core.telegram.org/api/srp#password-reset. +func (c *Client) ResetPassword(ctx context.Context) (time.Time, error) { + r, err := c.api.AccountResetPassword(ctx) + if err != nil { + return time.Time{}, errors.Wrap(err, "reset password") + } + switch v := r.(type) { + case *tg.AccountResetPasswordFailedWait: + return time.Time{}, &ResetFailedWaitError{Result: *v} + case *tg.AccountResetPasswordRequestedWait: + return time.Unix(int64(v.UntilDate), 0), nil + case *tg.AccountResetPasswordOk: + return time.Time{}, nil + default: + return time.Time{}, errors.Errorf("unexpected type %T", v) + } +} + +// CancelPasswordReset cancels password reset. +// +// See https://core.telegram.org/api/srp#password-reset. +func (c *Client) CancelPasswordReset(ctx context.Context) error { + if _, err := c.api.AccountDeclinePasswordReset(ctx); err != nil { + return errors.Wrap(err, "cancel password reset") + } + return nil +} diff --git a/telegram/auth/password_example_test.go b/telegram/auth/password_example_test.go new file mode 100644 index 0000000000..a7b4fd70c8 --- /dev/null +++ b/telegram/auth/password_example_test.go @@ -0,0 +1,62 @@ +package auth_test + +import ( + "context" + "fmt" + + "github.com/go-faster/errors" + + "github.com/gotd/td/telegram" + "github.com/gotd/td/telegram/auth" +) + +func ExampleClient_UpdatePassword() { + ctx := context.Background() + client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{}) + if err := client.Run(ctx, func(ctx context.Context) error { + // Updating password. + if err := client.Auth().UpdatePassword(ctx, "new_password", auth.UpdatePasswordOptions{ + // Hint sets new password hint. + Hint: "new password hint", + // Password will be called if old password is requested by Telegram. + // + // If password was requested and Password is nil, auth.ErrPasswordNotProvided error will be returned. + Password: func(ctx context.Context) (string, error) { + return "old_password", nil + }, + }); err != nil { + return err + } + + return nil + }); err != nil { + panic(err) + } +} + +func ExampleClient_ResetPassword() { + ctx := context.Background() + client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{}) + if err := client.Run(ctx, func(ctx context.Context) error { + wait, err := client.Auth().ResetPassword(ctx) + var waitErr *auth.ResetFailedWaitError + switch { + case errors.As(err, &waitErr): + // Telegram requested wait until making new reset request. + fmt.Printf("Wait until %s to reset password.\n", wait.String()) + case err != nil: + return err + } + + // If returned time is zero, password was successfully reset. + if wait.IsZero() { + fmt.Println("Password was reset.") + return nil + } + + fmt.Printf("Password will be reset on %s.\n", wait.String()) + return nil + }); err != nil { + panic(err) + } +} diff --git a/telegram/auth/password_test.go b/telegram/auth/password_test.go new file mode 100644 index 0000000000..dca4ea0ac9 --- /dev/null +++ b/telegram/auth/password_test.go @@ -0,0 +1,169 @@ +package auth + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/gotd/td/bin" + "github.com/gotd/td/internal/testutil" + "github.com/gotd/td/tg" + "github.com/gotd/td/tgmock" +) + +func TestPasswordHash(t *testing.T) { + a := require.New(t) + _, err := PasswordHash(nil, 0, nil, nil, nil) + a.Error(err, "unsupported algo") +} + +var testAlgo = &tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow{ + Salt1: []uint8{ + 230, 200, 149, 125, 223, 152, 141, 72, + }, + Salt2: []uint8{ + 159, 99, 68, 130, 43, 9, 108, 255, 135, 239, 164, 38, 245, 120, 87, 182, + }, + G: 3, + P: []uint8{ + 199, 28, 174, 185, 198, 177, 201, 4, 142, 108, 82, 47, 112, 241, 63, 115, + 152, 13, 64, 35, 142, 62, 33, 193, 73, 52, 208, 55, 86, 61, 147, 15, + 72, 25, 138, 10, 167, 193, 64, 88, 34, 148, 147, 210, 37, 48, 244, 219, + 250, 51, 111, 110, 10, 201, 37, 19, 149, 67, 174, 212, 76, 206, 124, 55, + 32, 253, 81, 246, 148, 88, 112, 90, 198, 140, 212, 254, 107, 107, 19, 171, + 220, 151, 70, 81, 41, 105, 50, 132, 84, 241, 143, 175, 140, 89, 95, 100, + 36, 119, 254, 150, 187, 42, 148, 29, 91, 205, 29, 74, 200, 204, 73, 136, + 7, 8, 250, 155, 55, 142, 60, 79, 58, 144, 96, 190, 230, 124, 249, 164, + 164, 166, 149, 129, 16, 81, 144, 126, 22, 39, 83, 181, 107, 15, 107, 65, + 13, 186, 116, 216, 168, 75, 42, 20, 179, 20, 78, 14, 241, 40, 71, 84, + 253, 23, 237, 149, 13, 89, 101, 180, 185, 221, 70, 88, 45, 177, 23, 141, + 22, 156, 107, 196, 101, 176, 214, 255, 156, 163, 146, 143, 239, 91, 154, 228, + 228, 24, 252, 21, 232, 62, 190, 160, 248, 127, 169, 255, 94, 237, 112, 5, + 13, 237, 40, 73, 244, 123, 249, 89, 217, 86, 133, 12, 233, 41, 133, 31, + 13, 129, 21, 246, 53, 177, 5, 238, 46, 78, 21, 208, 75, 36, 84, 191, + 111, 79, 173, 240, 52, 177, 4, 3, 17, 156, 216, 227, 185, 47, 204, 91, + }, +} + +func TestClient_UpdatePassword(t *testing.T) { + ctx := context.Background() + expectCall := func(a *require.Assertions, m *tgmock.Mock, hasPassword bool) *tgmock.RequestBuilder { + p := &tg.AccountPassword{ + HasPassword: hasPassword, + NewAlgo: testAlgo, + NewSecureAlgo: &tg.SecurePasswordKdfAlgoUnknown{}, + } + if hasPassword { + p.CurrentAlgo = testAlgo + } + p.SetFlags() + return m.ExpectCall(&tg.AccountGetPasswordRequest{}). + ThenResult(p).ExpectFunc(func(b bin.Encoder) { + a.IsType(&tg.AccountUpdatePasswordSettingsRequest{}, b) + r := b.(*tg.AccountUpdatePasswordSettingsRequest) + + if !hasPassword { + a.Equal(emptyPassword, r.Password) + } else { + a.NotEqual(emptyPassword, r.Password) + } + a.NotEmpty(r.NewSettings.NewPasswordHash) + a.Equal("hint", r.NewSettings.Hint) + }) + } + + t.Run("PasswordNotRequired", mockTest(func( + a *require.Assertions, + m *tgmock.Mock, + client *Client, + ) { + m.ExpectCall(&tg.AccountGetPasswordRequest{}).ThenErr(testutil.TestError()) + a.Error(client.UpdatePassword(ctx, "", UpdatePasswordOptions{})) + + expectCall(a, m, false).ThenTrue() + a.NoError(client.UpdatePassword(ctx, "", UpdatePasswordOptions{ + Hint: "hint", + })) + })) + + t.Run("PasswordRequired", mockTest(func( + a *require.Assertions, + m *tgmock.Mock, + client *Client, + ) { + m.ExpectCall(&tg.AccountGetPasswordRequest{}). + ThenResult(&tg.AccountPassword{ + HasPassword: true, + NewAlgo: testAlgo, + CurrentAlgo: testAlgo, + NewSecureAlgo: &tg.SecurePasswordKdfAlgoUnknown{}, + }) + a.ErrorIs(client.UpdatePassword(ctx, "", UpdatePasswordOptions{}), ErrPasswordNotProvided) + + m.ExpectCall(&tg.AccountGetPasswordRequest{}). + ThenResult(&tg.AccountPassword{ + HasPassword: true, + NewAlgo: testAlgo, + CurrentAlgo: testAlgo, + NewSecureAlgo: &tg.SecurePasswordKdfAlgoUnknown{}, + }) + a.ErrorIs(client.UpdatePassword(ctx, "", UpdatePasswordOptions{ + Hint: "hint", + Password: func(ctx context.Context) (string, error) { + return "", testutil.TestError() + }, + }), testutil.TestError()) + + expectCall(a, m, true).ThenTrue() + a.NoError(client.UpdatePassword(ctx, "", UpdatePasswordOptions{ + Hint: "hint", + Password: func(ctx context.Context) (string, error) { + return "password", nil + }, + })) + })) +} + +func TestClient_ResetPassword(t *testing.T) { + ctx := context.Background() + wait := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + mockTest(func(a *require.Assertions, mock *tgmock.Mock, client *Client) { + mock.ExpectCall(&tg.AccountResetPasswordRequest{}).ThenErr(testutil.TestError()) + _, err := client.ResetPassword(ctx) + a.Error(err) + + mock.ExpectCall(&tg.AccountResetPasswordRequest{}).ThenResult(&tg.AccountResetPasswordFailedWait{ + RetryDate: int(wait), + }) + var waitErr *ResetFailedWaitError + _, err = client.ResetPassword(ctx) + a.ErrorAs(err, &waitErr) + a.Equal(int(wait), waitErr.Result.RetryDate) + a.NotEmpty(waitErr.Error()) + + mock.ExpectCall(&tg.AccountResetPasswordRequest{}).ThenResult(&tg.AccountResetPasswordOk{}) + r, err := client.ResetPassword(ctx) + a.NoError(err) + a.True(r.IsZero()) + + mock.ExpectCall(&tg.AccountResetPasswordRequest{}).ThenResult(&tg.AccountResetPasswordRequestedWait{ + UntilDate: int(wait), + }) + r, err = client.ResetPassword(ctx) + a.NoError(err) + a.False(r.IsZero()) + })(t) +} + +func TestClient_CancelPasswordReset(t *testing.T) { + ctx := context.Background() + mockTest(func(a *require.Assertions, mock *tgmock.Mock, client *Client) { + mock.ExpectCall(&tg.AccountDeclinePasswordResetRequest{}).ThenErr(testutil.TestError()) + a.Error(client.CancelPasswordReset(ctx)) + + mock.ExpectCall(&tg.AccountDeclinePasswordResetRequest{}).ThenTrue() + a.NoError(client.CancelPasswordReset(ctx)) + })(t) +} diff --git a/telegram/auth/self_test.go b/telegram/auth/self_test.go new file mode 100644 index 0000000000..0b655ed1a8 --- /dev/null +++ b/telegram/auth/self_test.go @@ -0,0 +1,42 @@ +package auth + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gotd/td/internal/testutil" + "github.com/gotd/td/tg" + "github.com/gotd/td/tgmock" +) + +func TestClient_self(t *testing.T) { + ctx := context.Background() + mockTest(func(a *require.Assertions, mock *tgmock.Mock, client *Client) { + mock.ExpectCall(&tg.UsersGetUsersRequest{ + ID: []tg.InputUserClass{&tg.InputUserSelf{}}, + }).ThenErr(testutil.TestError()) + _, err := client.self(ctx) + a.Error(err) + + mock.ExpectCall(&tg.UsersGetUsersRequest{ + ID: []tg.InputUserClass{&tg.InputUserSelf{}}, + }).ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{&tg.UserEmpty{ + ID: 10, + }}}) + _, err = client.self(ctx) + a.Error(err) + + mock.ExpectCall(&tg.UsersGetUsersRequest{ + ID: []tg.InputUserClass{&tg.InputUserSelf{}}, + }).ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{&tg.User{ + Self: true, + ID: 10, + AccessHash: 10, + }}}) + r, err := client.self(ctx) + a.NoError(err) + a.Equal(int64(10), r.ID) + })(t) +} diff --git a/telegram/auth/signup.go b/telegram/auth/signup.go index a681e49f6c..f1b30a1ff3 100644 --- a/telegram/auth/signup.go +++ b/telegram/auth/signup.go @@ -22,7 +22,7 @@ func (s *SignUpRequired) Error() string { return "account with provided number does not exist (sign up required)" } -// checkResult checks that a is *tg.AuthAuthorization and returns authorization result or error. +// checkResult checks that `a` is *tg.AuthAuthorization and returns authorization result or error. func checkResult(a tg.AuthAuthorizationClass) (*tg.AuthAuthorization, error) { switch a := a.(type) { case *tg.AuthAuthorization: diff --git a/telegram/auth/status_test.go b/telegram/auth/status_test.go index d803fd2f79..68368a611c 100644 --- a/telegram/auth/status_test.go +++ b/telegram/auth/status_test.go @@ -11,7 +11,7 @@ import ( "github.com/gotd/td/tgmock" ) -func TestClient_AuthStatus(t *testing.T) { +func TestClient_Status(t *testing.T) { ctx := context.Background() t.Run("Authorized", func(t *testing.T) { @@ -49,7 +49,7 @@ func TestClient_AuthStatus(t *testing.T) { }) } -func TestClient_AuthIfNecessary(t *testing.T) { +func TestClient_IfNecessary(t *testing.T) { ctx := context.Background() t.Run("Authorized", func(t *testing.T) { @@ -62,4 +62,46 @@ func TestClient_AuthIfNecessary(t *testing.T) { // Pass empty AuthFlow because it should not be called anyway. require.NoError(t, testClient(mock).IfNecessary(ctx, Flow{})) }) + + t.Run("Error", func(t *testing.T) { + mock := tgmock.NewRequire(t) + mock.Expect().ThenRPCErr(&tgerr.Error{ + Code: 500, + Message: "BRUH", + Type: "BRUH", + }) + + // Pass empty AuthFlow because it should not be called anyway. + require.Error(t, testClient(mock).IfNecessary(ctx, Flow{})) + }) +} + +func TestClient_Test(t *testing.T) { + ctx := context.Background() + + t.Run("Authorized", func(t *testing.T) { + mock := tgmock.NewRequire(t) + testUser := &tg.User{ + Username: "user", + } + mock.Expect().ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{testUser}}) + + // Pass empty AuthFlow because it should not be called anyway. + require.NoError(t, testClient(mock).Test(ctx, 2)) + }) +} + +func TestClient_TestUser(t *testing.T) { + ctx := context.Background() + + t.Run("Authorized", func(t *testing.T) { + mock := tgmock.NewRequire(t) + testUser := &tg.User{ + Username: "user", + } + mock.Expect().ThenResult(&tg.UserClassVector{Elems: []tg.UserClass{testUser}}) + + // Pass empty AuthFlow because it should not be called anyway. + require.NoError(t, testClient(mock).TestUser(ctx, "phone", 2)) + }) } diff --git a/telegram/auth/user.go b/telegram/auth/user.go index 2f2c7fc09f..26c4e66a6d 100644 --- a/telegram/auth/user.go +++ b/telegram/auth/user.go @@ -5,7 +5,6 @@ import ( "github.com/go-faster/errors" - "github.com/gotd/td/internal/crypto/srp" "github.com/gotd/td/tg" "github.com/gotd/td/tgerr" ) @@ -26,15 +25,9 @@ func (c *Client) Password(ctx context.Context, password string) (*tg.AuthAuthori return nil, errors.Wrap(err, "get SRP parameters") } - algo, ok := p.CurrentAlgo.(*tg.PasswordKdfAlgoSHA256SHA256PBKDF2HMACSHA512iter100000SHA256ModPow) - if !ok { - return nil, errors.Errorf("unsupported algo: %T", p.CurrentAlgo) - } - - s := srp.NewSRP(c.rand) - a, err := s.Hash([]byte(password), p.SRPB, p.SecureRandom, srp.Input(*algo)) + a, err := PasswordHash([]byte(password), p.SRPID, p.SRPB, p.SecureRandom, p.CurrentAlgo) if err != nil { - return nil, errors.Wrap(err, "create SRP answer") + return nil, errors.Wrap(err, "compute password hash") } auth, err := c.api.AuthCheckPassword(ctx, &tg.InputCheckPasswordSRP{ diff --git a/telegram/auth/user_test.go b/telegram/auth/user_test.go index 2d48ebe1b3..ccbc2c7c0a 100644 --- a/telegram/auth/user_test.go +++ b/telegram/auth/user_test.go @@ -37,6 +37,12 @@ func TestClient_AuthSignIn(t *testing.T) { testUser := &tg.User{ID: 1} invoker := tgmock.Invoker(func(body bin.Encoder) (bin.Encoder, error) { switch req := body.(type) { + case *tg.UsersGetUsersRequest: + return nil, &tgerr.Error{ + Code: 401, + Message: "AUTH_KEY_UNREGISTERED", + Type: "AUTH_KEY_UNREGISTERED", + } case *tg.AuthSendCodeRequest: settings := tg.CodeSettings{} settings.SetCurrentNumber(true) @@ -112,14 +118,20 @@ func TestClient_AuthSignIn(t *testing.T) { require.NoError(t, err) require.Equal(t, testUser, result.User) }) - t.Run("AuthFlow", func(t *testing.T) { - // Using flow helper. - u := Constant(phone, password, CodeAuthenticatorFunc( + + flow := NewFlow( + Constant(phone, password, CodeAuthenticatorFunc( func(ctx context.Context, _ *tg.AuthSentCode) (string, error) { return code, nil }, - )) - require.NoError(t, NewFlow(u, SendCodeOptions{CurrentNumber: true}).Run(ctx, testClient(invoker))) + )), + SendCodeOptions{CurrentNumber: true}, + ) + t.Run("AuthFlow", func(t *testing.T) { + require.NoError(t, flow.Run(ctx, testClient(invoker))) + }) + t.Run("IfNecessary", func(t *testing.T) { + require.NoError(t, testClient(invoker).IfNecessary(ctx, flow)) }) } @@ -225,3 +237,18 @@ func TestClientTestSignUp(t *testing.T) { SendCodeOptions{}, ).Run(ctx, testClient(invoker))) } + +func TestClient_AcceptTOS(t *testing.T) { + ctx := context.Background() + mockTest(func(a *require.Assertions, mock *tgmock.Mock, client *Client) { + mock.Expect().ThenUnregistered() + a.Error(client.AcceptTOS(ctx, tg.DataJSON{ + Data: `{"data":"data"}`, + })) + + mock.Expect().ThenTrue() + a.NoError(client.AcceptTOS(ctx, tg.DataJSON{ + Data: `{"data":"data"}`, + })) + })(t) +} diff --git a/telegram/auth_test.go b/telegram/auth_example_test.go similarity index 60% rename from telegram/auth_test.go rename to telegram/auth_example_test.go index ea19d93044..201cfc52f5 100644 --- a/telegram/auth_test.go +++ b/telegram/auth_example_test.go @@ -3,7 +3,6 @@ package telegram_test import ( "bufio" "context" - "crypto/rand" "fmt" "log" "os" @@ -17,7 +16,46 @@ import ( "github.com/gotd/td/tg" ) -func ExampleClient_Auth() { +func ExampleClient_Auth_codeOnly() { + check := func(err error) { + if err != nil { + panic(err) + } + } + + var ( + appIDString = os.Getenv("APP_ID") + appHash = os.Getenv("APP_HASH") + phone = os.Getenv("PHONE") + ) + if appIDString == "" || appHash == "" || phone == "" { + log.Fatal("PHONE, APP_ID or APP_HASH is not set") + } + + appID, err := strconv.Atoi(appIDString) + check(err) + + ctx := context.Background() + client := telegram.NewClient(appID, appHash, telegram.Options{}) + codeAsk := func(ctx context.Context, sentCode *tg.AuthSentCode) (string, error) { + fmt.Print("code:") + code, err := bufio.NewReader(os.Stdin).ReadString('\n') + if err != nil { + return "", err + } + code = strings.ReplaceAll(code, "\n", "") + return code, nil + } + + check(client.Run(ctx, func(ctx context.Context) error { + return auth.NewFlow( + auth.CodeOnly(phone, auth.CodeAuthenticatorFunc(codeAsk)), + auth.SendCodeOptions{}, + ).Run(ctx, client.Auth()) + })) +} + +func ExampleClient_Auth_password() { check := func(err error) { if err != nil { panic(err) @@ -67,10 +105,32 @@ func ExampleClient_Auth_test() { DCList: dcs.Test(), }) if err := client.Run(ctx, func(ctx context.Context) error { - return auth.NewFlow( - auth.Test(rand.Reader, dcID), - auth.SendCodeOptions{}, - ).Run(ctx, client.Auth()) + return client.Auth().Test(ctx, dcID) + }); err != nil { + panic(err) + } +} + +func ExampleClient_Auth_bot() { + ctx := context.Background() + client := telegram.NewClient(telegram.TestAppID, telegram.TestAppHash, telegram.Options{}) + if err := client.Run(ctx, func(ctx context.Context) error { + // Checking auth status. + status, err := client.Auth().Status(ctx) + if err != nil { + return err + } + // Can be already authenticated if we have valid session in + // session storage. + if !status.Authorized { + // Otherwise, perform bot authentication. + if _, err := client.Auth().Bot(ctx, os.Getenv("BOT_TOKEN")); err != nil { + return err + } + } + + // All good, manually authenticated. + return nil }); err != nil { panic(err) }