Skip to content

Commit a8dbe3e

Browse files
authored
Merge pull request #672 from benluddy/text-un-marshaler
Add options to support TextMarshaler and TextUnmarshaler
2 parents d81767d + db9afc5 commit a8dbe3e

File tree

4 files changed

+411
-7
lines changed

4 files changed

+411
-7
lines changed

decode.go

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,25 @@ func (bum BinaryUnmarshalerMode) valid() bool {
749749
return bum >= 0 && bum < maxBinaryUnmarshalerMode
750750
}
751751

752+
// TextUnmarshalerMode specifies how to decode into types that implement
753+
// encoding.TextUnmarshaler.
754+
type TextUnmarshalerMode int
755+
756+
const (
757+
// TextUnmarshalerNone does not recognize TextUnmarshaler implementations during decode.
758+
TextUnmarshalerNone TextUnmarshalerMode = iota
759+
760+
// TextUnmarshalerTextString will invoke UnmarshalText on the contents of a CBOR text
761+
// string when decoding into a value that implements TextUnmarshaler.
762+
TextUnmarshalerTextString
763+
764+
maxTextUnmarshalerMode
765+
)
766+
767+
func (tum TextUnmarshalerMode) valid() bool {
768+
return tum >= 0 && tum < maxTextUnmarshalerMode
769+
}
770+
752771
// DecOptions specifies decoding options.
753772
type DecOptions struct {
754773
// DupMapKey specifies whether to enforce duplicate map key.
@@ -883,6 +902,10 @@ type DecOptions struct {
883902
// BinaryUnmarshaler specifies how to decode into types that implement
884903
// encoding.BinaryUnmarshaler.
885904
BinaryUnmarshaler BinaryUnmarshalerMode
905+
906+
// TextUnmarshaler specifies how to decode into types that implement
907+
// encoding.TextUnmarshaler.
908+
TextUnmarshaler TextUnmarshalerMode
886909
}
887910

888911
// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
@@ -1095,6 +1118,10 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore
10951118
return nil, errors.New("cbor: invalid BinaryUnmarshaler " + strconv.Itoa(int(opts.BinaryUnmarshaler)))
10961119
}
10971120

1121+
if !opts.TextUnmarshaler.valid() {
1122+
return nil, errors.New("cbor: invalid TextUnmarshaler " + strconv.Itoa(int(opts.TextUnmarshaler)))
1123+
}
1124+
10981125
dm := decMode{
10991126
dupMapKey: opts.DupMapKey,
11001127
timeTag: opts.TimeTag,
@@ -1122,6 +1149,7 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore
11221149
byteStringExpectedFormat: opts.ByteStringExpectedFormat,
11231150
bignumTag: opts.BignumTag,
11241151
binaryUnmarshaler: opts.BinaryUnmarshaler,
1152+
textUnmarshaler: opts.TextUnmarshaler,
11251153
}
11261154

11271155
return &dm, nil
@@ -1201,6 +1229,7 @@ type decMode struct {
12011229
byteStringExpectedFormat ByteStringExpectedFormatMode
12021230
bignumTag BignumTagMode
12031231
binaryUnmarshaler BinaryUnmarshalerMode
1232+
textUnmarshaler TextUnmarshalerMode
12041233
}
12051234

12061235
var defaultDecMode, _ = DecOptions{}.decMode()
@@ -1241,6 +1270,7 @@ func (dm *decMode) DecOptions() DecOptions {
12411270
ByteStringExpectedFormat: dm.byteStringExpectedFormat,
12421271
BignumTag: dm.bignumTag,
12431272
BinaryUnmarshaler: dm.binaryUnmarshaler,
1273+
TextUnmarshaler: dm.textUnmarshaler,
12441274
}
12451275
}
12461276

@@ -1530,7 +1560,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
15301560
if err != nil {
15311561
return err
15321562
}
1533-
return fillTextString(t, b, v)
1563+
return fillTextString(t, b, v, d.dm.textUnmarshaler)
15341564

15351565
case cborTypePrimitives:
15361566
_, ai, val := d.getHead()
@@ -2995,6 +3025,7 @@ var (
29953025
typeUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
29963026
typeUnexportedUnmarshaler = reflect.TypeOf((*unmarshaler)(nil)).Elem()
29973027
typeBinaryUnmarshaler = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem()
3028+
typeTextUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
29983029
typeString = reflect.TypeOf("")
29993030
typeByteSlice = reflect.TypeOf([]byte(nil))
30003031
)
@@ -3147,11 +3178,28 @@ func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts B
31473178
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
31483179
}
31493180

