Skip to content

Commit

Permalink
Wire TokenReader into Sanitize functions
Browse files Browse the repository at this point in the history
- Add tests
  • Loading branch information
jhillyerd committed Nov 22, 2018
1 parent 8341af1 commit 1a94a9e
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 9 deletions.
24 changes: 15 additions & 9 deletions sanitize.go
Expand Up @@ -51,12 +51,12 @@ var (
// It returns a HTML string that has been sanitized by the policy or an empty
// string if an error has occurred (most likely as a consequence of extremely
// malformed input)
func (p *Policy) Sanitize(s string) string {
func (p *Policy) Sanitize(s string, filters ...TokenReader) string {
if strings.TrimSpace(s) == "" {
return s
}

return p.sanitize(strings.NewReader(s)).String()
return p.sanitize(strings.NewReader(s), filters...).String()
}

// SanitizeBytes takes a []byte that contains a HTML fragment or document and applies
Expand All @@ -65,25 +65,25 @@ func (p *Policy) Sanitize(s string) string {
// It returns a []byte containing the HTML that has been sanitized by the policy
// or an empty []byte if an error has occurred (most likely as a consequence of
// extremely malformed input)
func (p *Policy) SanitizeBytes(b []byte) []byte {
func (p *Policy) SanitizeBytes(b []byte, filters ...TokenReader) []byte {
if len(bytes.TrimSpace(b)) == 0 {
return b
}

return p.sanitize(bytes.NewReader(b)).Bytes()
return p.sanitize(bytes.NewReader(b), filters...).Bytes()
}

// SanitizeReader takes an io.Reader that contains a HTML fragment or document
// and applies the given policy whitelist.
//
// It returns a bytes.Buffer containing the HTML that has been sanitized by the
// policy. Errors during sanitization will merely return an empty result.
func (p *Policy) SanitizeReader(r io.Reader) *bytes.Buffer {
return p.sanitize(r)
func (p *Policy) SanitizeReader(r io.Reader, filters ...TokenReader) *bytes.Buffer {
return p.sanitize(r, filters...)
}

// Performs the actual sanitization process.
func (p *Policy) sanitize(r io.Reader) *bytes.Buffer {
func (p *Policy) sanitize(r io.Reader, filters ...TokenReader) *bytes.Buffer {

// It is possible that the developer has created the policy via:
// p := bluemonday.Policy{}
Expand All @@ -100,10 +100,16 @@ func (p *Policy) sanitize(r io.Reader) *bytes.Buffer {
skipClosingTag bool
closingTagToSkipStack []string
mostRecentlyStartedToken string
reader TokenReader
)

tokenizer := html.NewTokenizer(r)
reader := tokenizerReader{tokenizer}
// Chain together TokenReader filters
reader = &tokenizerReader{html.NewTokenizer(r)}
for _, f := range filters {
f.Source(reader)
reader = f
}

for {
token, err := reader.Token()
if token == nil {
Expand Down
5 changes: 5 additions & 0 deletions token.go
Expand Up @@ -32,6 +32,7 @@ package bluemonday
import "golang.org/x/net/html"

type TokenReader interface {
Source(source TokenReader)
Token() (*html.Token, error)
}

Expand All @@ -48,3 +49,7 @@ func (r *tokenizerReader) Token() (*html.Token, error) {
token := r.Tokenizer.Token()
return &token, nil
}

// Source is a no-op for tokenizerReader
func (r *tokenizerReader) Source(TokenReader) {
}
48 changes: 48 additions & 0 deletions token_test.go
@@ -0,0 +1,48 @@
package bluemonday

import (
"testing"

"golang.org/x/net/html"
"golang.org/x/net/html/atom"
)

type testRemoverReader struct {
source TokenReader
tagAtom atom.Atom
}

func (r *testRemoverReader) Token() (*html.Token, error) {
t, err := r.source.Token()
if err != nil {
return t, err
}
if (t.Type == html.StartTagToken || t.Type == html.EndTagToken) && t.DataAtom == r.tagAtom {
// Skip bold, return next token
return r.source.Token()
}
return t, nil
}

func (r *testRemoverReader) Source(s TokenReader) {
r.source = s
}

func TestTokenReader(t *testing.T) {
p := UGCPolicy()

input := "<p><b>A bold statement.</b></p>"
want := "<p><b>A bold statement.</b></p>"
got := p.Sanitize(input)
if got != want {
t.Errorf("got: %q, want: %q", got, want)
}

removeBold := &testRemoverReader{tagAtom: atom.B}
input = "<p><b>A bold statement.</b></p>"
want = "<p>A bold statement.</p>"
got = p.Sanitize(input, removeBold)
if got != want {
t.Errorf("got: %q, want: %q", got, want)
}
}

0 comments on commit 1a94a9e

Please sign in to comment.