diff --git a/internal/graph/check.go b/internal/graph/check.go index 560dc97e1e..7d7a620cc5 100644 --- a/internal/graph/check.go +++ b/internal/graph/check.go @@ -45,7 +45,9 @@ type ValidatedCheckRequest struct { func (cc *ConcurrentChecker) Check(ctx context.Context, req ValidatedCheckRequest, relation *v0.Relation) (*v1.DispatchCheckResponse, error) { var directFunc ReduceableCheckFunc - if onrEqual(req.Subject, req.ObjectAndRelation) { + if req.Subject.ObjectId == tuple.PublicWildcard { + directFunc = checkError(NewErrInvalidArgument(errors.New("cannot perform check on wildcard"))) + } else if onrEqual(req.Subject, req.ObjectAndRelation) { // If we have found the goal's ONR, then we know that the ONR is a member. directFunc = alwaysMember() } else if relation.UsersetRewrite == nil { diff --git a/internal/graph/errors.go b/internal/graph/errors.go index 0bdf208c72..8945af80f4 100644 --- a/internal/graph/errors.go +++ b/internal/graph/errors.go @@ -122,3 +122,15 @@ func NewRelationMissingTypeInfoErr(nsName string, relationName string) error { relationName: relationName, } } + +// ErrInvalidArgument occurs when a request sent has an invalid argument. +type ErrInvalidArgument struct { + error +} + +// NewErrInvalidArgument constructs a request sent has an invalid argument. +func NewErrInvalidArgument(baseErr error) error { + return ErrInvalidArgument{ + error: baseErr, + } +} diff --git a/internal/graph/lookup.go b/internal/graph/lookup.go index 91d08b602f..52bedd10b1 100644 --- a/internal/graph/lookup.go +++ b/internal/graph/lookup.go @@ -2,6 +2,7 @@ package graph import ( "context" + "errors" "fmt" v0 "github.com/authzed/authzed-go/proto/authzed/api/v0" @@ -46,6 +47,10 @@ const ( // Lookup performs a lookup request with the provided request and context. func (cl *ConcurrentLookup) Lookup(ctx context.Context, req ValidatedLookupRequest) (*v1.DispatchLookupResponse, error) { funcToResolve := cl.lookupInternal(ctx, req) + if req.Subject.ObjectId == tuple.PublicWildcard { + funcToResolve = returnResult(lookupResultError(req, NewErrInvalidArgument(errors.New("cannot perform lookup on wildcard")), emptyMetadata)) + } + resolved := lookupOne(ctx, req, funcToResolve) // Remove the resolved relation reference from the excluded direct list to mark that it was completely resolved. @@ -115,9 +120,7 @@ func (cl *ConcurrentLookup) lookupInternal(ctx context.Context, req ValidatedLoo var requests []ReduceableLookupFunc for _, obj := range toCheck.AsSlice() { // If we've already found the target ONR, no further resolution is necessary. - if obj.Namespace == req.Subject.Namespace && - obj.Relation == req.Subject.Relation && - obj.ObjectId == req.Subject.ObjectId { + if onrEqualOrWildcard(obj, req.Subject) { continue } @@ -206,6 +209,46 @@ func (cl *ConcurrentLookup) lookupDirect(ctx context.Context, req ValidatedLooku }) } + // Dispatch a check for the subject wildcard, if allowed. + isWildcardAllowed, err := typeSystem.IsAllowedPublicNamespace(req.ObjectRelation.Relation, req.Subject.Namespace) + if isWildcardAllowed == namespace.PublicSubjectAllowed { + requests = append(requests, func(ctx context.Context, resultChan chan<- LookupResult) { + objects := tuple.NewONRSet() + it, err := cl.ds.ReverseQueryTuples( + ctx, + tuple.UsersetToSubjectFilter(&v0.ObjectAndRelation{ + Namespace: req.Subject.Namespace, + ObjectId: tuple.PublicWildcard, + Relation: req.Subject.Relation, + }), + req.Revision, + options.WithResRelation(&options.ResourceRelation{ + Namespace: req.ObjectRelation.Namespace, + Relation: req.ObjectRelation.Relation, + }), + ) + if err != nil { + resultChan <- lookupResultError(req, err, emptyMetadata) + return + } + defer it.Close() + + for tpl := it.Next(); tpl != nil; tpl = it.Next() { + objects.Add(tpl.ObjectAndRelation) + if objects.Length() >= req.Limit { + break + } + } + + if it.Err() != nil { + resultChan <- lookupResultError(req, it.Err(), emptyMetadata) + return + } + + resultChan <- lookupResult(req, objects.AsSlice(), emptyMetadata) + }) + } + // Dispatch to any allowed subject relation types that don't match the target ONR, collect // the found object IDs, and then search for those. allowedDirect, err := typeSystem.AllowedSubjectRelations(req.ObjectRelation.Relation) diff --git a/internal/membership/foundsubject.go b/internal/membership/foundsubject.go new file mode 100644 index 0000000000..e3753cfa32 --- /dev/null +++ b/internal/membership/foundsubject.go @@ -0,0 +1,142 @@ +package membership + +import ( + "fmt" + "sort" + "strings" + + v0 "github.com/authzed/authzed-go/proto/authzed/api/v0" + + "github.com/authzed/spicedb/pkg/tuple" +) + +// NewFoundSubject creates a new FoundSubject for a subject and a set of its resources. +func NewFoundSubject(subject *v0.ObjectAndRelation, resources ...*v0.ObjectAndRelation) FoundSubject { + return FoundSubject{subject, tuple.NewONRSet(), tuple.NewONRSet(resources...)} +} + +// FoundSubject contains a single found subject and all the relationships in which that subject +// is a member which were found via the ONRs expansion. +type FoundSubject struct { + // subject is the subject found. + subject *v0.ObjectAndRelation + + // excludedSubjects are any subjects excluded. Only should be set if subject is a wildcard. + excludedSubjects *tuple.ONRSet + + // relations are the relations under which the subject lives that informed the locating + // of this subject for the root ONR. + relationships *tuple.ONRSet +} + +// Subject returns the Subject of the FoundSubject. +func (fs FoundSubject) Subject() *v0.ObjectAndRelation { + return fs.subject +} + +// WildcardType returns the object type for the wildcard subject, if this is a wildcard subject. +func (fs FoundSubject) WildcardType() (string, bool) { + if fs.subject.ObjectId == tuple.PublicWildcard { + return fs.subject.Namespace, true + } + + return "", false +} + +// ExcludedSubjectsFromWildcard returns those subjects excluded from the wildcard subject. +// If not a wildcard subject, returns false. +func (fs FoundSubject) ExcludedSubjectsFromWildcard() ([]*v0.ObjectAndRelation, bool) { + if fs.subject.ObjectId == tuple.PublicWildcard { + return fs.excludedSubjects.AsSlice(), true + } + + return []*v0.ObjectAndRelation{}, false +} + +// Relationships returns all the relationships in which the subject was found as per the expand. +func (fs FoundSubject) Relationships() []*v0.ObjectAndRelation { + return fs.relationships.AsSlice() +} + +// ToValidationString returns the FoundSubject in a format that is consumable by the validationfile +// package. +func (fs FoundSubject) ToValidationString() string { + onrString := tuple.StringONR(fs.Subject()) + excluded, isWildcard := fs.ExcludedSubjectsFromWildcard() + if isWildcard && len(excluded) > 0 { + excludedONRStrings := make([]string, 0, len(excluded)) + for _, excludedONR := range excluded { + excludedONRStrings = append(excludedONRStrings, tuple.StringONR(excludedONR)) + } + + sort.Strings(excludedONRStrings) + return fmt.Sprintf("%s - {%s}", onrString, strings.Join(excludedONRStrings, ", ")) + } + + return onrString +} + +// union performs merging of two FoundSubject's with the same subject. +func (fs FoundSubject) union(other FoundSubject) FoundSubject { + if toKey(fs.subject) != toKey(other.subject) { + panic("Got wrong found subject to union") + } + + relationships := fs.relationships.Union(other.relationships) + var excludedSubjects *tuple.ONRSet + + // If a wildcard, then union together excluded subjects. + _, isWildcard := fs.WildcardType() + if isWildcard { + excludedSubjects = fs.excludedSubjects.Union(other.excludedSubjects) + } + + return FoundSubject{ + subject: fs.subject, + excludedSubjects: excludedSubjects, + relationships: relationships, + } +} + +// intersect performs intersection between two FoundSubject's with the same subject. +func (fs FoundSubject) intersect(other FoundSubject) FoundSubject { + if toKey(fs.subject) != toKey(other.subject) { + panic("Got wrong found subject to intersect") + } + + relationships := fs.relationships.Union(other.relationships) + var excludedSubjects *tuple.ONRSet + + // If a wildcard, then union together excluded subjects. + _, isWildcard := fs.WildcardType() + if isWildcard { + excludedSubjects = fs.excludedSubjects.Union(other.excludedSubjects) + } + + return FoundSubject{ + subject: fs.subject, + excludedSubjects: excludedSubjects, + relationships: relationships, + } +} + +// FoundSubjects contains the subjects found for a specific ONR. +type FoundSubjects struct { + // subjects is a map from the Subject ONR (as a string) to the FoundSubject information. + subjects map[string]FoundSubject +} + +// ListFound returns a slice of all the FoundSubject's. +func (fs FoundSubjects) ListFound() []FoundSubject { + found := []FoundSubject{} + for _, sub := range fs.subjects { + found = append(found, sub) + } + return found +} + +// LookupSubject returns the FoundSubject for a matching subject, if any. +func (fs FoundSubjects) LookupSubject(subject *v0.ObjectAndRelation) (FoundSubject, bool) { + found, ok := fs.subjects[toKey(subject)] + return found, ok +} diff --git a/internal/membership/foundsubject_test.go b/internal/membership/foundsubject_test.go new file mode 100644 index 0000000000..7f87a6d882 --- /dev/null +++ b/internal/membership/foundsubject_test.go @@ -0,0 +1,62 @@ +package membership + +import ( + "fmt" + "testing" + + "github.com/authzed/spicedb/pkg/validationfile" + "github.com/stretchr/testify/require" +) + +func TestToValidationString(t *testing.T) { + testCases := []struct { + name string + fs FoundSubject + expected string + }{ + { + "basic", + fs("user", "user1", "..."), + "user:user1", + }, + { + "with exclusion", + fs("user", "*", "...", ONR("user", "user1", "...")), + "user:* - {user:user1}", + }, + { + "with some exclusion", + fs("user", "*", "...", + ONR("user", "user1", "..."), + ONR("user", "user2", "..."), + ONR("user", "user3", "..."), + ONR("user", "user4", "..."), + ONR("user", "user5", "..."), + ), + "user:* - {user:user1, user:user2, user:user3, user:user4, user:user5}", + }, + { + "with many exclusion", + fs("user", "*", "...", + ONR("user", "user1", "..."), + ONR("user", "user2", "..."), + ONR("user", "user3", "..."), + ONR("user", "user4", "..."), + ONR("user", "user5", "..."), + ONR("user", "user6", "..."), + ), + "user:* - {user:user1, user:user2, user:user3, user:user4, user:user5, user:user6}", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + require.Equal(tc.expected, tc.fs.ToValidationString()) + + sub, err := validationfile.ValidationString(fmt.Sprintf("[%s]", tc.expected)).Subject() + require.Nil(err) + require.NotNil(sub) + }) + } +} diff --git a/internal/membership/membership.go b/internal/membership/membership.go new file mode 100644 index 0000000000..123846bea6 --- /dev/null +++ b/internal/membership/membership.go @@ -0,0 +1,140 @@ +package membership + +import ( + "fmt" + + v0 "github.com/authzed/authzed-go/proto/authzed/api/v0" + + "github.com/authzed/spicedb/pkg/tuple" +) + +// Set represents the set of membership for one or more ONRs, based on expansion +// trees. +type Set struct { + // objectsAndRelations is a map from an ONR (as a string) to the subjects found for that ONR. + objectsAndRelations map[string]FoundSubjects +} + +// SubjectsByONR returns a map from ONR (as a string) to the FoundSubjects for that ONR. +func (ms *Set) SubjectsByONR() map[string]FoundSubjects { + return ms.objectsAndRelations +} + +// NewMembershipSet constructs a new membership set. +// +// NOTE: This is designed solely for the developer API and should *not* be used in any performance +// sensitive code. +func NewMembershipSet() *Set { + return &Set{ + objectsAndRelations: map[string]FoundSubjects{}, + } +} + +// AddExpansion adds the expansion of an ONR to the membership set. Returns false if the ONR was already added. +// +// NOTE: The expansion tree *should* be the fully recursive expansion. +func (ms *Set) AddExpansion(onr *v0.ObjectAndRelation, expansion *v0.RelationTupleTreeNode) (FoundSubjects, bool, error) { + onrString := tuple.StringONR(onr) + existing, ok := ms.objectsAndRelations[onrString] + if ok { + return existing, false, nil + } + + tss, err := populateFoundSubjects(onr, expansion) + if err != nil { + return FoundSubjects{}, false, err + } + + fs := tss.ToFoundSubjects() + ms.objectsAndRelations[onrString] = fs + return fs, true, nil +} + +// AccessibleExpansionSubjects returns a TrackingSubjectSet representing the set of accessible subjects in the expansion. +func AccessibleExpansionSubjects(treeNode *v0.RelationTupleTreeNode) (TrackingSubjectSet, error) { + return populateFoundSubjects(treeNode.Expanded, treeNode) +} + +func populateFoundSubjects(rootONR *v0.ObjectAndRelation, treeNode *v0.RelationTupleTreeNode) (TrackingSubjectSet, error) { + resource := rootONR + if treeNode.Expanded != nil { + resource = treeNode.Expanded + } + + switch typed := treeNode.NodeType.(type) { + case *v0.RelationTupleTreeNode_IntermediateNode: + switch typed.IntermediateNode.Operation { + case v0.SetOperationUserset_UNION: + toReturn := NewTrackingSubjectSet() + for _, child := range typed.IntermediateNode.ChildNodes { + tss, err := populateFoundSubjects(resource, child) + if err != nil { + return nil, err + } + + toReturn.AddFrom(tss) + } + return toReturn, nil + + case v0.SetOperationUserset_INTERSECTION: + if len(typed.IntermediateNode.ChildNodes) == 0 { + return nil, fmt.Errorf("Found intersection with no children") + } + + firstChildSet, err := populateFoundSubjects(rootONR, typed.IntermediateNode.ChildNodes[0]) + if err != nil { + return nil, err + } + + toReturn := NewTrackingSubjectSet() + toReturn.AddFrom(firstChildSet) + + for _, child := range typed.IntermediateNode.ChildNodes[1:] { + childSet, err := populateFoundSubjects(rootONR, child) + if err != nil { + return nil, err + } + toReturn = toReturn.Intersect(childSet) + } + return toReturn, nil + + case v0.SetOperationUserset_EXCLUSION: + if len(typed.IntermediateNode.ChildNodes) == 0 { + return nil, fmt.Errorf("Found exclusion with no children") + } + + firstChildSet, err := populateFoundSubjects(rootONR, typed.IntermediateNode.ChildNodes[0]) + if err != nil { + return nil, err + } + + toReturn := NewTrackingSubjectSet() + toReturn.AddFrom(firstChildSet) + + for _, child := range typed.IntermediateNode.ChildNodes[1:] { + childSet, err := populateFoundSubjects(rootONR, child) + if err != nil { + return nil, err + } + toReturn = toReturn.Exclude(childSet) + } + + return toReturn, nil + + default: + panic("unknown expand operation") + } + + case *v0.RelationTupleTreeNode_LeafNode: + toReturn := NewTrackingSubjectSet() + for _, user := range typed.LeafNode.Users { + fs := NewFoundSubject(user.GetUserset()) + toReturn.Add(fs) + fs.relationships.Add(resource) + } + return toReturn, nil + + default: + panic("unknown TreeNode type") + } +} diff --git a/internal/membership/membership_test.go b/internal/membership/membership_test.go new file mode 100644 index 0000000000..f5c27de483 --- /dev/null +++ b/internal/membership/membership_test.go @@ -0,0 +1,383 @@ +package membership + +import ( + "sort" + "testing" + + v0 "github.com/authzed/authzed-go/proto/authzed/api/v0" + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/pkg/graph" + "github.com/authzed/spicedb/pkg/testutil" + "github.com/authzed/spicedb/pkg/tuple" +) + +var ( + ONR = tuple.ObjectAndRelation + Ellipsis = "..." +) + +var ( + _this *v0.ObjectAndRelation + + companyOwner = graph.Leaf(ONR("folder", "company", "owner"), + tuple.User(ONR("user", "owner", Ellipsis)), + ) + companyEditor = graph.Union(ONR("folder", "company", "editor"), + graph.Leaf(_this, tuple.User(ONR("user", "writer", Ellipsis))), + companyOwner, + ) + + auditorsOwner = graph.Leaf(ONR("folder", "auditors", "owner")) + + auditorsEditor = graph.Union(ONR("folder", "auditors", "editor"), + graph.Leaf(_this), + auditorsOwner, + ) + + auditorsViewerRecursive = graph.Union(ONR("folder", "auditors", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "auditor", "...")), + ), + auditorsEditor, + graph.Union(ONR("folder", "auditors", "viewer")), + ) + + companyViewerRecursive = graph.Union(ONR("folder", "company", "viewer"), + graph.Union(ONR("folder", "company", "viewer"), + auditorsViewerRecursive, + graph.Leaf(_this, + tuple.User(ONR("user", "legal", "...")), + tuple.User(ONR("folder", "auditors", "viewer")), + ), + ), + companyEditor, + graph.Union(ONR("folder", "company", "viewer")), + ) +) + +func TestMembershipSetBasic(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + // Add some expansion trees. + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "owner"), companyOwner) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso, "user:owner") + + fse, ok, err := ms.AddExpansion(ONR("folder", "company", "editor"), companyEditor) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fse, "user:owner", "user:writer") + + fsv, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), companyViewerRecursive) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fsv, "folder:auditors#viewer", "user:auditor", "user:legal", "user:owner", "user:writer") +} + +func TestMembershipSetIntersectionBasic(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + intersection := graph.Intersection(ONR("folder", "company", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "legal", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + tuple.User(ONR("user", "legal", "...")), + ), + ) + + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), intersection) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso, "user:legal") +} + +func TestMembershipSetExclusion(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + exclusion := graph.Exclusion(ONR("folder", "company", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + tuple.User(ONR("user", "legal", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "legal", "...")), + ), + ) + + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), exclusion) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso, "user:owner") +} + +func TestMembershipSetExclusionMultiple(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + exclusion := graph.Exclusion(ONR("folder", "company", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + tuple.User(ONR("user", "legal", "...")), + tuple.User(ONR("user", "third", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "legal", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + ), + ) + + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), exclusion) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso, "user:third") +} + +func TestMembershipSetExclusionMultipleWithWildcard(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + exclusion := graph.Exclusion(ONR("folder", "company", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + tuple.User(ONR("user", "legal", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "legal", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "*", "...")), + ), + ) + + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), exclusion) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso) +} + +func TestMembershipSetExclusionMultipleMiddle(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + exclusion := graph.Exclusion(ONR("folder", "company", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + tuple.User(ONR("user", "legal", "...")), + tuple.User(ONR("user", "third", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "another", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + ), + ) + + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), exclusion) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso, "user:third", "user:legal") +} + +func TestMembershipSetIntersectionWithOneWildcard(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + intersection := + graph.Intersection(ONR("folder", "company", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + tuple.User(ONR("user", "*", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "legal", "...")), + ), + ) + + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), intersection) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso, "user:legal") +} + +func TestMembershipSetIntersectionWithAllWildcardLeft(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + intersection := + graph.Intersection(ONR("folder", "company", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + tuple.User(ONR("user", "*", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "*", "...")), + ), + ) + + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), intersection) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso, "user:*", "user:owner") +} + +func TestMembershipSetIntersectionWithAllWildcardRight(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + intersection := + graph.Intersection(ONR("folder", "company", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "*", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + tuple.User(ONR("user", "*", "...")), + ), + ) + + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), intersection) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso, "user:*", "user:owner") +} + +func TestMembershipSetExclusionWithLeftWildcard(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + exclusion := + graph.Exclusion(ONR("folder", "company", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + tuple.User(ONR("user", "*", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "legal", "...")), + ), + ) + + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), exclusion) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso, "user:*", "user:owner") +} + +func TestMembershipSetExclusionWithRightWildcard(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + exclusion := + graph.Exclusion(ONR("folder", "company", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + tuple.User(ONR("user", "legal", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "*", "...")), + ), + ) + + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), exclusion) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso) +} + +func TestMembershipSetIntersectionWithThreeWildcards(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + intersection := + graph.Intersection(ONR("folder", "company", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + tuple.User(ONR("user", "legal", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "*", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "*", "...")), + ), + ) + + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), intersection) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso, "user:owner", "user:legal") +} + +func TestMembershipSetIntersectionWithOneBranchMissingWildcards(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + intersection := + graph.Intersection(ONR("folder", "company", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + tuple.User(ONR("user", "legal", "...")), + tuple.User(ONR("user", "*", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "*", "...")), + ), + ) + + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), intersection) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso, "user:owner") +} + +func TestMembershipSetIntersectionWithTwoBranchesMissingWildcards(t *testing.T) { + require := require.New(t) + ms := NewMembershipSet() + + intersection := + graph.Intersection(ONR("folder", "company", "viewer"), + graph.Leaf(_this, + tuple.User(ONR("user", "owner", "...")), + tuple.User(ONR("user", "legal", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "another", "...")), + ), + graph.Leaf(_this, + tuple.User(ONR("user", "*", "...")), + ), + ) + + fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), intersection) + require.True(ok) + require.NoError(err) + verifySubjects(t, require, fso) +} + +func verifySubjects(t *testing.T, require *require.Assertions, fs FoundSubjects, expected ...string) { + foundSubjects := []*v0.ObjectAndRelation{} + for _, found := range fs.ListFound() { + foundSubjects = append(foundSubjects, found.Subject()) + + _, ok := fs.LookupSubject(found.Subject()) + require.True(ok, "Could not find expected subject %s", found.Subject()) + } + + found := tuple.StringsONRs(foundSubjects) + sort.Strings(expected) + sort.Strings(found) + + testutil.RequireEqualEmptyNil(t, expected, found) +} diff --git a/internal/membership/trackingsubjectset.go b/internal/membership/trackingsubjectset.go new file mode 100644 index 0000000000..8f07dfa43d --- /dev/null +++ b/internal/membership/trackingsubjectset.go @@ -0,0 +1,185 @@ +package membership + +import ( + "fmt" + + v0 "github.com/authzed/authzed-go/proto/authzed/api/v0" + + "github.com/authzed/spicedb/pkg/tuple" +) + +func isWildcard(subject *v0.ObjectAndRelation) bool { + return subject.ObjectId == tuple.PublicWildcard +} + +// TrackingSubjectSet defines a set that tracks accessible subjects and their associated +// relationships. +// +// NOTE: This is designed solely for the developer API and testing and should *not* be used in any +// performance sensitive code. +// +// NOTE: Unlike a traditional set, unions between wildcards and a concrete subject will result +// in *both* being present in the set, to maintain the proper relationship tracking and reporting +// of concrete subjects. +// +// TODO(jschorr): Once we have stable generics support, break into a standard SubjectSet and +// a tracking variant built on top of it. +type TrackingSubjectSet map[string]FoundSubject + +// NewTrackingSubjectSet creates a new TrackingSubjectSet, with optional initial subjects. +func NewTrackingSubjectSet(subjects ...FoundSubject) TrackingSubjectSet { + var toReturn TrackingSubjectSet = make(map[string]FoundSubject) + toReturn.Add(subjects...) + return toReturn +} + +// AddFrom adds the subjects found in the other set to this set. +func (tss TrackingSubjectSet) AddFrom(otherSet TrackingSubjectSet) { + for _, value := range otherSet { + tss.Add(value) + } +} + +// RemoveFrom removes any subjects found in the other set from this set. +func (tss TrackingSubjectSet) RemoveFrom(otherSet TrackingSubjectSet) { + for _, otherSAR := range otherSet { + tss.Remove(otherSAR.subject) + } +} + +// Add adds the given subjects to this set. +func (tss TrackingSubjectSet) Add(subjectsAndResources ...FoundSubject) { + tss.AddWithResources(subjectsAndResources, nil) +} + +// AddWithResources adds the given subjects to this set, with the additional resources appended +// for each subject to be included in their relationships. +func (tss TrackingSubjectSet) AddWithResources(subjectsAndResources []FoundSubject, additionalResources *tuple.ONRSet) { + for _, sar := range subjectsAndResources { + found, ok := tss[toKey(sar.subject)] + if ok { + tss[toKey(sar.subject)] = found.union(sar) + } else { + tss[toKey(sar.subject)] = sar + } + } +} + +// Get returns the found subject in the set, if any. +func (tss TrackingSubjectSet) Get(subject *v0.ObjectAndRelation) (FoundSubject, bool) { + found, ok := tss[toKey(subject)] + return found, ok +} + +// Contains returns true if the set contains the given subject. +func (tss TrackingSubjectSet) Contains(subject *v0.ObjectAndRelation) bool { + _, ok := tss[toKey(subject)] + return ok +} + +// removeExact removes the given subject(s) from the set. If the subject is a wildcard, only +// the exact matching wildcard will be removed. +func (tss TrackingSubjectSet) removeExact(subjects ...*v0.ObjectAndRelation) { + for _, subject := range subjects { + delete(tss, toKey(subject)) + } +} + +// Remove removes the given subject(s) from the set. If the subject is a wildcard, all matching +// subjects are removed. If the subject matches a wildcard in the existing set, then it is added +// to that wildcard as an exclusion. +func (tss TrackingSubjectSet) Remove(subjects ...*v0.ObjectAndRelation) { + for _, subject := range subjects { + delete(tss, toKey(subject)) + + // Delete any entries matching the wildcard, if applicable. + if isWildcard(subject) { + // Remove any subjects matching the type. + for key := range tss { + current := fromKey(key) + if current.Namespace == subject.Namespace { + delete(tss, key) + } + } + } else { + // Check for any wildcards matching and, if found, add to the exclusion. + for _, existing := range tss { + wildcardType, ok := existing.WildcardType() + if ok && wildcardType == subject.Namespace { + existing.excludedSubjects.Add(subject) + } + } + } + } +} + +// WithType returns any subjects in the set with the given object type. +func (tss TrackingSubjectSet) WithType(objectType string) []FoundSubject { + toReturn := make([]FoundSubject, 0, len(tss)) + for _, current := range tss { + if current.subject.Namespace == objectType { + toReturn = append(toReturn, current) + } + } + return toReturn +} + +// Exclude returns a new set that contains the items in this set minus those in the other set. +func (tss TrackingSubjectSet) Exclude(otherSet TrackingSubjectSet) TrackingSubjectSet { + newSet := NewTrackingSubjectSet() + newSet.AddFrom(tss) + newSet.RemoveFrom(otherSet) + return newSet +} + +// Intersect returns a new set that contains the items in this set *and* the other set. Note that +// if wildcard is found in *both* sets, it will be returned *along* with any concrete subjects found +// on the other side of the intersection. +func (tss TrackingSubjectSet) Intersect(otherSet TrackingSubjectSet) TrackingSubjectSet { + newSet := NewTrackingSubjectSet() + for _, current := range tss { + // Add directly if shared by both. + other, ok := otherSet.Get(current.subject) + if ok { + newSet.Add(current.intersect(other)) + } + + // If the current is a wildcard, and add any matching. + if isWildcard(current.subject) { + newSet.AddWithResources(otherSet.WithType(current.subject.Namespace), current.relationships) + } + } + + for _, current := range otherSet { + // If the current is a wildcard, add any matching. + if isWildcard(current.subject) { + newSet.AddWithResources(tss.WithType(current.subject.Namespace), current.relationships) + } + } + + return newSet +} + +// ToSlice returns a slice of all subjects found in the set. +func (tss TrackingSubjectSet) ToSlice() []FoundSubject { + toReturn := make([]FoundSubject, 0, len(tss)) + for _, current := range tss { + toReturn = append(toReturn, current) + } + return toReturn +} + +// ToFoundSubjects returns the set as a FoundSubjects struct. +func (tss TrackingSubjectSet) ToFoundSubjects() FoundSubjects { + return FoundSubjects{tss} +} + +func toKey(subject *v0.ObjectAndRelation) string { + return fmt.Sprintf("%s %s %s", subject.Namespace, subject.ObjectId, subject.Relation) +} + +func fromKey(key string) *v0.ObjectAndRelation { + subject := &v0.ObjectAndRelation{} + fmt.Sscanf(key, "%s %s %s", &subject.Namespace, &subject.ObjectId, &subject.Relation) + return subject +} diff --git a/internal/membership/trackingsubjectset_test.go b/internal/membership/trackingsubjectset_test.go new file mode 100644 index 0000000000..bb1860dc23 --- /dev/null +++ b/internal/membership/trackingsubjectset_test.go @@ -0,0 +1,342 @@ +package membership + +import ( + "testing" + + v0 "github.com/authzed/authzed-go/proto/authzed/api/v0" + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/pkg/tuple" +) + +func set(subjects ...*v0.ObjectAndRelation) TrackingSubjectSet { + newSet := NewTrackingSubjectSet() + for _, subject := range subjects { + newSet.Add(NewFoundSubject(subject)) + } + return newSet +} + +func union(firstSet TrackingSubjectSet, sets ...TrackingSubjectSet) TrackingSubjectSet { + current := firstSet + for _, set := range sets { + current.AddFrom(set) + } + return current +} + +func intersect(firstSet TrackingSubjectSet, sets ...TrackingSubjectSet) TrackingSubjectSet { + current := firstSet + for _, set := range sets { + current = current.Intersect(set) + } + return current +} + +func exclude(firstSet TrackingSubjectSet, sets ...TrackingSubjectSet) TrackingSubjectSet { + current := firstSet + for _, set := range sets { + current = current.Exclude(set) + } + return current +} + +func fs(subjectType string, subjectID string, subjectRel string, excludedSubjects ...*v0.ObjectAndRelation) FoundSubject { + return FoundSubject{ + subject: ONR(subjectType, subjectID, subjectRel), + excludedSubjects: tuple.NewONRSet(excludedSubjects...), + relationships: tuple.NewONRSet(), + } +} + +func TestTrackingSubjectSet(t *testing.T) { + testCases := []struct { + name string + set TrackingSubjectSet + expected []FoundSubject + }{ + { + "simple set", + set(ONR("user", "user1", "...")), + []FoundSubject{fs("user", "user1", "...")}, + }, + { + "simple union", + union( + set(ONR("user", "user1", "...")), + set(ONR("user", "user2", "...")), + set(ONR("user", "user3", "...")), + ), + []FoundSubject{ + fs("user", "user1", "..."), + fs("user", "user2", "..."), + fs("user", "user3", "..."), + }, + }, + { + "simple intersection", + intersect( + set( + (ONR("user", "user1", "...")), + (ONR("user", "user2", "...")), + ), + set( + (ONR("user", "user2", "...")), + (ONR("user", "user3", "...")), + ), + set( + (ONR("user", "user2", "...")), + (ONR("user", "user4", "...")), + ), + ), + []FoundSubject{fs("user", "user2", "...")}, + }, + { + "empty intersection", + intersect( + set( + (ONR("user", "user1", "...")), + (ONR("user", "user2", "...")), + ), + set( + (ONR("user", "user3", "...")), + (ONR("user", "user4", "...")), + ), + ), + []FoundSubject{}, + }, + { + "simple exclusion", + exclude( + set( + (ONR("user", "user1", "...")), + (ONR("user", "user2", "...")), + ), + set(ONR("user", "user2", "...")), + set(ONR("user", "user3", "...")), + ), + []FoundSubject{fs("user", "user1", "...")}, + }, + { + "empty exclusion", + exclude( + set( + (ONR("user", "user1", "...")), + (ONR("user", "user2", "...")), + ), + set(ONR("user", "user1", "...")), + set(ONR("user", "user2", "...")), + ), + []FoundSubject{}, + }, + { + "wildcard left side union", + union( + set( + (ONR("user", "*", "...")), + ), + set(ONR("user", "user1", "...")), + ), + []FoundSubject{ + fs("user", "*", "..."), + fs("user", "user1", "..."), + }, + }, + { + "wildcard right side union", + union( + set(ONR("user", "user1", "...")), + set( + (ONR("user", "*", "...")), + ), + ), + []FoundSubject{ + fs("user", "*", "..."), + fs("user", "user1", "..."), + }, + }, + { + "wildcard left side exclusion", + exclude( + set( + (ONR("user", "*", "...")), + (ONR("user", "user2", "...")), + ), + set(ONR("user", "user1", "...")), + ), + []FoundSubject{ + fs("user", "*", "...", ONR("user", "user1", "...")), + fs("user", "user2", "..."), + }, + }, + { + "wildcard right side exclusion", + exclude( + set( + (ONR("user", "user2", "...")), + ), + set(ONR("user", "*", "...")), + ), + []FoundSubject{}, + }, + { + "wildcard right side concrete exclusion", + exclude( + set( + (ONR("user", "*", "...")), + ), + set(ONR("user", "user1", "...")), + ), + []FoundSubject{ + fs("user", "*", "...", ONR("user", "user1", "...")), + }, + }, + { + "wildcard both sides exclusion", + exclude( + set( + (ONR("user", "user2", "...")), + (ONR("user", "*", "...")), + ), + set(ONR("user", "*", "...")), + ), + []FoundSubject{}, + }, + { + "wildcard left side intersection", + intersect( + set( + (ONR("user", "*", "...")), + (ONR("user", "user2", "...")), + ), + set(ONR("user", "user1", "...")), + ), + []FoundSubject{ + fs("user", "user1", "..."), + }, + }, + { + "wildcard right side intersection", + intersect( + set(ONR("user", "user1", "...")), + set( + (ONR("user", "*", "...")), + (ONR("user", "user2", "...")), + ), + ), + []FoundSubject{ + fs("user", "user1", "..."), + }, + }, + { + "wildcard both sides intersection", + intersect( + set( + (ONR("user", "*", "...")), + (ONR("user", "user1", "..."))), + set( + (ONR("user", "*", "...")), + (ONR("user", "user2", "...")), + ), + ), + []FoundSubject{ + fs("user", "*", "..."), + fs("user", "user1", "..."), + fs("user", "user2", "..."), + }, + }, + { + "wildcard with exclusions union", + union( + NewTrackingSubjectSet(fs("user", "*", "...", ONR("user", "user1", "..."))), + NewTrackingSubjectSet(fs("user", "*", "...", ONR("user", "user2", "..."))), + ), + []FoundSubject{ + fs("user", "*", "...", ONR("user", "user1", "..."), ONR("user", "user2", "...")), + }, + }, + { + "wildcard with exclusions intersection", + intersect( + NewTrackingSubjectSet(fs("user", "*", "...", ONR("user", "user1", "..."))), + NewTrackingSubjectSet(fs("user", "*", "...", ONR("user", "user2", "..."))), + ), + []FoundSubject{ + fs("user", "*", "...", ONR("user", "user1", "..."), ONR("user", "user2", "...")), + }, + }, + { + "wildcard with exclusions exclusion", + exclude( + NewTrackingSubjectSet( + fs("user", "*", "...", ONR("user", "user1", "...")), + ), + NewTrackingSubjectSet(fs("user", "*", "...", ONR("user", "user2", "..."))), + ), + []FoundSubject{}, + }, + { + "wildcard with exclusions excluded user added", + exclude( + NewTrackingSubjectSet( + fs("user", "*", "...", ONR("user", "user1", "...")), + ), + NewTrackingSubjectSet(fs("user", "user2", "...")), + ), + []FoundSubject{ + fs("user", "*", "...", ONR("user", "user1", "..."), ONR("user", "user2", "...")), + }, + }, + { + "wildcard multiple exclusions", + exclude( + NewTrackingSubjectSet( + fs("user", "*", "...", ONR("user", "user1", "...")), + ), + NewTrackingSubjectSet(fs("user", "user2", "...")), + NewTrackingSubjectSet(fs("user", "user3", "...")), + ), + []FoundSubject{ + fs("user", "*", "...", ONR("user", "user1", "..."), ONR("user", "user2", "..."), ONR("user", "user3", "...")), + }, + }, + { + "intersection of exclusions", + intersect( + NewTrackingSubjectSet( + fs("user", "*", "...", ONR("user", "user1", "...")), + ), + NewTrackingSubjectSet( + fs("user", "*", "...", ONR("user", "user2", "...")), + ), + ), + []FoundSubject{ + fs("user", "*", "...", ONR("user", "user1", "..."), ONR("user", "user2", "...")), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + + for _, fs := range tc.expected { + _, isWildcard := fs.WildcardType() + if isWildcard { + found, ok := tc.set.Get(fs.subject) + require.True(ok, "missing expected subject %s", fs.subject) + + expectedExcluded := fs.excludedSubjects.AsSlice() + foundExcluded := found.excludedSubjects.AsSlice() + require.Len(fs.excludedSubjects.Subtract(found.excludedSubjects).AsSlice(), 0, "mismatch on excluded subjects on %s: expected: %s, found: %s", fs.subject, expectedExcluded, foundExcluded) + require.Len(found.excludedSubjects.Subtract(fs.excludedSubjects).AsSlice(), 0, "mismatch on excluded subjects on %s: expected: %s, found: %s", fs.subject, expectedExcluded, foundExcluded) + } else { + require.True(tc.set.Contains(fs.subject), "missing expected subject %s", fs.subject) + } + tc.set.removeExact(fs.subject) + } + + require.Len(tc.set, 0) + }) + } +} diff --git a/internal/services/consistency_test.go b/internal/services/consistency_test.go index ee07ff9193..a1de2c28b5 100644 --- a/internal/services/consistency_test.go +++ b/internal/services/consistency_test.go @@ -23,12 +23,12 @@ import ( "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/dispatch/caching" "github.com/authzed/spicedb/internal/dispatch/graph" + "github.com/authzed/spicedb/internal/membership" "github.com/authzed/spicedb/internal/namespace" v1 "github.com/authzed/spicedb/internal/proto/dispatch/v1" v0svc "github.com/authzed/spicedb/internal/services/v0" v1svc "github.com/authzed/spicedb/internal/services/v1" "github.com/authzed/spicedb/internal/testfixtures" - graphpkg "github.com/authzed/spicedb/pkg/graph" "github.com/authzed/spicedb/pkg/testutil" "github.com/authzed/spicedb/pkg/tuple" "github.com/authzed/spicedb/pkg/validationfile" @@ -133,18 +133,36 @@ func runAssertions(t *testing.T, rel := tuple.Parse(assertTrueRel) require.NotNil(t, rel) + // Ensure the assertion passes Check. result, err := tester.Check(context.Background(), rel.ObjectAndRelation, rel.User.GetUserset(), revision) require.NoError(t, err) require.True(t, result, "Assertion `%s` returned false; true expected", tuple.String(rel)) + + // Ensure the assertion passes Lookup. + resolvedObjectIds, err := tester.Lookup(context.Background(), &v0.RelationReference{ + Namespace: rel.ObjectAndRelation.Namespace, + Relation: rel.ObjectAndRelation.Relation, + }, rel.User.GetUserset(), revision) + require.NoError(t, err) + require.Contains(t, resolvedObjectIds, rel.ObjectAndRelation.ObjectId, "Missing object %s in lookup for assertion %s", rel.ObjectAndRelation, rel) } for _, assertFalseRel := range parsedFile.Assertions.AssertFalse { rel := tuple.Parse(assertFalseRel) require.NotNil(t, rel) + // Ensure the assertion does not pass Check. result, err := tester.Check(context.Background(), rel.ObjectAndRelation, rel.User.GetUserset(), revision) require.NoError(t, err) require.False(t, result, "Assertion `%s` returned true; false expected", tuple.String(rel)) + + // Ensure the assertion does not pass Lookup. + resolvedObjectIds, err := tester.Lookup(context.Background(), &v0.RelationReference{ + Namespace: rel.ObjectAndRelation.Namespace, + Relation: rel.ObjectAndRelation.Relation, + }, rel.User.GetUserset(), revision) + require.NoError(t, err) + require.NotContains(t, resolvedObjectIds, rel.ObjectAndRelation.ObjectId, "Found unexpected object %s in lookup for false assertion %s", rel.ObjectAndRelation, rel) } } } @@ -241,12 +259,18 @@ func runConsistencyTests(t *testing.T, // Collect the set of objects and subjects. objectsPerNamespace := setmultimap.New() subjects := tuple.NewONRSet() + subjectsNoWildcard := tuple.NewONRSet() for _, tpl := range fullyResolved.Tuples { objectsPerNamespace.Put(tpl.ObjectAndRelation.Namespace, tpl.ObjectAndRelation.ObjectId) switch m := tpl.User.UserOneof.(type) { case *v0.User_Userset: + // NOTE: we skip adding wildcards as subjects or object IDs. subjects.Add(m.Userset) + if m.Userset.ObjectId != tuple.PublicWildcard { + objectsPerNamespace.Put(m.Userset.Namespace, m.Userset.ObjectId) + subjectsNoWildcard.Add(m.Userset) + } } } @@ -263,11 +287,18 @@ func runConsistencyTests(t *testing.T, for _, objectID := range allObjectIds { objectIDStr := objectID.(string) + onr := &v0.ObjectAndRelation{ Namespace: nsDef.Name, Relation: relation.Name, ObjectId: objectIDStr, } + + if subject.ObjectId == tuple.PublicWildcard { + accessibilitySet.Set(onr, subject, isWildcard) + continue + } + hasPermission, err := tester.Check(context.Background(), onr, subject, revision) require.NoError(t, err) @@ -293,6 +324,7 @@ func runConsistencyTests(t *testing.T, accessibilitySet: accessibilitySet, dispatch: dispatch, subjects: subjects, + subjectsNoWildcard: subjectsNoWildcard, tester: tester, revision: revision, } @@ -325,13 +357,9 @@ func accessibleViaWildcardOnly(t *testing.T, dispatch dispatch.Dispatcher, onr * }) require.NoError(t, err) - subjectsFound := graphpkg.Simplify(resp.TreeNode) - subjectsFoundSet := tuple.NewONRSet() - for _, subjectUser := range subjectsFound { - subjectsFoundSet.Add(subjectUser.GetUserset()) - } - - return !subjectsFoundSet.Has(subject) + subjectsFound, err := membership.AccessibleExpansionSubjects(resp.TreeNode) + require.NoError(t, err) + return !subjectsFound.Contains(subject) } type validationContext struct { @@ -339,6 +367,7 @@ type validationContext struct { objectsPerNamespace *setmultimap.MultiMap subjects *tuple.ONRSet + subjectsNoWildcard *tuple.ONRSet accessibilitySet *accessibilitySet dispatch dispatch.Dispatcher @@ -391,12 +420,13 @@ func validateValidation(t *testing.T, dev v0.DeveloperServiceServer, reqContext require.Nil(t, err) for _, validationStr := range validationStrings { - subjectONR, err := validationStr.Subject() + foundSubject, err := validationStr.Subject() require.Nil(t, err) require.True(t, - vctx.accessibilitySet.GetIsMember(onr, subjectONR) == isMember, + (vctx.accessibilitySet.GetIsMember(onr, foundSubject.Subject) == isMember || + vctx.accessibilitySet.GetIsMember(onr, foundSubject.Subject) == isWildcard), "Generated expected relations returned inaccessible member %s for %s", - tuple.StringONR(subjectONR), + tuple.StringONR(foundSubject.Subject), tuple.StringONR(onr)) } } @@ -408,7 +438,7 @@ func validateValidation(t *testing.T, dev v0.DeveloperServiceServer, reqContext for _, result := range vctx.accessibilitySet.results { if result.isMember == isMember || result.isMember == isMemberViaWildcard { trueAssertions = append(trueAssertions, fmt.Sprintf("%s@%s", tuple.StringONR(result.object), tuple.StringONR(result.subject))) - } else { + } else if result.isMember == isNotMember { falseAssertions = append(falseAssertions, fmt.Sprintf("%s@%s", tuple.StringONR(result.object), tuple.StringONR(result.subject))) } } @@ -434,7 +464,7 @@ func validateValidation(t *testing.T, dev v0.DeveloperServiceServer, reqContext func validateEditChecks(t *testing.T, dev v0.DeveloperServiceServer, reqContext *v0.RequestContext, vctx *validationContext) { for _, nsDef := range vctx.fullyResolved.NamespaceDefinitions { for _, relation := range nsDef.Relation { - for _, subject := range vctx.subjects.AsSlice() { + for _, subject := range vctx.subjectsNoWildcard.AsSlice() { objectRelation := &v0.RelationReference{ Namespace: nsDef.Name, Relation: relation.Name, @@ -490,7 +520,7 @@ func validateEditChecks(t *testing.T, dev v0.DeveloperServiceServer, reqContext func validateLookup(t *testing.T, vctx *validationContext) { for _, nsDef := range vctx.fullyResolved.NamespaceDefinitions { for _, relation := range nsDef.Relation { - for _, subject := range vctx.subjects.AsSlice() { + for _, subject := range vctx.subjectsNoWildcard.AsSlice() { objectRelation := &v0.RelationReference{ Namespace: nsDef.Name, Relation: relation.Name, @@ -619,38 +649,76 @@ func validateExpansionSubjects(t *testing.T, vctx *validationContext) { }) vrequire.NoError(err) - subjectsFound := graphpkg.Simplify(resp.TreeNode) - subjectsFoundSet := tuple.NewONRSet() - - for _, subjectUser := range subjectsFound { - subjectsFoundSet.Add(subjectUser.GetUserset()) - } + subjectsFoundSet, err := membership.AccessibleExpansionSubjects(resp.TreeNode) + vrequire.NoError(err) // Ensure all terminal subjects were found in the expansion. - vrequire.EqualValues(0, accessibleTerminalSubjects.Subtract(subjectsFoundSet).Length()) + vrequire.EqualValues(0, len(accessibleTerminalSubjects.Exclude(subjectsFoundSet).ToSlice()), "Expected %s, Found: %s", accessibleTerminalSubjects.ToSlice(), subjectsFoundSet.ToSlice()) // Ensure every subject found matches Check. - for _, subjectUser := range subjectsFound { - subject := subjectUser.GetUserset() + for _, foundSubject := range subjectsFoundSet.ToSlice() { + excludedSubjects, isWildcard := foundSubject.ExcludedSubjectsFromWildcard() - isMember, err := vctx.tester.Check(context.Background(), - &v0.ObjectAndRelation{ - Namespace: nsDef.Name, - Relation: relation.Name, - ObjectId: objectIDStr, - }, - subject, - vctx.revision, - ) - vrequire.NoError(err) - vrequire.True( - isMember, - "Found Check under Expand failure for relation %s:%s#%s and subject %s", - nsDef.Name, - objectIDStr, - relation.Name, - tuple.StringONR(subject), - ) + // If the subject is a wildcard, then check every matching subject. + if isWildcard { + excludedSubjectsSet := tuple.NewONRSet(excludedSubjects...) + + allSubjectObjectIds, ok := vctx.objectsPerNamespace.Get(foundSubject.Subject().Namespace) + if !ok { + continue + } + + for _, subjectID := range allSubjectObjectIds { + subjectIDStr := subjectID.(string) + localSubject := &v0.ObjectAndRelation{ + Namespace: foundSubject.Subject().Namespace, + Relation: foundSubject.Subject().Relation, + ObjectId: subjectIDStr, + } + isMember, err := vctx.tester.Check(context.Background(), + &v0.ObjectAndRelation{ + Namespace: nsDef.Name, + Relation: relation.Name, + ObjectId: objectIDStr, + }, + localSubject, + vctx.revision, + ) + vrequire.NoError(err) + vrequire.Equal( + !excludedSubjectsSet.Has(localSubject), + isMember, + "Found Check under Expand failure for relation %s:%s#%s and subject %s (checked because of wildcard %s). Expected: %v, Found: %v", + nsDef.Name, + objectIDStr, + relation.Name, + tuple.StringONR(localSubject), + tuple.StringONR(foundSubject.Subject()), + !excludedSubjectsSet.Has(localSubject), + isMember, + ) + } + } else { + // Otherwise, check directly. + isMember, err := vctx.tester.Check(context.Background(), + &v0.ObjectAndRelation{ + Namespace: nsDef.Name, + Relation: relation.Name, + ObjectId: objectIDStr, + }, + foundSubject.Subject(), + vctx.revision, + ) + vrequire.NoError(err) + vrequire.True( + isMember, + "Found Check under Expand failure for relation %s:%s#%s and subject %s", + nsDef.Name, + objectIDStr, + relation.Name, + tuple.StringONR(foundSubject.Subject()), + ) + } } }) } @@ -673,6 +741,7 @@ const ( isNotMember isMemberStatus = 0 isMember isMemberStatus = 1 isMemberViaWildcard isMemberStatus = 2 + isWildcard isMemberStatus = 3 ) type checkResult struct { @@ -707,16 +776,15 @@ func (rs *accessibilitySet) GetIsMember(object *v0.ObjectAndRelation, subject *v } } - panic("Missing matching result") + panic(fmt.Sprintf("Missing matching result for %s %s", object, subject)) } -// AccessibleObjectIDs returns the set of object IDs accessible for the given subject from the given relation on the namespace, -// *not* including those accessible solely via wildcard. +// AccessibleObjectIDs returns the set of object IDs accessible for the given subject from the given relation on the namespace. func (rs *accessibilitySet) AccessibleObjectIDs(namespaceName string, relationName string, subject *v0.ObjectAndRelation) []string { var accessibleObjectIDs []string subjectStr := tuple.StringONR(subject) for _, result := range rs.results { - if result.isMember != isMember { + if result.isMember == isNotMember { continue } @@ -727,17 +795,16 @@ func (rs *accessibilitySet) AccessibleObjectIDs(namespaceName string, relationNa return accessibleObjectIDs } -// AccessibleTerminalSubjects returns the set of terminal subjects with accessible for the given object on the given relation on the namespace, -// *not* including those accessible solely via wildcard. -func (rs *accessibilitySet) AccessibleTerminalSubjects(namespaceName string, relationName string, objectIDStr string) *tuple.ONRSet { - accessibleSubjects := tuple.NewONRSet() +// AccessibleTerminalSubjects returns the set of terminal subjects with accessible for the given object on the given relation on the namespace +func (rs *accessibilitySet) AccessibleTerminalSubjects(namespaceName string, relationName string, objectIDStr string) membership.TrackingSubjectSet { + accessibleSubjects := membership.NewTrackingSubjectSet() for _, result := range rs.results { - if result.isMember != isMember { + if result.isMember == isNotMember || result.isMember == isWildcard { continue } if result.object.Namespace == namespaceName && result.object.Relation == relationName && result.object.ObjectId == objectIDStr && result.subject.Relation == "..." { - accessibleSubjects.Add(result.subject) + accessibleSubjects.Add(membership.NewFoundSubject(result.subject, result.object)) } } return accessibleSubjects diff --git a/internal/services/testconfigs/bannedintersectwildcard.yaml b/internal/services/testconfigs/bannedintersectwildcard.yaml new file mode 100644 index 0000000000..634510e931 --- /dev/null +++ b/internal/services/testconfigs/bannedintersectwildcard.yaml @@ -0,0 +1,25 @@ +--- +schema: >- + definition test/user {} + + definition test/resource { + relation viewer: test/user | test/user:* + relation banned1: test/user | test/user:* + relation banned2: test/user | test/user:* + + permission banned = banned1 & banned2 + permission view = viewer - banned + } +relationships: | + test/resource:first#viewer@test/user:* + test/resource:first#banned1@test/user:somegal + test/resource:first#banned2@test/user:somegal + test/resource:first#banned1@test/user:anotheruser +assertions: + assertTrue: + - "test/resource:first#view@test/user:anotheruser" + - "test/resource:first#view@test/user:editordude" + - "test/resource:first#view@test/user:aseconduser" + - "test/resource:first#view@test/user:athirduser" + assertFalse: + - "test/resource:first#view@test/user:somegal" diff --git a/internal/services/testconfigs/simplewildcard.yaml b/internal/services/testconfigs/simplewildcard.yaml new file mode 100644 index 0000000000..7af9470557 --- /dev/null +++ b/internal/services/testconfigs/simplewildcard.yaml @@ -0,0 +1,16 @@ +--- +schema: >- + definition test/user {} + + definition test/resource { + relation viewer: test/user | test/user:* + } +relationships: | + test/resource:first#viewer@test/user:* + test/resource:first#viewer@test/user:concreteguy +assertions: + assertTrue: + - test/resource:first#viewer@test/user:concreteguy + - test/resource:first#viewer@test/user:anotheruser + - test/resource:first#viewer@test/user:aseconduser + - test/resource:first#viewer@test/user:athirduser diff --git a/internal/services/testconfigs/wildcardnested.yaml b/internal/services/testconfigs/wildcardnested.yaml new file mode 100644 index 0000000000..63ec04e20c --- /dev/null +++ b/internal/services/testconfigs/wildcardnested.yaml @@ -0,0 +1,35 @@ +--- +schema: >- + definition test/user {} + + definition test/resource { + relation viewer: test/user | test/user:* + relation banned: test/user + relation mustbehere: test/user + + permission view = viewer - banned + permission specialview = view & mustbehere + } +relationships: | + test/resource:first#viewer@test/user:* + test/resource:first#banned@test/user:bannedguy + test/resource:first#mustbehere@test/user:somegal +assertions: + assertTrue: + - test/resource:first#viewer@test/user:somegal + - test/resource:first#viewer@test/user:anotherperson + - test/resource:first#viewer@test/user:thirduser + - test/resource:first#viewer@test/user:bannedguy + + - test/resource:first#view@test/user:somegal + - test/resource:first#view@test/user:anotherperson + - test/resource:first#view@test/user:thirduser + + - test/resource:first#mustbehere@test/user:somegal + - test/resource:first#specialview@test/user:somegal + assertFalse: + - test/resource:first#view@test/user:bannedguy + + - test/resource:first#specialview@test/user:bannedguy + - test/resource:first#specialview@test/user:anotherperson + - test/resource:first#specialview@test/user:thirduser diff --git a/internal/services/testconfigs/wildcardwithintersection.yaml b/internal/services/testconfigs/wildcardwithintersection.yaml new file mode 100644 index 0000000000..c87ef6c0a2 --- /dev/null +++ b/internal/services/testconfigs/wildcardwithintersection.yaml @@ -0,0 +1,25 @@ +--- +schema: >- + definition test/user {} + + definition test/resource { + relation viewer: test/user | test/user:* + relation reader: test/user | test/user:* + + permission view = viewer & reader + } +relationships: | + test/resource:first#reader@test/user:* + test/resource:first#viewer@test/user:somegal + + test/resource:second#reader@test/user:* + test/resource:second#viewer@test/user:* +assertions: + assertTrue: + - test/resource:first#view@test/user:somegal + - test/resource:second#view@test/user:editordude + - test/resource:second#view@test/user:seconduser + assertFalse: + - "test/resource:first#view@test/user:editordude" + - "test/resource:first#view@test/user:anotheruser" + - "test/resource:first#view@test/user:aseconduser" diff --git a/internal/services/testconfigs/wildcardwithrightsideexclusion.yaml b/internal/services/testconfigs/wildcardwithrightsideexclusion.yaml new file mode 100644 index 0000000000..2304e1d1e9 --- /dev/null +++ b/internal/services/testconfigs/wildcardwithrightsideexclusion.yaml @@ -0,0 +1,27 @@ +--- +schema: >- + definition test/user {} + + definition test/resource { + relation viewer: test/user + relation banned: test/user | test/user:* + + permission view = viewer - banned + } +relationships: | + test/resource:first#banned@test/user:* + test/resource:first#viewer@test/user:somegal + + test/resource:second#banned@test/user:otherperson + test/resource:second#viewer@test/user:somegal +assertions: + assertTrue: + - test/resource:first#viewer@test/user:somegal + - test/resource:second#viewer@test/user:somegal + assertFalse: + - "test/resource:first#view@test/user:editordude" + - "test/resource:first#view@test/user:anotheruser" + - "test/resource:first#view@test/user:aseconduser" + - "test/resource:first#view@test/user:athirduser" + - "test/resource:first#view@test/user:somegal" + - "test/resource:second#view@test/user:otherperson" diff --git a/internal/services/v0/acl.go b/internal/services/v0/acl.go index f39e8994bf..5464fcb8d0 100644 --- a/internal/services/v0/acl.go +++ b/internal/services/v0/acl.go @@ -433,6 +433,9 @@ func rewriteACLError(ctx context.Context, err error) error { case errors.As(err, &graph.ErrRequestCanceled{}): return status.Errorf(codes.Canceled, "request canceled: %s", err) + case errors.As(err, &graph.ErrInvalidArgument{}): + return status.Errorf(codes.InvalidArgument, "%s", err) + case errors.As(err, &datastore.ErrInvalidRevision{}): return status.Errorf(codes.OutOfRange, "invalid zookie: %s", err) diff --git a/internal/services/v0/acl_test.go b/internal/services/v0/acl_test.go index 50ac4955a2..c7324b9b52 100644 --- a/internal/services/v0/acl_test.go +++ b/internal/services/v0/acl_test.go @@ -25,9 +25,9 @@ import ( "github.com/authzed/spicedb/internal/datastore" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch/graph" + "github.com/authzed/spicedb/internal/membership" "github.com/authzed/spicedb/internal/namespace" tf "github.com/authzed/spicedb/internal/testfixtures" - g "github.com/authzed/spicedb/pkg/graph" ns "github.com/authzed/spicedb/pkg/namespace" "github.com/authzed/spicedb/pkg/tuple" "github.com/authzed/spicedb/pkg/zookie" @@ -522,6 +522,13 @@ func TestCheck(t *testing.T) { {ONR("user", "aasdasd", "..."), false}, }, }, + { + ONR("document", "somedoc", "owner"), + codes.InvalidArgument, + []checkTest{ + {ONR("user", "*", "..."), false}, + }, + }, } for _, delta := range testTimedeltas { @@ -643,7 +650,9 @@ func TestExpand(t *testing.T) { require.NotNil(expanded.Revision) require.NotEmpty(expanded.Revision.Token) - require.Equal(tc.expandRelatedCount, len(g.Simplify(expanded.TreeNode))) + found, err := membership.AccessibleExpansionSubjects(expanded.TreeNode) + require.NoError(err) + require.Equal(tc.expandRelatedCount, len(found)) dispatchCount, err := responsemeta.GetIntResponseTrailerMetadata(trailer, responsemeta.DispatchedOperationsCount) require.NoError(err) @@ -773,6 +782,12 @@ func TestLookup(t *testing.T) { []string{}, codes.FailedPrecondition, }, + { + RR("document", "viewer_and_editor_derived"), + ONR("user", "*", "..."), + []string{}, + codes.InvalidArgument, + }, } for _, delta := range testTimedeltas { diff --git a/internal/services/v0/developer.go b/internal/services/v0/developer.go index 18d9b1c8b3..bcd3fc4422 100644 --- a/internal/services/v0/developer.go +++ b/internal/services/v0/developer.go @@ -16,9 +16,9 @@ import ( "google.golang.org/grpc" "google.golang.org/protobuf/encoding/prototext" + "github.com/authzed/spicedb/internal/membership" v1 "github.com/authzed/spicedb/internal/proto/dispatch/v1" "github.com/authzed/spicedb/internal/sharederrors" - "github.com/authzed/spicedb/pkg/membership" "github.com/authzed/spicedb/pkg/schemadsl/generator" "github.com/authzed/spicedb/pkg/tuple" "github.com/authzed/spicedb/pkg/validationfile" @@ -331,7 +331,7 @@ func generateValidation(membershipSet *membership.Set) (string, error) { for _, fs := range foundSubjects.ListFound() { strs = append(strs, fmt.Sprintf("[%s] is %s", - tuple.StringONR(fs.Subject()), + fs.ToValidationString(), strings.Join(wrapRelationships(tuple.StringsONRs(fs.Relationships())), "/"), )) } @@ -425,7 +425,7 @@ func validateSubjects(onr *v0.ObjectAndRelation, fs membership.FoundSubjects, va // Verify that every referenced subject is found in the membership. encounteredSubjects := map[string]struct{}{} for _, validationString := range validationStrings { - subjectONR, err := validationString.Subject() + expectedSubject, err := validationString.Subject() if err != nil { failures = append(failures, &v0.DeveloperError{ Message: fmt.Sprintf("For object and permission/relation `%s`, %s", tuple.StringONR(onr), err.Error()), @@ -436,10 +436,11 @@ func validateSubjects(onr *v0.ObjectAndRelation, fs membership.FoundSubjects, va continue } - if subjectONR == nil { + if expectedSubject == nil { continue } + subjectONR := expectedSubject.Subject encounteredSubjects[tuple.StringONR(subjectONR)] = struct{}{} expectedRelationships, err := validationString.ONRS() @@ -482,6 +483,38 @@ func validateSubjects(onr *v0.ObjectAndRelation, fs membership.FoundSubjects, va Context: string(validationString), }) } + + // Verify exclusions are the same, if any. + foundExcludedSubjects, isWildcard := subject.ExcludedSubjectsFromWildcard() + expectedExcludedSubjects := expectedSubject.Exceptions + if isWildcard { + expectedExcludedONRStrings := tuple.StringsONRs(expectedExcludedSubjects) + foundExcludedONRStrings := tuple.StringsONRs(foundExcludedSubjects) + if !cmp.Equal(expectedExcludedONRStrings, foundExcludedONRStrings) { + failures = append(failures, &v0.DeveloperError{ + Message: fmt.Sprintf("For object and permission/relation `%s`, found different excluded subjects for subject `%s`: Specified: `%s`, Computed: `%s`", + tuple.StringONR(onr), + tuple.StringONR(subjectONR), + strings.Join(wrapRelationships(expectedExcludedONRStrings), ", "), + strings.Join(wrapRelationships(foundExcludedONRStrings), ", "), + ), + Source: v0.DeveloperError_VALIDATION_YAML, + Kind: v0.DeveloperError_MISSING_EXPECTED_RELATIONSHIP, + Context: string(validationString), + }) + } + } else { + if len(expectedExcludedSubjects) > 0 { + failures = append(failures, &v0.DeveloperError{ + Message: fmt.Sprintf("For object and permission/relation `%s`, found unexpected excluded subjects", + tuple.StringONR(onr), + ), + Source: v0.DeveloperError_VALIDATION_YAML, + Kind: v0.DeveloperError_EXTRA_RELATIONSHIP_FOUND, + Context: string(validationString), + }) + } + } } // Verify that every subject found was referenced. diff --git a/internal/services/v0/developer_test.go b/internal/services/v0/developer_test.go index a2f177a2d1..07e0e603f5 100644 --- a/internal/services/v0/developer_test.go +++ b/internal/services/v0/developer_test.go @@ -634,6 +634,59 @@ assertFalse: `document:somedoc#view: - '[user:*] is ' - '[user:jimmy] is ' +`, + }, + { + "wildcard exclusion", + ` + definition user {} + definition document { + relation banned: user + relation viewer: user | user:* + permission view = viewer - banned + } + `, + []*v0.RelationTuple{ + tuple.MustParse("document:somedoc#banned@user:jimmy"), + tuple.MustParse("document:somedoc#viewer@user:*"), + }, + `"document:somedoc#view": +- "[user:* - {user:jimmy}] is "`, + `assertTrue: +- document:somedoc#view@user:somegal +assertFalse: +- document:somedoc#view@user:jimmy`, + nil, + `document:somedoc#view: +- '[user:* - {user:jimmy}] is ' +`, + }, + { + "wildcard exclusion under intersection", + ` + definition user {} + definition document { + relation banned: user + relation viewer: user | user:* + relation other: user + permission view = (viewer - banned) & (viewer - other) + } + `, + []*v0.RelationTuple{ + tuple.MustParse("document:somedoc#other@user:sarah"), + tuple.MustParse("document:somedoc#banned@user:jimmy"), + tuple.MustParse("document:somedoc#viewer@user:*"), + }, + `"document:somedoc#view": +- "[user:* - {user:jimmy}] is "`, + `assertTrue: +- document:somedoc#view@user:somegal +assertFalse: +- document:somedoc#view@user:jimmy +- document:somedoc#view@user:sarah`, + nil, + `document:somedoc#view: +- '[user:* - {user:jimmy, user:sarah}] is ' `, }, } diff --git a/internal/services/v1/permissions_test.go b/internal/services/v1/permissions_test.go index 5e3b2f6a60..e19eeffd71 100644 --- a/internal/services/v1/permissions_test.go +++ b/internal/services/v1/permissions_test.go @@ -228,6 +228,13 @@ func TestCheckPermissions(t *testing.T) { v1.CheckPermissionResponse_PERMISSIONSHIP_UNSPECIFIED, codes.InvalidArgument, }, + { + obj("document", "something"), + "viewer", + sub("user", "*", ""), + v1.CheckPermissionResponse_PERMISSIONSHIP_UNSPECIFIED, + codes.InvalidArgument, + }, } for _, delta := range testTimedeltas { @@ -390,6 +397,12 @@ func TestLookupResources(t *testing.T) { []string{}, codes.FailedPrecondition, }, + { + "document", "viewer_and_editor_derived", + sub("user", "*", ""), + []string{}, + codes.InvalidArgument, + }, } for _, delta := range testTimedeltas { diff --git a/internal/services/v1/relationships.go b/internal/services/v1/relationships.go index 2ee57a2858..b5e695517b 100644 --- a/internal/services/v1/relationships.go +++ b/internal/services/v1/relationships.go @@ -207,7 +207,7 @@ func (ps *permissionServer) WriteRelationships(ctx context.Context, req *v1.Writ if isAllowed != namespace.PublicSubjectAllowed { return nil, status.Errorf( codes.InvalidArgument, - "wildcardsubjects of type %s are not allowed on %v", + "wildcard subjects of type %s are not allowed on %v", update.Relationship.Subject.Object.ObjectType, tuple.StringObjectRef(update.Relationship.Resource), ) @@ -275,6 +275,9 @@ func rewritePermissionsError(ctx context.Context, err error) error { case errors.As(err, &datastore.ErrPreconditionFailed{}): return status.Errorf(codes.FailedPrecondition, "failed precondition: %s", err) + case errors.As(err, &graph.ErrInvalidArgument{}): + return status.Errorf(codes.InvalidArgument, "%s", err) + case errors.As(err, &graph.ErrRequestCanceled{}): return status.Errorf(codes.Canceled, "request canceled: %s", err) diff --git a/pkg/graph/tree.go b/pkg/graph/tree.go index 8e3b83fc17..c0b8da5926 100644 --- a/pkg/graph/tree.go +++ b/pkg/graph/tree.go @@ -1,81 +1,10 @@ package graph import ( - "fmt" - v0 "github.com/authzed/authzed-go/proto/authzed/api/v0" ) -func Simplify(node *v0.RelationTupleTreeNode) []*v0.User { - switch typed := node.NodeType.(type) { - case *v0.RelationTupleTreeNode_IntermediateNode: - switch typed.IntermediateNode.Operation { - case v0.SetOperationUserset_UNION: - return SimplifyUnion(typed.IntermediateNode.ChildNodes) - case v0.SetOperationUserset_INTERSECTION: - return SimplifyIntersection(typed.IntermediateNode.ChildNodes) - case v0.SetOperationUserset_EXCLUSION: - return SimplifyExclusion(typed.IntermediateNode.ChildNodes) - } - case *v0.RelationTupleTreeNode_LeafNode: - var toReturn UserSet = make(map[string]struct{}) - for _, usr := range typed.LeafNode.Users { - toReturn.Add(usr) - } - return toReturn.ToSlice() - } - return nil -} - -func SimplifyUnion(children []*v0.RelationTupleTreeNode) []*v0.User { - var toReturn UserSet = make(map[string]struct{}) - for _, child := range children { - toReturn.Add(Simplify(child)...) - } - return toReturn.ToSlice() -} - -func SimplifyIntersection(children []*v0.RelationTupleTreeNode) []*v0.User { - firstChildChildren := Simplify(children[0]) - - if len(children) == 1 { - return firstChildChildren - } - - var inOthers UserSet = make(map[string]struct{}) - inOthers.Add(SimplifyIntersection(children[1:])...) - - maxChildren := len(firstChildChildren) - if len(inOthers) < maxChildren { - maxChildren = len(inOthers) - } - - toReturn := make([]*v0.User, 0, maxChildren) - for _, child := range firstChildChildren { - if inOthers.Contains(child) { - toReturn = append(toReturn, child) - } - } - - return toReturn -} - -func SimplifyExclusion(children []*v0.RelationTupleTreeNode) []*v0.User { - firstChildChildren := Simplify(children[0]) - - if len(children) == 1 || len(firstChildChildren) == 0 { - return firstChildChildren - } - - var toReturn UserSet = make(map[string]struct{}) - toReturn.Add(firstChildChildren...) - for _, child := range children[1:] { - toReturn.Remove(Simplify(child)...) - } - - return toReturn.ToSlice() -} - +// Leaf constructs a RelationTupleTreeNode leaf. func Leaf(start *v0.ObjectAndRelation, children ...*v0.User) *v0.RelationTupleTreeNode { return &v0.RelationTupleTreeNode{ NodeType: &v0.RelationTupleTreeNode_LeafNode{ @@ -103,53 +32,17 @@ func setResult( } } +// Union constructs a RelationTupleTreeNode union operation. func Union(start *v0.ObjectAndRelation, children ...*v0.RelationTupleTreeNode) *v0.RelationTupleTreeNode { return setResult(v0.SetOperationUserset_UNION, start, children) } +// Intersection constructs a RelationTupleTreeNode intersection operation. func Intersection(start *v0.ObjectAndRelation, children ...*v0.RelationTupleTreeNode) *v0.RelationTupleTreeNode { return setResult(v0.SetOperationUserset_INTERSECTION, start, children) } +// Exclusion constructs a RelationTupleTreeNode exclusion operation. func Exclusion(start *v0.ObjectAndRelation, children ...*v0.RelationTupleTreeNode) *v0.RelationTupleTreeNode { return setResult(v0.SetOperationUserset_EXCLUSION, start, children) } - -type UserSet map[string]struct{} - -func (us UserSet) Add(users ...*v0.User) { - for _, usr := range users { - us[toKey(usr)] = struct{}{} - } -} - -func (us UserSet) Contains(usr *v0.User) bool { - _, ok := us[toKey(usr)] - return ok -} - -func (us UserSet) Remove(users ...*v0.User) { - for _, usr := range users { - delete(us, toKey(usr)) - } -} - -func (us UserSet) ToSlice() []*v0.User { - toReturn := make([]*v0.User, 0, len(us)) - for key := range us { - toReturn = append(toReturn, fromKey(key)) - } - return toReturn -} - -func toKey(usr *v0.User) string { - return fmt.Sprintf("%s %s %s", usr.GetUserset().Namespace, usr.GetUserset().ObjectId, usr.GetUserset().Relation) -} - -func fromKey(key string) *v0.User { - userset := &v0.ObjectAndRelation{} - fmt.Sscanf(key, "%s %s %s", &userset.Namespace, &userset.ObjectId, &userset.Relation) - return &v0.User{ - UserOneof: &v0.User_Userset{Userset: userset}, - } -} diff --git a/pkg/graph/tree_test.go b/pkg/graph/tree_test.go deleted file mode 100644 index 3abbe51be9..0000000000 --- a/pkg/graph/tree_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package graph - -import ( - "testing" - - v0 "github.com/authzed/authzed-go/proto/authzed/api/v0" - "github.com/stretchr/testify/require" - - "github.com/authzed/spicedb/pkg/tuple" -) - -var ONR = tuple.ObjectAndRelation - -func TestSimplify(t *testing.T) { - testCases := []struct { - name string - tree *v0.RelationTupleTreeNode - expected []*v0.ObjectAndRelation - }{ - { - "simple leaf", - Leaf(nil, tuple.User(ONR("user", "user1", "..."))), - []*v0.ObjectAndRelation{ONR("user", "user1", "...")}, - }, - { - "simple union", - Union(nil, - Leaf(nil, tuple.User(ONR("user", "user1", "..."))), - Leaf(nil, tuple.User(ONR("user", "user2", "..."))), - Leaf(nil, tuple.User(ONR("user", "user3", "..."))), - ), - []*v0.ObjectAndRelation{ - ONR("user", "user1", "..."), - ONR("user", "user2", "..."), - ONR("user", "user3", "..."), - }, - }, - { - "simple intersection", - Intersection(nil, - Leaf(nil, - tuple.User(ONR("user", "user1", "...")), - tuple.User(ONR("user", "user2", "...")), - ), - Leaf(nil, - tuple.User(ONR("user", "user2", "...")), - tuple.User(ONR("user", "user3", "...")), - ), - Leaf(nil, - tuple.User(ONR("user", "user2", "...")), - tuple.User(ONR("user", "user4", "...")), - ), - ), - []*v0.ObjectAndRelation{ONR("user", "user2", "...")}, - }, - { - "empty intersection", - Intersection(nil, - Leaf(nil, - tuple.User(ONR("user", "user1", "...")), - tuple.User(ONR("user", "user2", "...")), - ), - Leaf(nil, - tuple.User(ONR("user", "user3", "...")), - tuple.User(ONR("user", "user4", "...")), - ), - ), - []*v0.ObjectAndRelation{}, - }, - { - "simple exclusion", - Exclusion(nil, - Leaf(nil, - tuple.User(ONR("user", "user1", "...")), - tuple.User(ONR("user", "user2", "...")), - ), - Leaf(nil, tuple.User(ONR("user", "user2", "..."))), - Leaf(nil, tuple.User(ONR("user", "user3", "..."))), - ), - []*v0.ObjectAndRelation{ONR("user", "user1", "...")}, - }, - { - "empty exclusion", - Exclusion(nil, - Leaf(nil, - tuple.User(ONR("user", "user1", "...")), - tuple.User(ONR("user", "user2", "...")), - ), - Leaf(nil, tuple.User(ONR("user", "user1", "..."))), - Leaf(nil, tuple.User(ONR("user", "user2", "..."))), - ), - []*v0.ObjectAndRelation{}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - require := require.New(t) - - var simplified UserSet = make(map[string]struct{}) - simplified.Add(Simplify(tc.tree)...) - - for _, onr := range tc.expected { - usr := tuple.User(onr) - require.True(simplified.Contains(usr)) - simplified.Remove(usr) - } - - require.Len(simplified, 0) - }) - } -} diff --git a/pkg/membership/membership.go b/pkg/membership/membership.go deleted file mode 100644 index 2c3bb30d4d..0000000000 --- a/pkg/membership/membership.go +++ /dev/null @@ -1,238 +0,0 @@ -package membership - -import ( - "fmt" - - v0 "github.com/authzed/authzed-go/proto/authzed/api/v0" - - "github.com/authzed/spicedb/pkg/tuple" -) - -// Set represents the set of membership for one or more ONRs, based on expansion -// trees. -type Set struct { - // objectsAndRelations is a map from an ONR (as a string) to the subjects found for that ONR. - objectsAndRelations map[string]FoundSubjects -} - -// SubjectsByONR returns a map from ONR (as a string) to the FoundSubjects for that ONR. -func (ms *Set) SubjectsByONR() map[string]FoundSubjects { - return ms.objectsAndRelations -} - -// FoundSubjects contains the subjects found for a specific ONR. -type FoundSubjects struct { - // subjects is a map from the Subject ONR (as a string) to the FoundSubject information. - subjects map[string]FoundSubject -} - -// ListFound returns a slice of all the FoundSubject's. -func (fs FoundSubjects) ListFound() []FoundSubject { - found := []FoundSubject{} - for _, sub := range fs.subjects { - found = append(found, sub) - } - return found -} - -// LookupSubject returns the FoundSubject for a matching subject, if any. -func (fs FoundSubjects) LookupSubject(subject *v0.ObjectAndRelation) (FoundSubject, bool) { - onrString := tuple.StringONR(subject) - found, ok := fs.subjects[onrString] - return found, ok -} - -// FoundSubject contains a single found subject and all the relationships in which that subject -// is a member which were found via the ONRs expansion. -type FoundSubject struct { - // subject is the subject found. - subject *v0.ObjectAndRelation - - // relations are the relations under which the subject lives that informed the locating - // of this subject for the root ONR. - relationships *tuple.ONRSet -} - -// Subject returns the Subject of the FoundSubject. -func (fs FoundSubject) Subject() *v0.ObjectAndRelation { - return fs.subject -} - -// Relationships returns all the relationships in which the subject was found as per the expand. -func (fs FoundSubject) Relationships() []*v0.ObjectAndRelation { - return fs.relationships.AsSlice() -} - -// NewMembershipSet constructs a new membership set. -// -// NOTE: This is designed solely for the developer API and should *not* be used in any performance -// sensitive code. -func NewMembershipSet() *Set { - return &Set{ - objectsAndRelations: map[string]FoundSubjects{}, - } -} - -// AddExpansion adds the expansion of an ONR to the membership set. Returns false if the ONR was already added. -// -// NOTE: The expansion tree *should* be the fully recursive expansion. -func (ms *Set) AddExpansion(onr *v0.ObjectAndRelation, expansion *v0.RelationTupleTreeNode) (FoundSubjects, bool, error) { - onrString := tuple.StringONR(onr) - existing, ok := ms.objectsAndRelations[onrString] - if ok { - return existing, false, nil - } - - foundSubjectsMap := map[string]FoundSubject{} - err := populateFoundSubjects(foundSubjectsMap, onr, expansion) - if err != nil { - return FoundSubjects{}, false, err - } - - fs := FoundSubjects{ - subjects: foundSubjectsMap, - } - ms.objectsAndRelations[onrString] = fs - return fs, true, nil -} - -func populateFoundSubjects(foundSubjectsMap map[string]FoundSubject, rootONR *v0.ObjectAndRelation, treeNode *v0.RelationTupleTreeNode) error { - relationship := rootONR - if treeNode.Expanded != nil { - relationship = treeNode.Expanded - } - - switch typed := treeNode.NodeType.(type) { - case *v0.RelationTupleTreeNode_IntermediateNode: - switch typed.IntermediateNode.Operation { - case v0.SetOperationUserset_UNION: - for _, child := range typed.IntermediateNode.ChildNodes { - err := populateFoundSubjects(foundSubjectsMap, rootONR, child) - if err != nil { - return err - } - } - - case v0.SetOperationUserset_INTERSECTION: - if len(typed.IntermediateNode.ChildNodes) == 0 { - return fmt.Errorf("found intersection with no children") - } - - fsm := map[string]FoundSubject{} - err := populateFoundSubjects(fsm, rootONR, typed.IntermediateNode.ChildNodes[0]) - if err != nil { - return err - } - - subjectset := newSubjectSet() - subjectset.union(fsm) - - for _, child := range typed.IntermediateNode.ChildNodes[1:] { - fsm := map[string]FoundSubject{} - if err := populateFoundSubjects(fsm, rootONR, child); err != nil { - return err - } - subjectset.intersect(fsm) - } - - subjectset.populate(foundSubjectsMap) - - case v0.SetOperationUserset_EXCLUSION: - if len(typed.IntermediateNode.ChildNodes) == 0 { - return fmt.Errorf("found exclusion with no children") - } - - fsm := map[string]FoundSubject{} - err := populateFoundSubjects(fsm, rootONR, typed.IntermediateNode.ChildNodes[0]) - if err != nil { - return err - } - - subjectset := newSubjectSet() - subjectset.union(fsm) - - for _, child := range typed.IntermediateNode.ChildNodes[1:] { - fsm := map[string]FoundSubject{} - if err := populateFoundSubjects(fsm, rootONR, child); err != nil { - return err - } - subjectset.exclude(fsm) - } - - subjectset.populate(foundSubjectsMap) - - default: - panic("unknown expand operation") - } - - case *v0.RelationTupleTreeNode_LeafNode: - for _, user := range typed.LeafNode.Users { - subjectONRString := tuple.StringONR(user.GetUserset()) - _, ok := foundSubjectsMap[subjectONRString] - if !ok { - foundSubjectsMap[subjectONRString] = FoundSubject{ - subject: user.GetUserset(), - relationships: tuple.NewONRSet(), - } - } - - foundSubjectsMap[subjectONRString].relationships.Add(relationship) - } - default: - panic("unknown TreeNode type") - } - - return nil -} - -type subjectSet struct { - subjectsMap map[string]FoundSubject -} - -func newSubjectSet() *subjectSet { - return &subjectSet{ - subjectsMap: map[string]FoundSubject{}, - } -} - -func (ss *subjectSet) populate(outgoingSubjectsMap map[string]FoundSubject) { - for key, fs := range ss.subjectsMap { - existing, ok := outgoingSubjectsMap[key] - if ok { - existing.relationships.UpdateFrom(fs.relationships) - } else { - outgoingSubjectsMap[key] = fs - } - } -} - -func (ss *subjectSet) union(subjectsMap map[string]FoundSubject) { - for key, fs := range subjectsMap { - existing, ok := ss.subjectsMap[key] - if ok { - existing.relationships.UpdateFrom(fs.relationships) - } else { - ss.subjectsMap[key] = fs - } - } -} - -func (ss *subjectSet) intersect(subjectsMap map[string]FoundSubject) { - for key, fs := range ss.subjectsMap { - other, ok := subjectsMap[key] - if ok { - fs.relationships.UpdateFrom(other.relationships) - } else { - delete(ss.subjectsMap, key) - } - } -} - -func (ss *subjectSet) exclude(subjectsMap map[string]FoundSubject) { - for key := range ss.subjectsMap { - _, ok := subjectsMap[key] - if ok { - delete(ss.subjectsMap, key) - } - } -} diff --git a/pkg/membership/membership_test.go b/pkg/membership/membership_test.go deleted file mode 100644 index bea4f36261..0000000000 --- a/pkg/membership/membership_test.go +++ /dev/null @@ -1,128 +0,0 @@ -package membership - -import ( - "testing" - - v0 "github.com/authzed/authzed-go/proto/authzed/api/v0" - "github.com/stretchr/testify/require" - - "github.com/authzed/spicedb/pkg/graph" - "github.com/authzed/spicedb/pkg/tuple" -) - -var ( - ONR = tuple.ObjectAndRelation - Ellipsis = "..." -) - -var ( - _this *v0.ObjectAndRelation - - companyOwner = graph.Leaf(ONR("folder", "company", "owner"), - tuple.User(ONR("user", "owner", Ellipsis)), - ) - companyEditor = graph.Union(ONR("folder", "company", "editor"), - graph.Leaf(_this, tuple.User(ONR("user", "writer", Ellipsis))), - companyOwner, - ) - - auditorsOwner = graph.Leaf(ONR("folder", "auditors", "owner")) - - auditorsEditor = graph.Union(ONR("folder", "auditors", "editor"), - graph.Leaf(_this), - auditorsOwner, - ) - - auditorsViewerRecursive = graph.Union(ONR("folder", "auditors", "viewer"), - graph.Leaf(_this, - tuple.User(ONR("user", "auditor", "...")), - ), - auditorsEditor, - graph.Union(ONR("folder", "auditors", "viewer")), - ) - - companyViewerRecursive = graph.Union(ONR("folder", "company", "viewer"), - graph.Union(ONR("folder", "company", "viewer"), - auditorsViewerRecursive, - graph.Leaf(_this, - tuple.User(ONR("user", "legal", "...")), - tuple.User(ONR("folder", "auditors", "viewer")), - ), - ), - companyEditor, - graph.Union(ONR("folder", "company", "viewer")), - ) -) - -func TestMembershipSet(t *testing.T) { - require := require.New(t) - ms := NewMembershipSet() - - // Add some expansion trees. - fso, ok, err := ms.AddExpansion(ONR("folder", "company", "owner"), companyOwner) - require.True(ok) - require.NoError(err) - verifySubjects(require, fso, "user:owner") - - fse, ok, err := ms.AddExpansion(ONR("folder", "company", "editor"), companyEditor) - require.True(ok) - require.NoError(err) - verifySubjects(require, fse, "user:owner", "user:writer") - - fsv, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), companyViewerRecursive) - require.True(ok) - require.NoError(err) - verifySubjects(require, fsv, "folder:auditors#viewer", "user:auditor", "user:legal", "user:owner", "user:writer") -} - -func TestMembershipSetIntersection(t *testing.T) { - require := require.New(t) - ms := NewMembershipSet() - - intersection := graph.Intersection(ONR("folder", "company", "viewer"), - graph.Leaf(_this, - tuple.User(ONR("user", "legal", "...")), - ), - graph.Leaf(_this, - tuple.User(ONR("user", "owner", "...")), - tuple.User(ONR("user", "legal", "...")), - ), - ) - - fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), intersection) - require.True(ok) - require.NoError(err) - verifySubjects(require, fso, "user:legal") -} - -func TestMembershipSetExclusion(t *testing.T) { - require := require.New(t) - ms := NewMembershipSet() - - intersection := graph.Exclusion(ONR("folder", "company", "viewer"), - graph.Leaf(_this, - tuple.User(ONR("user", "owner", "...")), - tuple.User(ONR("user", "legal", "...")), - ), - graph.Leaf(_this, - tuple.User(ONR("user", "legal", "...")), - ), - ) - - fso, ok, err := ms.AddExpansion(ONR("folder", "company", "viewer"), intersection) - require.True(ok) - require.NoError(err) - verifySubjects(require, fso, "user:owner") -} - -func verifySubjects(require *require.Assertions, fs FoundSubjects, expected ...string) { - foundSubjects := []*v0.ObjectAndRelation{} - for _, found := range fs.ListFound() { - foundSubjects = append(foundSubjects, found.Subject()) - - _, ok := fs.LookupSubject(found.Subject()) - require.True(ok) - } - - require.Equal(expected, tuple.StringsONRs(foundSubjects)) -} diff --git a/pkg/tuple/onrset.go b/pkg/tuple/onrset.go index 22caf53f42..3406a1dcc1 100644 --- a/pkg/tuple/onrset.go +++ b/pkg/tuple/onrset.go @@ -91,6 +91,18 @@ func (ons *ONRSet) With(onr *v0.ObjectAndRelation) *ONRSet { return updated } +// Union returns a copy of this ONR set with the other set's elements added in. +func (ons *ONRSet) Union(otherSet *ONRSet) *ONRSet { + updated := NewONRSet() + for _, current := range ons.onrs { + updated.Add(current) + } + for _, current := range otherSet.onrs { + updated.Add(current) + } + return updated +} + // AsSlice returns the ONRs found in the set as a slice. func (ons *ONRSet) AsSlice() []*v0.ObjectAndRelation { slice := make([]*v0.ObjectAndRelation, 0, len(ons.onrs)) diff --git a/pkg/validationfile/fileformat.go b/pkg/validationfile/fileformat.go index 7b0f86bcd0..d0749eb83f 100644 --- a/pkg/validationfile/fileformat.go +++ b/pkg/validationfile/fileformat.go @@ -149,10 +149,20 @@ func (ors ObjectRelationString) ONR() (*v0.ObjectAndRelation, *ErrorWithSource) } var ( - vsSubjectRegex = regexp.MustCompile(`(.*?)\[(?P.*)\](.*?)`) - vsObjectAndRelationRegex = regexp.MustCompile(`(.*?)<(?P[^\>]+)>(.*?)`) + vsSubjectRegex = regexp.MustCompile(`(.*?)\[(?P.*)\](.*?)`) + vsObjectAndRelationRegex = regexp.MustCompile(`(.*?)<(?P[^\>]+)>(.*?)`) + vsSubjectWithExceptionsRegex = regexp.MustCompile(`^(.+)\s*-\s*\{([^\}]+)\}$`) ) +// SubjectWithExceptions returns the subject found in a validation string, along with any exceptions. +type SubjectWithExceptions struct { + // Subject is the subject found. + Subject *v0.ObjectAndRelation + + // Exceptions are those subjects removed from the subject, if it is a wildcard. + Exceptions []*v0.ObjectAndRelation +} + // ValidationString holds a validation string containing a Subject and one or // more Relations to the parent Object. // Example: `[tenant/user:someuser#...] is ` @@ -170,17 +180,44 @@ func (vs ValidationString) SubjectString() (string, bool) { // Subject returns the subject contained in the ValidationString, if any. If // none, returns nil. -func (vs ValidationString) Subject() (*v0.ObjectAndRelation, *ErrorWithSource) { +func (vs ValidationString) Subject() (*SubjectWithExceptions, *ErrorWithSource) { subjectStr, ok := vs.SubjectString() if !ok { return nil, nil } + subjectStr = strings.TrimSpace(subjectStr) + if strings.HasSuffix(subjectStr, "}") { + result := vsSubjectWithExceptionsRegex.FindStringSubmatch(subjectStr) + if len(result) != 3 { + return nil, &ErrorWithSource{fmt.Errorf("invalid subject: %s", subjectStr), subjectStr, 0, 0} + } + + subjectONR := tuple.ParseSubjectONR(strings.TrimSpace(result[1])) + if subjectONR == nil { + return nil, &ErrorWithSource{fmt.Errorf("invalid subject: %s", result[1]), result[1], 0, 0} + } + + exceptionsString := strings.TrimSpace(result[2]) + exceptionsStringsSlice := strings.Split(exceptionsString, ",") + exceptions := make([]*v0.ObjectAndRelation, 0, len(exceptionsStringsSlice)) + for _, exceptionString := range exceptionsStringsSlice { + exceptionONR := tuple.ParseSubjectONR(strings.TrimSpace(exceptionString)) + if exceptionONR == nil { + return nil, &ErrorWithSource{fmt.Errorf("invalid subject: %s", exceptionString), exceptionString, 0, 0} + } + + exceptions = append(exceptions, exceptionONR) + } + + return &SubjectWithExceptions{subjectONR, exceptions}, nil + } + found := tuple.ParseSubjectONR(subjectStr) if found == nil { return nil, &ErrorWithSource{fmt.Errorf("invalid subject: %s", subjectStr), subjectStr, 0, 0} } - return found, nil + return &SubjectWithExceptions{found, nil}, nil } // ONRStrings returns the ONRs contained in the ValidationString, if any. diff --git a/pkg/validationfile/fileformat_test.go b/pkg/validationfile/fileformat_test.go index f0d694ee6a..92cf17946e 100644 --- a/pkg/validationfile/fileformat_test.go +++ b/pkg/validationfile/fileformat_test.go @@ -65,6 +65,18 @@ func TestValidationString(t *testing.T) { "", []string{"tenant/document:example#viewer", "tenant/document:example#builder"}, }, + { + "subject with exclusions", + "[tenant/user:someuser#... - {test/user:1,test/user:2}] is /", + "tenant/user:someuser", + []string{"tenant/document:example#viewer", "tenant/document:example#builder"}, + }, + { + "subject with bad exclusions", + "[tenant/user:someuser#... - {te1,test/user:2}] is /", + "", + []string{"tenant/document:example#viewer", "tenant/document:example#builder"}, + }, } for _, tc := range tests { @@ -72,11 +84,13 @@ func TestValidationString(t *testing.T) { require := require.New(t) vs := ValidationString(tc.input) - subject, _ := vs.Subject() + subject, err := vs.Subject() + if tc.expectedSubject == "" { require.Nil(subject) } else { - require.Equal(tc.expectedSubject, tuple.StringONR(subject)) + require.Nil(err) + require.Equal(tc.expectedSubject, tuple.StringONR(subject.Subject)) } foundONRStrings := []string{}