3150-
func fillTextString(t cborType, val []byte, v reflect.Value) error {
3181+
func fillTextString(t cborType, val []byte, v reflect.Value, tum TextUnmarshalerMode) error {
3182+
// Check if the value implements TextUnmarshaler and the mode allows it
3183+
if tum == TextUnmarshalerTextString && reflect.PointerTo(v.Type()).Implements(typeTextUnmarshaler) {
3184+
if v.CanAddr() {
3185+
v = v.Addr()
3186+
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
3187+
// The contract of TextUnmarshaler forbids retaining the input
3188+
// bytes, so no copying is required even if val is shared.
3189+
if err := u.UnmarshalText(val); err != nil {
3190+
return fmt.Errorf("cbor: cannot unmarshal text for %s: %w", v.Type(), err)
3191+
}
3192+
return nil
3193+
}
3194+
}
3195+
return errors.New("cbor: cannot set new value for " + v.Type().String())
3196+
}
3197+
31513198
if v.Kind() == reflect.String {
31523199
v.SetString(string(val))
31533200
return nil
31543201
}
3202+
31553203
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
31563204
}
31573205

decode_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4925,6 +4925,7 @@ func TestDecOptions(t *testing.T) {
49254925
ByteStringExpectedFormat: ByteStringExpectedBase64URL,
49264926
BignumTag: BignumTagForbidden,
49274927
BinaryUnmarshaler: BinaryUnmarshalerNone,
4928+
TextUnmarshaler: TextUnmarshalerTextString,
49284929
}
49294930
ov := reflect.ValueOf(opts1)
49304931
for i := 0; i < ov.NumField(); i++ {
@@ -10132,3 +10133,105 @@ func TestBignumTagMode(t *testing.T) {
1013210133
})
1013310134
}
1013410135
}
10136+
10137+
func TestDecModeInvalidTextUnmarshaler(t *testing.T) {
10138+
for _, tc := range []struct {
10139+
name string
10140+
opts DecOptions
10141+
wantErrorMsg string
10142+
}{
10143+
{
10144+
name: "below range of valid modes",
10145+
opts: DecOptions{TextUnmarshaler: -1},
10146+
wantErrorMsg: "cbor: invalid TextUnmarshaler -1",
10147+
},
10148+
{
10149+
name: "above range of valid modes",
10150+
opts: DecOptions{TextUnmarshaler: 101},
10151+
wantErrorMsg: "cbor: invalid TextUnmarshaler 101",
10152+
},
10153+
} {
10154+
t.Run(tc.name, func(t *testing.T) {
10155+
_, err := tc.opts.DecMode()
10156+
if err == nil {
10157+
t.Errorf("DecMode() didn't return an error")
10158+
} else if err.Error() != tc.wantErrorMsg {
10159+
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
10160+
}
10161+
})
10162+
}
10163+
}
10164+
10165+
type testTextUnmarshaler string
10166+
10167+
func (tu *testTextUnmarshaler) UnmarshalText(_ []byte) error {
10168+
*tu = "UnmarshalText"
10169+
return nil
10170+
}
10171+
10172+
func TestTextUnmarshalerMode(t *testing.T) {
10173+
for _, tc := range []struct {
10174+
name string
10175+
opts DecOptions
10176+
in []byte
10177+
want any
10178+
}{
10179+
{
10180+
name: "UnmarshalText is not called by default",
10181+
opts: DecOptions{},
10182+
in: []byte("\x65hello"), // "hello"
10183+
want: testTextUnmarshaler("hello"),
10184+
},
10185+
{
10186+
name: "UnmarshalText is called with TextUnmarshalerTextString",
10187+
opts: DecOptions{TextUnmarshaler: TextUnmarshalerTextString},
10188+
in: []byte("\x65hello"), // "hello"
10189+
want: testTextUnmarshaler("UnmarshalText"),
10190+
},
10191+
{
10192+
name: "default text string unmarshaling behavior is used with TextUnmarshalerNone",
10193+
opts: DecOptions{TextUnmarshaler: TextUnmarshalerNone},
10194+
in: []byte("\x65hello"), // "hello"
10195+
want: testTextUnmarshaler("hello"),
10196+
},
10197+
} {
10198+
t.Run(tc.name, func(t *testing.T) {
10199+
dm, err := tc.opts.DecMode()
10200+
if err != nil {
10201+
t.Fatal(err)
10202+
}
10203+
10204+
gotrv := reflect.New(reflect.TypeOf(tc.want))
10205+
if err := dm.Unmarshal(tc.in, gotrv.Interface()); err != nil {
10206+
t.Fatal(err)
10207+
}
10208+
10209+
got := gotrv.Elem().Interface()
10210+
if !reflect.DeepEqual(tc.want, got) {
10211+
t.Errorf("want: %v, got: %v", tc.want, got)
10212+
}
10213+
})
10214+
}
10215+
}
10216+
10217+
type errorTextUnmarshaler struct{}
10218+
10219+
func (u *errorTextUnmarshaler) UnmarshalText([]byte) error {
10220+
return errors.New("test")
10221+
}
10222+
10223+
func TestTextUnmarshalerModeError(t *testing.T) {
10224+
dec, err := DecOptions{TextUnmarshaler: TextUnmarshalerTextString}.DecMode()
10225+
if err != nil {
10226+
t.Fatal(err)
10227+
}
10228+
10229+
err = dec.Unmarshal([]byte{0x61, 'a'}, new(errorTextUnmarshaler))
10230+
if err == nil {
10231+
t.Fatal("expected non-nil error")
10232+
}
10233+
10234+
if got, want := err.Error(), "cbor: cannot unmarshal text for *cbor.errorTextUnmarshaler: test"; got != want {
10235+
t.Errorf("want: %q, got: %q", want, got)
10236+
}
10237+
}

