Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

creds/aws: Add support for DSA signature verification for EC2 #12340

Merged
merged 8 commits into from Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions builtin/credential/aws/backend_test.go
Expand Up @@ -1129,6 +1129,22 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndAccessListIdentity(t *testing
}
}

// Configure additional metadata to be returned for ec2 logins.
identity := map[string]interface{}{
"ec2_metadata": []string{"instance_id", "region", "ami_id"},
}

// store the identity
_, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Storage: storage,
Path: "config/identity",
Data: identity,
})
if err != nil {
t.Fatal(err)
}

loginInput := map[string]interface{}{
"pkcs7": pkcs7,
"nonce": "vault-client-nonce",
Expand Down Expand Up @@ -1241,6 +1257,7 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndAccessListIdentity(t *testing
delete(loginInput, "pkcs7")
loginInput["identity"] = identityDoc
loginInput["signature"] = identityDocSig

resp, err = b.HandleRequest(context.Background(), loginRequest)
if err != nil {
t.Fatal(err)
Expand Down
6 changes: 3 additions & 3 deletions builtin/credential/aws/path_login.go
Expand Up @@ -20,13 +20,13 @@ import (
awsClient "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/fullsailor/pkcs7"
"github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-retryablehttp"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/go-secure-stdlib/strutil"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/builtin/credential/aws/pkcs7"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/cidrutil"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
Expand Down Expand Up @@ -348,8 +348,8 @@ func (b *backend) parseIdentityDocument(ctx context.Context, s logical.Storage,

// Verify extracts the authenticated attributes in the PKCS#7 signature, and verifies
// the authenticity of the content using 'dsa.PublicKey' embedded in the public certificate.
if pkcs7Data.Verify() != nil {
return nil, fmt.Errorf("failed to verify the signature")
if err := pkcs7Data.Verify(); err != nil {
return nil, fmt.Errorf("failed to verify the signature: %w", err)
}

// Check if the signature has content inside of it
Expand Down
5 changes: 5 additions & 0 deletions builtin/credential/aws/pkcs7/README.md
@@ -0,0 +1,5 @@
# PKCS7

This code is used to verify PKCS7 signatures for the EC2 auth method. The code
was forked from [mozilla-services/pkcs7](https://github.com/mozilla-services/pkcs7)
and modified for Vault.
251 changes: 251 additions & 0 deletions builtin/credential/aws/pkcs7/ber.go
@@ -0,0 +1,251 @@
package pkcs7

import (
"bytes"
"errors"
)

var encodeIndent = 0

type asn1Object interface {
EncodeTo(writer *bytes.Buffer) error
}

type asn1Structured struct {
tagBytes []byte
content []asn1Object
}

func (s asn1Structured) EncodeTo(out *bytes.Buffer) error {
//fmt.Printf("%s--> tag: % X\n", strings.Repeat("| ", encodeIndent), s.tagBytes)
encodeIndent++
inner := new(bytes.Buffer)
for _, obj := range s.content {
err := obj.EncodeTo(inner)
if err != nil {
return err
}
}
encodeIndent--
out.Write(s.tagBytes)
encodeLength(out, inner.Len())
out.Write(inner.Bytes())
return nil
}

type asn1Primitive struct {
tagBytes []byte
length int
content []byte
}

func (p asn1Primitive) EncodeTo(out *bytes.Buffer) error {
_, err := out.Write(p.tagBytes)
if err != nil {
return err
}
if err = encodeLength(out, p.length); err != nil {
return err
}
//fmt.Printf("%s--> tag: % X length: %d\n", strings.Repeat("| ", encodeIndent), p.tagBytes, p.length)
//fmt.Printf("%s--> content length: %d\n", strings.Repeat("| ", encodeIndent), len(p.content))
out.Write(p.content)

return nil
}

func ber2der(ber []byte) ([]byte, error) {
if len(ber) == 0 {
return nil, errors.New("ber2der: input ber is empty")
}
//fmt.Printf("--> ber2der: Transcoding %d bytes\n", len(ber))
out := new(bytes.Buffer)

obj, _, err := readObject(ber, 0)
if err != nil {
return nil, err
}
obj.EncodeTo(out)

// if offset < len(ber) {
// return nil, fmt.Errorf("ber2der: Content longer than expected. Got %d, expected %d", offset, len(ber))
//}

return out.Bytes(), nil
}

// encodes lengths that are longer than 127 into string of bytes
func marshalLongLength(out *bytes.Buffer, i int) (err error) {
n := lengthLength(i)

for ; n > 0; n-- {
err = out.WriteByte(byte(i >> uint((n-1)*8)))
if err != nil {
return
}
}

return nil
}

// computes the byte length of an encoded length value
func lengthLength(i int) (numBytes int) {
numBytes = 1
for i > 255 {
numBytes++
i >>= 8
}
return
}

// encodes the length in DER format
// If the length fits in 7 bits, the value is encoded directly.
//
// Otherwise, the number of bytes to encode the length is first determined.
// This number is likely to be 4 or less for a 32bit length. This number is
// added to 0x80. The length is encoded in big endian encoding follow after
//
// Examples:
// length | byte 1 | bytes n
// 0 | 0x00 | -
// 120 | 0x78 | -
// 200 | 0x81 | 0xC8
// 500 | 0x82 | 0x01 0xF4
//
func encodeLength(out *bytes.Buffer, length int) (err error) {
if length >= 128 {
l := lengthLength(length)
err = out.WriteByte(0x80 | byte(l))
if err != nil {
return
}
err = marshalLongLength(out, length)
if err != nil {
return
}
} else {
err = out.WriteByte(byte(length))
if err != nil {
return
}
}
return
}

func readObject(ber []byte, offset int) (asn1Object, int, error) {
berLen := len(ber)
if offset >= berLen {
return nil, 0, errors.New("ber2der: offset is after end of ber data")
}
tagStart := offset
b := ber[offset]
offset++
if offset >= berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
tag := b & 0x1F // last 5 bits
if tag == 0x1F {
tag = 0
for ber[offset] >= 0x80 {
tag = tag*128 + ber[offset] - 0x80
offset++
if offset > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
}
// jvehent 20170227: this doesn't appear to be used anywhere...
//tag = tag*128 + ber[offset] - 0x80
offset++
if offset > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
}
tagEnd := offset

kind := b & 0x20
if kind == 0 {
debugprint("--> Primitive\n")
} else {
debugprint("--> Constructed\n")
}
// read length
var length int
l := ber[offset]
offset++
if offset > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
hack := 0
if l > 0x80 {
numberOfBytes := (int)(l & 0x7F)
if numberOfBytes > 4 { // int is only guaranteed to be 32bit
return nil, 0, errors.New("ber2der: BER tag length too long")
}
if numberOfBytes == 4 && (int)(ber[offset]) > 0x7F {
return nil, 0, errors.New("ber2der: BER tag length is negative")
}
if (int)(ber[offset]) == 0x0 {
return nil, 0, errors.New("ber2der: BER tag length has leading zero")
}
debugprint("--> (compute length) indicator byte: %x\n", l)
debugprint("--> (compute length) length bytes: % X\n", ber[offset:offset+numberOfBytes])
for i := 0; i < numberOfBytes; i++ {
length = length*256 + (int)(ber[offset])
offset++
if offset > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
}
} else if l == 0x80 {
// find length by searching content
markerIndex := bytes.LastIndex(ber[offset:], []byte{0x0, 0x0})
if markerIndex == -1 {
return nil, 0, errors.New("ber2der: Invalid BER format")
}
length = markerIndex
hack = 2
debugprint("--> (compute length) marker found at offset: %d\n", markerIndex+offset)
} else {
length = (int)(l)
}
if length < 0 {
return nil, 0, errors.New("ber2der: invalid negative value found in BER tag length")
}
//fmt.Printf("--> length : %d\n", length)
contentEnd := offset + length
if contentEnd > len(ber) {
return nil, 0, errors.New("ber2der: BER tag length is more than available data")
}
debugprint("--> content start : %d\n", offset)
debugprint("--> content end : %d\n", contentEnd)
debugprint("--> content : % X\n", ber[offset:contentEnd])
var obj asn1Object
if kind == 0 {
obj = asn1Primitive{
tagBytes: ber[tagStart:tagEnd],
length: length,
content: ber[offset:contentEnd],
}
} else {
var subObjects []asn1Object
for offset < contentEnd {
var subObj asn1Object
var err error
subObj, offset, err = readObject(ber[:contentEnd], offset)
if err != nil {
return nil, 0, err
}
subObjects = append(subObjects, subObj)
}
obj = asn1Structured{
tagBytes: ber[tagStart:tagEnd],
content: subObjects,
}
}

return obj, contentEnd + hack, nil
}

func debugprint(format string, a ...interface{}) {
//fmt.Printf(format, a)
}
62 changes: 62 additions & 0 deletions builtin/credential/aws/pkcs7/ber_test.go
@@ -0,0 +1,62 @@
package pkcs7

import (
"bytes"
"encoding/asn1"
"strings"
"testing"
)

func TestBer2Der(t *testing.T) {
// indefinite length fixture
ber := []byte{0x30, 0x80, 0x02, 0x01, 0x01, 0x00, 0x00}
expected := []byte{0x30, 0x03, 0x02, 0x01, 0x01}
der, err := ber2der(ber)
if err != nil {
t.Fatalf("ber2der failed with error: %v", err)
}
if !bytes.Equal(der, expected) {
t.Errorf("ber2der result did not match.\n\tExpected: % X\n\tActual: % X", expected, der)
}

if der2, err := ber2der(der); err != nil {
t.Errorf("ber2der on DER bytes failed with error: %v", err)
} else {
if !bytes.Equal(der, der2) {
t.Error("ber2der is not idempotent")
}
}
var thing struct {
Number int
}
rest, err := asn1.Unmarshal(der, &thing)
if err != nil {
t.Errorf("Cannot parse resulting DER because: %v", err)
} else if len(rest) > 0 {
t.Errorf("Resulting DER has trailing data: % X", rest)
}
}

func TestBer2Der_Negatives(t *testing.T) {
fixtures := []struct {
Input []byte
ErrorContains string
}{
{[]byte{0x30, 0x85}, "tag length too long"},
{[]byte{0x30, 0x84, 0x80, 0x0, 0x0, 0x0}, "length is negative"},
{[]byte{0x30, 0x82, 0x0, 0x1}, "length has leading zero"},
{[]byte{0x30, 0x80, 0x1, 0x2}, "Invalid BER format"},
{[]byte{0x30, 0x03, 0x01, 0x02}, "length is more than available data"},
{[]byte{0x30}, "end of ber data reached"},
}

for _, fixture := range fixtures {
_, err := ber2der(fixture.Input)
if err == nil {
t.Errorf("No error thrown. Expected: %s", fixture.ErrorContains)
}
if !strings.Contains(err.Error(), fixture.ErrorContains) {
t.Errorf("Unexpected error thrown.\n\tExpected: /%s/\n\tActual: %s", fixture.ErrorContains, err.Error())
}
}
}