Skip to content

Commit

Permalink
Add session context outbounds as slice (#3356)
Browse files Browse the repository at this point in the history
* Add session context outbounds as slice

slice is needed for dialer proxy where two outbounds work on top of each other
There are two sets of target addr for example
It also enable Xtls to correctly do splice copy by checking both outbounds are ready to do direct copy

* Fill outbound tag info

* Splice now checks capalibility from all outbounds

* Fix unit tests
  • Loading branch information
yuhan6665 committed May 14, 2024
1 parent 0735053 commit 017f53b
Show file tree
Hide file tree
Showing 42 changed files with 302 additions and 235 deletions.
22 changes: 13 additions & 9 deletions app/dispatcher/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,12 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
if !destination.IsValid() {
panic("Dispatcher: Invalid destination.")
}
ob := session.OutboundFromContext(ctx)
if ob == nil {
ob = &session.Outbound{}
ctx = session.ContextWithOutbound(ctx, ob)
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) == 0 {
outbounds = []*session.Outbound{{}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
}
ob := outbounds[len(outbounds) - 1]
ob.OriginalTarget = destination
ob.Target = destination
content := session.ContentFromContext(ctx)
Expand Down Expand Up @@ -274,11 +275,12 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
if !destination.IsValid() {
return newError("Dispatcher: Invalid destination.")
}
ob := session.OutboundFromContext(ctx)
if ob == nil {
ob = &session.Outbound{}
ctx = session.ContextWithOutbound(ctx, ob)
outbounds := session.OutboundsFromContext(ctx)
if len(outbounds) == 0 {
outbounds = []*session.Outbound{{}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
}
ob := outbounds[len(outbounds) - 1]
ob.OriginalTarget = destination
ob.Target = destination
content := session.ContentFromContext(ctx)
Expand Down Expand Up @@ -368,7 +370,8 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw
return contentResult, contentErr
}
func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
ob := session.OutboundFromContext(ctx)
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
proxied := hosts.LookupHosts(ob.Target.String())
if proxied != nil {
Expand Down Expand Up @@ -425,6 +428,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
return
}

ob.Tag = handler.Tag()
if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
if tag := handler.Tag(); tag != "" {
if inTag == "" {
Expand Down
11 changes: 6 additions & 5 deletions app/dispatcher/fakednssniffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@ func newFakeDNSSniffer(ctx context.Context) (protocolSnifferWithMetadata, error)
return protocolSnifferWithMetadata{}, errNotInit
}
return protocolSnifferWithMetadata{protocolSniffer: func(ctx context.Context, bytes []byte) (SniffResult, error) {
Target := session.OutboundFromContext(ctx).Target
if Target.Network == net.Network_TCP || Target.Network == net.Network_UDP {
domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(Target.Address)
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if ob.Target.Network == net.Network_TCP || ob.Target.Network == net.Network_UDP {
domainFromFakeDNS := fakeDNSEngine.GetDomainFromFakeDNS(ob.Target.Address)
if domainFromFakeDNS != "" {
newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", Target.Address.String()).WriteToLog(session.ExportIDToError(ctx))
newError("fake dns got domain: ", domainFromFakeDNS, " for ip: ", ob.Target.Address.String()).WriteToLog(session.ExportIDToError(ctx))
return &fakeDNSSniffResult{domainName: domainFromFakeDNS}, nil
}
}

if ipAddressInRangeValueI := ctx.Value(ipAddressInRange); ipAddressInRangeValueI != nil {
ipAddressInRangeValue := ipAddressInRangeValueI.(*ipAddressInRangeOpt)
if fkr0, ok := fakeDNSEngine.(dns.FakeDNSEngineRev0); ok {
inPool := fkr0.IsIPInIPPool(Target.Address)
inPool := fkr0.IsIPInIPPool(ob.Target.Address)
ipAddressInRangeValue.addressInRange = &inPool
}
}
Expand Down
11 changes: 6 additions & 5 deletions app/proxyman/inbound/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (w *tcpWorker) callback(conn stat.Connection) {
sid := session.NewID()
ctx = session.ContextWithID(ctx, sid)

var outbound = &session.Outbound{}
outbounds := []*session.Outbound{{}}
if w.recvOrigDest {
var dest net.Destination
switch getTProxyType(w.stream) {
Expand All @@ -75,10 +75,10 @@ func (w *tcpWorker) callback(conn stat.Connection) {
dest = net.DestinationFromAddr(conn.LocalAddr())
}
if dest.IsValid() {
outbound.Target = dest
outbounds[0].Target = dest
}
}
ctx = session.ContextWithOutbound(ctx, outbound)
ctx = session.ContextWithOutbounds(ctx, outbounds)

if w.uplinkCounter != nil || w.downlinkCounter != nil {
conn = &stat.CounterConnection{
Expand Down Expand Up @@ -309,9 +309,10 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest
ctx = session.ContextWithID(ctx, sid)

if originalDest.IsValid() {
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
outbounds := []*session.Outbound{{
Target: originalDest,
})
}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
}
ctx = session.ContextWithInbound(ctx, &session.Inbound{
Source: source,
Expand Down
38 changes: 18 additions & 20 deletions app/proxyman/outbound/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,11 @@ func (h *Handler) Tag() string {

// Dispatch implements proxy.Outbound.Dispatch.
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
outbound := session.OutboundFromContext(ctx)
if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address {
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if ob.Target.Network == net.Network_UDP && ob.OriginalTarget.Address != nil && ob.OriginalTarget.Address != ob.Target.Address {
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
}
if h.mux != nil {
test := func(err error) {
Expand All @@ -183,7 +184,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
common.Interrupt(link.Writer)
}
}
if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 {
if ob.Target.Network == net.Network_UDP && ob.Target.Port == 443 {
switch h.udp443 {
case "reject":
test(newError("XUDP rejected UDP/443 traffic").AtInfo())
Expand All @@ -192,7 +193,7 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
goto out
}
}
if h.xudp != nil && outbound.Target.Network == net.Network_UDP {
if h.xudp != nil && ob.Target.Network == net.Network_UDP {
if !h.xudp.Enabled {
goto out
}
Expand Down Expand Up @@ -243,10 +244,11 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
handler := h.outboundManager.GetHandler(tag)
if handler != nil {
newError("proxying to ", tag, " for dest ", dest).AtDebug().WriteToLog(session.ExportIDToError(ctx))
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
outbounds := session.OutboundsFromContext(ctx)
ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{
Target: dest,
})

Tag: tag,
})) // add another outbound in session ctx
opts := pipe.OptionsFromContext(ctx)
uplinkReader, uplinkWriter := pipe.New(opts...)
downlinkReader, downlinkWriter := pipe.New(opts...)
Expand All @@ -266,15 +268,12 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti
}

if h.senderSettings.Via != nil {
outbound := session.OutboundFromContext(ctx)
if outbound == nil {
outbound = new(session.Outbound)
ctx = session.ContextWithOutbound(ctx, outbound)
}
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if h.senderSettings.ViaCidr == "" {
outbound.Gateway = h.senderSettings.Via.AsAddress()
ob.Gateway = h.senderSettings.Via.AsAddress()
} else { //Get a random address.
outbound.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr)
ob.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr)
}
}
}
Expand All @@ -285,10 +284,9 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connecti

conn, err := internet.Dial(ctx, dest, h.streamSettings)
conn = h.getStatCouterConnection(conn)
outbound := session.OutboundFromContext(ctx)
if outbound != nil {
outbound.Conn = conn
}
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
ob.Conn = conn
return conn, err
}

Expand Down
3 changes: 3 additions & 0 deletions app/proxyman/outbound/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/xtls/xray-core/app/stats"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/serial"
"github.com/xtls/xray-core/common/session"
core "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/proxy/freedom"
Expand Down Expand Up @@ -44,6 +45,7 @@ func TestOutboundWithoutStatCounter(t *testing.T) {
v, _ := core.New(config)
v.AddFeature((outbound.Manager)(new(Manager)))
ctx := context.WithValue(context.Background(), xrayKey, v)
ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}})
h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{
Tag: "tag",
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
Expand Down Expand Up @@ -73,6 +75,7 @@ func TestOutboundWithStatCounter(t *testing.T) {
v, _ := core.New(config)
v.AddFeature((outbound.Manager)(new(Manager)))
ctx := context.WithValue(context.Background(), xrayKey, v)
ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{{}})
h, _ := NewHandler(ctx, &core.OutboundHandlerConfig{
Tag: "tag",
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
Expand Down
12 changes: 7 additions & 5 deletions app/reverse/portal.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,13 @@ func (p *Portal) Close() error {
}

func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) error {
outboundMeta := session.OutboundFromContext(ctx)
if outboundMeta == nil {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
if ob == nil {
return newError("outbound metadata not found").AtError()
}

if isDomain(outboundMeta.Target, p.domain) {
if isDomain(ob.Target, p.domain) {
muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{})
if err != nil {
return newError("failed to create mux client worker").Base(err).AtWarning()
Expand Down Expand Up @@ -206,9 +207,10 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
downlinkReader, downlinkWriter := pipe.New(opt...)

ctx := context.Background()
ctx = session.ContextWithOutbound(ctx, &session.Outbound{
outbounds := []*session.Outbound{{
Target: net.UDPDestination(net.DomainAddress(internalDomain), 0),
})
}}
ctx = session.ContextWithOutbounds(ctx, outbounds)
f := client.Dispatch(ctx, &transport.Link{
Reader: uplinkReader,
Writer: downlinkWriter,
Expand Down
20 changes: 15 additions & 5 deletions app/router/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ func TestSimpleRouter(t *testing.T) {
HandlerSelector: mockHs,
}, nil))

ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
Expand Down Expand Up @@ -86,7 +88,9 @@ func TestSimpleBalancer(t *testing.T) {
HandlerSelector: mockHs,
}, nil))

ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
Expand Down Expand Up @@ -174,7 +178,9 @@ func TestIPOnDemand(t *testing.T) {
r := new(Router)
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))

ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
Expand Down Expand Up @@ -213,7 +219,9 @@ func TestIPIfNonMatchDomain(t *testing.T) {
r := new(Router)
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))

ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("example.com"), 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("example.com"), 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
Expand Down Expand Up @@ -247,7 +255,9 @@ func TestIPIfNonMatchIP(t *testing.T) {
r := new(Router)
common.Must(r.Init(context.TODO(), config, mockDNS, nil, nil))

ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.LocalHostIP, 80),
}})
route, err := r.PickRoute(routing_session.AsRoutingContext(ctx))
common.Must(err)
if tag := route.GetOutboundTag(); tag != "test" {
Expand Down
14 changes: 8 additions & 6 deletions common/mux/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
}

go func(p proxy.Outbound, d internet.Dialer, c common.Closable) {
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{
outbounds := []*session.Outbound{{
Target: net.TCPDestination(muxCoolAddress, muxCoolPort),
})
}}
ctx := session.ContextWithOutbounds(context.Background(), outbounds)
ctx, cancel := context.WithCancel(ctx)

if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
Expand Down Expand Up @@ -242,17 +243,18 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
}

func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
dest := session.OutboundFromContext(ctx).Target
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds) - 1]
transferType := protocol.TransferTypeStream
if dest.Network == net.Network_UDP {
if ob.Target.Network == net.Network_UDP {
transferType = protocol.TransferTypePacket
}
s.transferType = transferType
writer := NewWriter(s.ID, dest, output, transferType, xudp.GetGlobalID(ctx))
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
defer s.Close(false)
defer writer.Close()

newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx))
newError("dispatching request to ", ob.Target).WriteToLog(session.ExportIDToError(ctx))
if err := writeFirstPayload(s.input, writer); err != nil {
newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
Expand Down
8 changes: 4 additions & 4 deletions common/mux/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ func TestClientWorkerClose(t *testing.T) {
}

tr1, tw1 := pipe.New(pipe.WithoutSizeLimit())
ctx1 := session.ContextWithOutbound(context.Background(), &session.Outbound{
ctx1 := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80),
})
}})
common.Must(manager.Dispatch(ctx1, &transport.Link{
Reader: tr1,
Writer: tw1,
Expand All @@ -103,9 +103,9 @@ func TestClientWorkerClose(t *testing.T) {
}

tr2, tw2 := pipe.New(pipe.WithoutSizeLimit())
ctx2 := session.ContextWithOutbound(context.Background(), &session.Outbound{
ctx2 := session.ContextWithOutbounds(context.Background(), []*session.Outbound{{
Target: net.TCPDestination(net.DomainAddress("www.example.com"), 80),
})
}})
common.Must(manager.Dispatch(ctx2, &transport.Link{
Reader: tr2,
Writer: tw2,
Expand Down
10 changes: 5 additions & 5 deletions common/session/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ func InboundFromContext(ctx context.Context) *Inbound {
return nil
}

func ContextWithOutbound(ctx context.Context, outbound *Outbound) context.Context {
return context.WithValue(ctx, outboundSessionKey, outbound)
func ContextWithOutbounds(ctx context.Context, outbounds []*Outbound) context.Context {
return context.WithValue(ctx, outboundSessionKey, outbounds)
}

func OutboundFromContext(ctx context.Context) *Outbound {
if outbound, ok := ctx.Value(outboundSessionKey).(*Outbound); ok {
return outbound
func OutboundsFromContext(ctx context.Context) []*Outbound {
if outbounds, ok := ctx.Value(outboundSessionKey).([]*Outbound); ok {
return outbounds
}
return nil
}
Expand Down

0 comments on commit 017f53b

Please sign in to comment.