Skip to content

Commit 2e94561

Browse files
committed
support mtproto conn type 0xee. fixes #1297
1 parent d839595 commit 2e94561

File tree

4 files changed

+62
-11
lines changed

4 files changed

+62
-11
lines changed

proxy/mtproto/auth.go

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package mtproto
22

33
import (
4+
"context"
45
"crypto/rand"
56
"crypto/sha256"
67
"io"
@@ -13,6 +14,35 @@ const (
1314
HeaderSize = 64
1415
)
1516

17+
type SessionContext struct {
18+
ConnectionType [4]byte
19+
DataCenterID uint16
20+
}
21+
22+
func DefaultSessionContext() SessionContext {
23+
return SessionContext{
24+
ConnectionType: [4]byte{0xef, 0xef, 0xef, 0xef},
25+
DataCenterID: 0,
26+
}
27+
}
28+
29+
type contextKey int32
30+
31+
const (
32+
sessionContextKey contextKey = iota
33+
)
34+
35+
func ContextWithSessionContext(ctx context.Context, c SessionContext) context.Context {
36+
return context.WithValue(ctx, sessionContextKey, c)
37+
}
38+
39+
func SessionContextFromContext(ctx context.Context) SessionContext {
40+
if c := ctx.Value(sessionContextKey); c != nil {
41+
return c.(SessionContext)
42+
}
43+
return DefaultSessionContext()
44+
}
45+
1646
type Authentication struct {
1747
Header [HeaderSize]byte
1848
DecodingKey [32]byte
@@ -29,12 +59,18 @@ func (a *Authentication) DataCenterID() uint16 {
2959
return uint16(x) - 1
3060
}
3161

62+
func (a *Authentication) ConnectionType() [4]byte {
63+
var x [4]byte
64+
copy(x[:], a.Header[56:60])
65+
return x
66+
}
67+
3268
func (a *Authentication) ApplySecret(b []byte) {
3369
a.DecodingKey = sha256.Sum256(append(a.DecodingKey[:], b...))
3470
a.EncodingKey = sha256.Sum256(append(a.EncodingKey[:], b...))
3571
}
3672

37-
func generateRandomBytes(random []byte) {
73+
func generateRandomBytes(random []byte, connType [4]byte) {
3874
for {
3975
common.Must2(rand.Read(random))
4076

@@ -51,19 +87,16 @@ func generateRandomBytes(random []byte) {
5187
continue
5288
}
5389

54-
random[56] = 0xef
55-
random[57] = 0xef
56-
random[58] = 0xef
57-
random[59] = 0xef
90+
copy(random[56:60], connType[:])
5891

5992
return
6093
}
6194
}
6295

63-
func NewAuthentication() *Authentication {
96+
func NewAuthentication(sc SessionContext) *Authentication {
6497
auth := getAuthenticationObject()
6598
random := auth.Header[:]
66-
generateRandomBytes(random)
99+
generateRandomBytes(random, sc.ConnectionType)
67100
copy(auth.EncodingKey[:], random[8:])
68101
copy(auth.EncodingNonce[:], random[8+32:])
69102
keyivInverse := Inverse(random[8 : 8+32+16])

proxy/mtproto/auth_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func TestInverse(t *testing.T) {
3232
func TestAuthenticationReadWrite(t *testing.T) {
3333
assert := With(t)
3434

35-
a := NewAuthentication()
35+
a := NewAuthentication(DefaultSessionContext())
3636
b := bytes.NewReader(a.Header[:])
3737
a2, err := ReadAuthentication(b)
3838
assert(err, IsNil)

proxy/mtproto/client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial
3636
}
3737
defer conn.Close() // nolint: errcheck
3838

39-
auth := NewAuthentication()
39+
sc := SessionContextFromContext(ctx)
40+
auth := NewAuthentication(sc)
4041
defer putAuthenticationObject(auth)
4142

4243
request := func() error {

proxy/mtproto/server.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ func (s *Server) Network() net.NetworkList {
6464
}
6565
}
6666

67+
func isValidConnectionType(c [4]byte) bool {
68+
if compare.BytesAll(c[:], 0xef) {
69+
return true
70+
}
71+
if compare.BytesAll(c[:], 0xee) {
72+
return true
73+
}
74+
return false
75+
}
76+
6777
func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher core.Dispatcher) error {
6878
sPolicy := s.policy.ForLevel(s.user.Level)
6979

@@ -85,8 +95,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
8595
decryptor := crypto.NewAesCTRStream(auth.DecodingKey[:], auth.DecodingNonce[:])
8696
decryptor.XORKeyStream(auth.Header[:], auth.Header[:])
8797

88-
if !compare.BytesAll(auth.Header[56:60], 0xef) {
89-
return newError("invalid connection type: ", auth.Header[56:60])
98+
ct := auth.ConnectionType()
99+
if !isValidConnectionType(ct) {
100+
return newError("invalid connection type: ", ct)
90101
}
91102

92103
dcID := auth.DataCenterID()
@@ -104,6 +115,12 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
104115
timer := signal.CancelAfterInactivity(ctx, cancel, sPolicy.Timeouts.ConnectionIdle)
105116
ctx = core.ContextWithBufferPolicy(ctx, sPolicy.Buffer)
106117

118+
sc := SessionContext{
119+
ConnectionType: ct,
120+
DataCenterID: dcID,
121+
}
122+
ctx = ContextWithSessionContext(ctx, sc)
123+
107124
link, err := dispatcher.Dispatch(ctx, dest)
108125
if err != nil {
109126
return newError("failed to dispatch request to: ", dest).Base(err)

0 commit comments

Comments
 (0)