encode.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,24 @@ func (bmm BinaryMarshalerMode) valid() bool {
510510
return bmm >= 0 && bmm < maxBinaryMarshalerMode
511511
}
512512

513+
// TextMarshalerMode specifies how to encode types that implement encoding.TextMarshaler.
514+
type TextMarshalerMode int
515+
516+
const (
517+
// TextMarshalerNone does not recognize TextMarshaler implementations during encode.
518+
// This is the default behavior.
519+
TextMarshalerNone TextMarshalerMode = iota
520+
521+
// TextMarshalerTextString encodes the output of MarshalText to a CBOR text string.
522+
TextMarshalerTextString
523+
524+
maxTextMarshalerMode
525+
)
526+
527+
func (tmm TextMarshalerMode) valid() bool {
528+
return tmm >= 0 && tmm < maxTextMarshalerMode
529+
}
530+
513531
// EncOptions specifies encoding options.
514532
type EncOptions struct {
515533
// Sort specifies sorting order.
@@ -567,6 +585,9 @@ type EncOptions struct {
567585

568586
// BinaryMarshaler specifies how to encode types that implement encoding.BinaryMarshaler.
569587
BinaryMarshaler BinaryMarshalerMode
588+
589+
// TextMarshaler specifies how to encode types that implement encoding.TextMarshaler.
590+
TextMarshaler TextMarshalerMode
570591
}
571592

572593
// CanonicalEncOptions returns EncOptions for "Canonical CBOR" encoding,
@@ -777,6 +798,9 @@ func (opts EncOptions) encMode() (*encMode, error) { //nolint:gocritic // ignore
777798
if !opts.BinaryMarshaler.valid() {
778799
return nil, errors.New("cbor: invalid BinaryMarshaler " + strconv.Itoa(int(opts.BinaryMarshaler)))
779800
}
801+
if !opts.TextMarshaler.valid() {
802+
return nil, errors.New("cbor: invalid TextMarshaler " + strconv.Itoa(int(opts.TextMarshaler)))
803+
}
780804
em := encMode{
781805
sort: opts.Sort,
782806
shortestFloat: opts.ShortestFloat,
@@ -796,6 +820,7 @@ func (opts EncOptions) encMode() (*encMode, error) { //nolint:gocritic // ignore
796820
byteSliceLaterEncodingTag: byteSliceLaterEncodingTag,
797821
byteArray: opts.ByteArray,
798822
binaryMarshaler: opts.BinaryMarshaler,
823+
textMarshaler: opts.TextMarshaler,
799824
}
800825
return &em, nil
801826
}
@@ -841,6 +866,7 @@ type encMode struct {
841866
byteSliceLaterEncodingTag uint64
842867
byteArray ByteArrayMode
843868
binaryMarshaler BinaryMarshalerMode
869+
textMarshaler TextMarshalerMode
844870
}
845871

846872
var defaultEncMode, _ = EncOptions{}.encMode()
@@ -933,6 +959,7 @@ func (em *encMode) EncOptions() EncOptions {
933959
ByteSliceLaterFormat: em.byteSliceLaterFormat,
934960
ByteArray: em.byteArray,
935961
BinaryMarshaler: em.binaryMarshaler,
962+
TextMarshaler: em.textMarshaler,
936963
}
937964
}
938965

@@ -1704,6 +1731,54 @@ func (bme binaryMarshalerEncoder) isEmpty(em *encMode, v reflect.Value) (bool, e
17041731
return len(data) == 0, nil
17051732
}
17061733

1734+
type textMarshalerEncoder struct {
1735+
alternateEncode encodeFunc
1736+
alternateIsEmpty isEmptyFunc
1737+
}
1738+
1739+
func (tme textMarshalerEncoder) encode(e *bytes.Buffer, em *encMode, v reflect.Value) error {
1740+
if em.textMarshaler == TextMarshalerNone {
1741+
return tme.alternateEncode(e, em, v)
1742+
}
1743+
1744+
vt := v.Type()
1745+
m, ok := v.Interface().(encoding.TextMarshaler)
1746+
if !ok {
1747+
pv := reflect.New(vt)
1748+
pv.Elem().Set(v)
1749+
m = pv.Interface().(encoding.TextMarshaler)
1750+
}
1751+
data, err := m.MarshalText()
1752+
if err != nil {
1753+
return fmt.Errorf("cbor: cannot marshal text for %s: %w", vt, err)
1754+
}
1755+
if b := em.encTagBytes(vt); b != nil {
1756+
e.Write(b)
1757+
}
1758+
1759+
encodeHead(e, byte(cborTypeTextString), uint64(len(data)))
1760+
e.Write(data)
1761+
return nil
1762+
}
1763+
1764+
func (tme textMarshalerEncoder) isEmpty(em *encMode, v reflect.Value) (bool, error) {
1765+
if em.textMarshaler == TextMarshalerNone {
1766+
return tme.alternateIsEmpty(em, v)
1767+
}
1768+
1769+
m, ok := v.Interface().(encoding.TextMarshaler)
1770+
if !ok {
1771+
pv := reflect.New(v.Type())
1772+
pv.Elem().Set(v)
1773+
m = pv.Interface().(encoding.TextMarshaler)
1774+
}
1775+
data, err := m.MarshalText()
1776+
if err != nil {
1777+
return false, fmt.Errorf("cbor: cannot marshal text for %s: %w", v.Type(), err)
1778+
}
1779+
return len(data) == 0, nil
1780+
}
1781+
17071782
func encodeMarshalerType(e *bytes.Buffer, em *encMode, v reflect.Value) error {
17081783
if em.tagsMd == TagsForbidden && v.Type() == typeRawTag {
17091784
return errors.New("cbor: cannot encode cbor.RawTag when TagsMd is TagsForbidden")
@@ -1810,6 +1885,7 @@ func encodeHead(e *bytes.Buffer, t byte, n uint64) int {
18101885
var (
18111886
typeMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem()
18121887
typeBinaryMarshaler = reflect.TypeOf((*encoding.BinaryMarshaler)(nil)).Elem()
1888+
typeTextMarshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
18131889
typeRawMessage = reflect.TypeOf(RawMessage(nil))
18141890
typeByteString = reflect.TypeOf(ByteString(""))
18151891
)
@@ -1852,6 +1928,17 @@ func getEncodeFuncInternal(t reflect.Type) (ef encodeFunc, ief isEmptyFunc, izf
18521928
ief = bme.isEmpty
18531929
}()
18541930
}
1931+
if reflect.PointerTo(t).Implements(typeTextMarshaler) {
1932+
defer func() {
1933+
// capture encoding method used for modes that disable TextMarshaler
1934+
tme := textMarshalerEncoder{
1935+
alternateEncode: ef,
1936+
alternateIsEmpty: ief,
1937+
}
1938+
ef = tme.encode
1939+
ief = tme.isEmpty
1940+
}()
1941+
}
18551942
switch k {
18561943
case reflect.Bool:
18571944
return encodeBool, isEmptyBool, getIsZeroFunc(t)

0 commit comments

Comments
 (0)