blob: 67a019db2d40677d68517e2d44330419f12e743b [file] [log] [blame]
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package asn1
import (
"bytes"
"errors"
"fmt"
"io"
"math/big"
"reflect"
"time"
"unicode/utf8"
)
// A forkableWriter is an in-memory buffer that can be
// 'forked' to create new forkableWriters that bracket the
// original. After
// pre, post := w.fork()
// the overall sequence of bytes represented is logically w+pre+post.
type forkableWriter struct {
*bytes.Buffer
pre, post *forkableWriter
}
func newForkableWriter() *forkableWriter {
return &forkableWriter{new(bytes.Buffer), nil, nil}
}
func (f *forkableWriter) fork() (pre, post *forkableWriter) {
if f.pre != nil || f.post != nil {
panic("have already forked")
}
f.pre = newForkableWriter()
f.post = newForkableWriter()
return f.pre, f.post
}
func (f *forkableWriter) Len() (l int) {
l += f.Buffer.Len()
if f.pre != nil {
l += f.pre.Len()
}
if f.post != nil {
l += f.post.Len()
}
return
}
func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) {
n, err = out.Write(f.Bytes())
if err != nil {
return
}
var nn int
if f.pre != nil {
nn, err = f.pre.writeTo(out)
n += nn
if err != nil {
return
}
}
if f.post != nil {
nn, err = f.post.writeTo(out)
n += nn
}
return
}
func marshalBase128Int(out *forkableWriter, n int64) (err error) {
if n == 0 {
err = out.WriteByte(0)
return
}
l := 0
for i := n; i > 0; i >>= 7 {
l++
}
for i := l - 1; i >= 0; i-- {
o := byte(n >> uint(i*7))
o &= 0x7f
if i != 0 {
o |= 0x80
}
err = out.WriteByte(o)
if err != nil {
return
}
}
return nil
}
func marshalInt64(out *forkableWriter, i int64) (err error) {
n := int64Length(i)
for ; n > 0; n-- {
err = out.WriteByte(byte(i >> uint((n-1)*8)))
if err != nil {
return
}
}
return nil
}
func int64Length(i int64) (numBytes int) {
numBytes = 1
for i > 127 {
numBytes++
i >>= 8
}
for i < -128 {
numBytes++
i >>= 8
}
return
}
func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
if n.Sign() < 0 {
// A negative number has to be converted to two's-complement
// form. So we'll subtract 1 and invert. If the
// most-significant-bit isn't set then we'll need to pad the
// beginning with 0xff in order to keep the number negative.
nMinus1 := new(big.Int).Neg(n)
nMinus1.Sub(nMinus1, bigOne)
bytes := nMinus1.Bytes()
for i := range bytes {
bytes[i] ^= 0xff
}
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
err = out.WriteByte(0xff)
if err != nil {
return
}
}
_, err = out.Write(bytes)
} else if n.Sign() == 0 {
// Zero is written as a single 0 zero rather than no bytes.
err = out.WriteByte(0x00)
} else {
bytes := n.Bytes()
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
// We'll have to pad this with 0x00 in order to stop it
// looking like a negative number.
err = out.WriteByte(0)
if err != nil {
return
}
}
_, err = out.Write(bytes)
}
return
}
func marshalLength(out *forkableWriter, i int) (err error) {
n := lengthLength(i)
for ; n > 0; n-- {
err = out.WriteByte(byte(i >> uint((n-1)*8)))
if err != nil {
return
}
}
return nil
}
func lengthLength(i int) (numBytes int) {
numBytes = 1
for i > 255 {
numBytes++
i >>= 8
}
return
}
func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) {
b := uint8(t.class) << 6
if t.isCompound {
b |= 0x20
}
if t.tag >= 31 {
b |= 0x1f
err = out.WriteByte(b)
if err != nil {
return
}
err = marshalBase128Int(out, int64(t.tag))
if err != nil {
return
}
} else {
b |= uint8(t.tag)
err = out.WriteByte(b)
if err != nil {
return
}
}
if t.length >= 128 {
l := lengthLength(t.length)
err = out.WriteByte(0x80 | byte(l))
if err != nil {
return
}
err = marshalLength(out, t.length)
if err != nil {
return
}
} else {
err = out.WriteByte(byte(t.length))
if err != nil {
return
}
}
return nil
}
func marshalBitString(out *forkableWriter, b BitString) (err error) {
paddingBits := byte((8 - b.BitLength%8) % 8)
err = out.WriteByte(paddingBits)
if err != nil {
return
}
_, err = out.Write(b.Bytes)
return
}
func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) {
if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
return StructuralError{"invalid object identifier"}
}
err = marshalBase128Int(out, int64(oid[0]*40+oid[1]))
if err != nil {
return
}
for i := 2; i < len(oid); i++ {
err = marshalBase128Int(out, int64(oid[i]))
if err != nil {
return
}
}
return
}
func marshalPrintableString(out *forkableWriter, s string) (err error) {
b := []byte(s)
for _, c := range b {
if !isPrintable(c) {
return StructuralError{"PrintableString contains invalid character"}
}
}
_, err = out.Write(b)
return
}
func marshalIA5String(out *forkableWriter, s string) (err error) {
b := []byte(s)
for _, c := range b {
if c > 127 {
return StructuralError{"IA5String contains invalid character"}
}
}
_, err = out.Write(b)
return
}
func marshalUTF8String(out *forkableWriter, s string) (err error) {
_, err = out.Write([]byte(s))
return
}
func marshalTwoDigits(out *forkableWriter, v int) (err error) {
err = out.WriteByte(byte('0' + (v/10)%10))
if err != nil {
return
}
return out.WriteByte(byte('0' + v%10))
}
func marshalFourDigits(out *forkableWriter, v int) (err error) {
var bytes [4]byte
for i := range bytes {
bytes[3-i] = '0' + byte(v%10)
v /= 10
}
_, err = out.Write(bytes[:])
return
}
func outsideUTCRange(t time.Time) bool {
year := t.Year()
return year < 1950 || year >= 2050
}
func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
year := t.Year()
switch {
case 1950 <= year && year < 2000:
err = marshalTwoDigits(out, int(year-1900))
case 2000 <= year && year < 2050:
err = marshalTwoDigits(out, int(year-2000))
default:
return StructuralError{"cannot represent time as UTCTime"}
}
if err != nil {
return
}
return marshalTimeCommon(out, t)
}
func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) {
year := t.Year()
if year < 0 || year > 9999 {
return StructuralError{"cannot represent time as GeneralizedTime"}
}
if err = marshalFourDigits(out, year); err != nil {
return
}
return marshalTimeCommon(out, t)
}
func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) {
_, month, day := t.Date()
err = marshalTwoDigits(out, int(month))
if err != nil {
return
}
err = marshalTwoDigits(out, day)
if err != nil {
return
}
hour, min, sec := t.Clock()
err = marshalTwoDigits(out, hour)
if err != nil {
return
}
err = marshalTwoDigits(out, min)
if err != nil {
return
}
err = marshalTwoDigits(out, sec)
if err != nil {
return
}
_, offset := t.Zone()
switch {
case offset/60 == 0:
err = out.WriteByte('Z')
return
case offset > 0:
err = out.WriteByte('+')
case offset < 0:
err = out.WriteByte('-')
}
if err != nil {
return
}
offsetMinutes := offset / 60
if offsetMinutes < 0 {
offsetMinutes = -offsetMinutes
}
err = marshalTwoDigits(out, offsetMinutes/60)
if err != nil {
return
}
err = marshalTwoDigits(out, offsetMinutes%60)
return
}
func stripTagAndLength(in []byte) []byte {
_, offset, err := parseTagAndLength(in, 0)
if err != nil {
return in
}
return in[offset:]
}
func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) {
switch value.Type() {
case flagType:
return nil
case timeType:
t := value.Interface().(time.Time)
if params.timeType == tagGeneralizedTime || outsideUTCRange(t) {
return marshalGeneralizedTime(out, t)
} else {
return marshalUTCTime(out, t)
}
case bitStringType:
return marshalBitString(out, value.Interface().(BitString))
case objectIdentifierType:
return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
case bigIntType:
return marshalBigInt(out, value.Interface().(*big.Int))
}
switch v := value; v.Kind() {
case reflect.Bool:
if v.Bool() {
return out.WriteByte(255)
} else {
return out.WriteByte(0)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return marshalInt64(out, int64(v.Int()))
case reflect.Struct:
t := v.Type()
startingField := 0
// If the first element of the structure is a non-empty
// RawContents, then we don't bother serializing the rest.
if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
s := v.Field(0)
if s.Len() > 0 {
bytes := make([]byte, s.Len())
for i := 0; i < s.Len(); i++ {
bytes[i] = uint8(s.Index(i).Uint())
}
/* The RawContents will contain the tag and
* length fields but we'll also be writing
* those ourselves, so we strip them out of
* bytes */
_, err = out.Write(stripTagAndLength(bytes))
return
} else {
startingField = 1
}
}
for i := startingField; i < t.NumField(); i++ {
var pre *forkableWriter
pre, out = out.fork()
err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
if err != nil {
return
}
}
return
case reflect.Slice:
sliceType := v.Type()
if sliceType.Elem().Kind() == reflect.Uint8 {
bytes := make([]byte, v.Len())
for i := 0; i < v.Len(); i++ {
bytes[i] = uint8(v.Index(i).Uint())
}
_, err = out.Write(bytes)
return
}
var fp fieldParameters
for i := 0; i < v.Len(); i++ {
var pre *forkableWriter
pre, out = out.fork()
err = marshalField(pre, v.Index(i), fp)
if err != nil {
return
}
}
return
case reflect.String:
switch params.stringType {
case tagIA5String:
return marshalIA5String(out, v.String())
case tagPrintableString:
return marshalPrintableString(out, v.String())
default:
return marshalUTF8String(out, v.String())
}
}
return StructuralError{"unknown Go type"}
}
func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) {
// If the field is an interface{} then recurse into it.
if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
return marshalField(out, v.Elem(), params)
}
if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
return
}
if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
defaultValue := reflect.New(v.Type()).Elem()
defaultValue.SetInt(*params.defaultValue)
if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
return
}
}
// If no default value is given then the zero value for the type is
// assumed to be the default value. This isn't obviously the correct
// behaviour, but it's what Go has traditionally done.
if params.optional && params.defaultValue == nil {
if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
return
}
}
if v.Type() == rawValueType {
rv := v.Interface().(RawValue)
if len(rv.FullBytes) != 0 {
_, err = out.Write(rv.FullBytes)
} else {
err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
if err != nil {
return
}
_, err = out.Write(rv.Bytes)
}
return
}
tag, isCompound, ok := getUniversalType(v.Type())
if !ok {
err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
return
}
class := classUniversal
if params.timeType != 0 && tag != tagUTCTime {
return StructuralError{"explicit time type given to non-time member"}
}
if params.stringType != 0 && tag != tagPrintableString {
return StructuralError{"explicit string type given to non-string member"}
}
switch tag {
case tagPrintableString:
if params.stringType == 0 {
// This is a string without an explicit string type. We'll use
// a PrintableString if the character set in the string is
// sufficiently limited, otherwise we'll use a UTF8String.
for _, r := range v.String() {
if r >= utf8.RuneSelf || !isPrintable(byte(r)) {
if !utf8.ValidString(v.String()) {
return errors.New("asn1: string not valid UTF-8")
}
tag = tagUTF8String
break
}
}
} else {
tag = params.stringType
}
case tagUTCTime:
if params.timeType == tagGeneralizedTime || outsideUTCRange(v.Interface().(time.Time)) {
tag = tagGeneralizedTime
}
}
if params.set {
if tag != tagSequence {
return StructuralError{"non sequence tagged as set"}
}
tag = tagSet
}
tags, body := out.fork()
err = marshalBody(body, v, params)
if err != nil {
return
}
bodyLen := body.Len()
var explicitTag *forkableWriter
if params.explicit {
explicitTag, tags = tags.fork()
}
if !params.explicit && params.tag != nil {
// implicit tag.
tag = *params.tag
class = classContextSpecific
}
err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound})
if err != nil {
return
}
if params.explicit {
err = marshalTagAndLength(explicitTag, tagAndLength{
class: classContextSpecific,
tag: *params.tag,
length: bodyLen + tags.Len(),
isCompound: true,
})
}
return nil
}
// Marshal returns the ASN.1 encoding of val.
//
// In addition to the struct tags recognised by Unmarshal, the following can be
// used:
//
// ia5: causes strings to be marshaled as ASN.1, IA5 strings
// omitempty: causes empty slices to be skipped
// printable: causes strings to be marshaled as ASN.1, PrintableString strings.
// utf8: causes strings to be marshaled as ASN.1, UTF8 strings
func Marshal(val interface{}) ([]byte, error) {
var out bytes.Buffer
v := reflect.ValueOf(val)
f := newForkableWriter()
err := marshalField(f, v, fieldParameters{})
if err != nil {
return nil, err
}
_, err = f.writeTo(&out)
return out.Bytes(), nil
}