@@ -510,6 +510,24 @@ func (bmm BinaryMarshalerMode) valid() bool {
510510 return bmm >= 0 && bmm < maxBinaryMarshalerMode
511511}
512512
513+ // TextMarshalerMode specifies how to encode types that implement encoding.TextMarshaler.
514+ type TextMarshalerMode int
515+
516+ const (
517+ // TextMarshalerNone does not recognize TextMarshaler implementations during encode.
518+ // This is the default behavior.
519+ TextMarshalerNone TextMarshalerMode = iota
520+
521+ // TextMarshalerTextString encodes the output of MarshalText to a CBOR text string.
522+ TextMarshalerTextString
523+
524+ maxTextMarshalerMode
525+ )
526+
527+ func (tmm TextMarshalerMode ) valid () bool {
528+ return tmm >= 0 && tmm < maxTextMarshalerMode
529+ }
530+
513531// EncOptions specifies encoding options.
514532type EncOptions struct {
515533 // Sort specifies sorting order.
@@ -567,6 +585,9 @@ type EncOptions struct {
567585
568586 // BinaryMarshaler specifies how to encode types that implement encoding.BinaryMarshaler.
569587 BinaryMarshaler BinaryMarshalerMode
588+
589+ // TextMarshaler specifies how to encode types that implement encoding.TextMarshaler.
590+ TextMarshaler TextMarshalerMode
570591}
571592
572593// CanonicalEncOptions returns EncOptions for "Canonical CBOR" encoding,
@@ -777,6 +798,9 @@ func (opts EncOptions) encMode() (*encMode, error) { //nolint:gocritic // ignore
777798 if ! opts .BinaryMarshaler .valid () {
778799 return nil , errors .New ("cbor: invalid BinaryMarshaler " + strconv .Itoa (int (opts .BinaryMarshaler )))
779800 }
801+ if ! opts .TextMarshaler .valid () {
802+ return nil , errors .New ("cbor: invalid TextMarshaler " + strconv .Itoa (int (opts .TextMarshaler )))
803+ }
780804 em := encMode {
781805 sort : opts .Sort ,
782806 shortestFloat : opts .ShortestFloat ,
@@ -796,6 +820,7 @@ func (opts EncOptions) encMode() (*encMode, error) { //nolint:gocritic // ignore
796820 byteSliceLaterEncodingTag : byteSliceLaterEncodingTag ,
797821 byteArray : opts .ByteArray ,
798822 binaryMarshaler : opts .BinaryMarshaler ,
823+ textMarshaler : opts .TextMarshaler ,
799824 }
800825 return & em , nil
801826}
@@ -841,6 +866,7 @@ type encMode struct {
841866 byteSliceLaterEncodingTag uint64
842867 byteArray ByteArrayMode
843868 binaryMarshaler BinaryMarshalerMode
869+ textMarshaler TextMarshalerMode
844870}
845871
846872var defaultEncMode , _ = EncOptions {}.encMode ()
@@ -933,6 +959,7 @@ func (em *encMode) EncOptions() EncOptions {
933959 ByteSliceLaterFormat : em .byteSliceLaterFormat ,
934960 ByteArray : em .byteArray ,
935961 BinaryMarshaler : em .binaryMarshaler ,
962+ TextMarshaler : em .textMarshaler ,
936963 }
937964}
938965
@@ -1704,6 +1731,54 @@ func (bme binaryMarshalerEncoder) isEmpty(em *encMode, v reflect.Value) (bool, e
17041731 return len (data ) == 0 , nil
17051732}
17061733
1734+ type textMarshalerEncoder struct {
1735+ alternateEncode encodeFunc
1736+ alternateIsEmpty isEmptyFunc
1737+ }
1738+
1739+ func (tme textMarshalerEncoder ) encode (e * bytes.Buffer , em * encMode , v reflect.Value ) error {
1740+ if em .textMarshaler == TextMarshalerNone {
1741+ return tme .alternateEncode (e , em , v )
1742+ }
1743+
1744+ vt := v .Type ()
1745+ m , ok := v .Interface ().(encoding.TextMarshaler )
1746+ if ! ok {
1747+ pv := reflect .New (vt )
1748+ pv .Elem ().Set (v )
1749+ m = pv .Interface ().(encoding.TextMarshaler )
1750+ }
1751+ data , err := m .MarshalText ()
1752+ if err != nil {
1753+ return fmt .Errorf ("cbor: cannot marshal text for %s: %w" , vt , err )
1754+ }
1755+ if b := em .encTagBytes (vt ); b != nil {
1756+ e .Write (b )
1757+ }
1758+
1759+ encodeHead (e , byte (cborTypeTextString ), uint64 (len (data )))
1760+ e .Write (data )
1761+ return nil
1762+ }
1763+
1764+ func (tme textMarshalerEncoder ) isEmpty (em * encMode , v reflect.Value ) (bool , error ) {
1765+ if em .textMarshaler == TextMarshalerNone {
1766+ return tme .alternateIsEmpty (em , v )
1767+ }
1768+
1769+ m , ok := v .Interface ().(encoding.TextMarshaler )
1770+ if ! ok {
1771+ pv := reflect .New (v .Type ())
1772+ pv .Elem ().Set (v )
1773+ m = pv .Interface ().(encoding.TextMarshaler )
1774+ }
1775+ data , err := m .MarshalText ()
1776+ if err != nil {
1777+ return false , fmt .Errorf ("cbor: cannot marshal text for %s: %w" , v .Type (), err )
1778+ }
1779+ return len (data ) == 0 , nil
1780+ }
1781+
17071782func encodeMarshalerType (e * bytes.Buffer , em * encMode , v reflect.Value ) error {
17081783 if em .tagsMd == TagsForbidden && v .Type () == typeRawTag {
17091784 return errors .New ("cbor: cannot encode cbor.RawTag when TagsMd is TagsForbidden" )
@@ -1810,6 +1885,7 @@ func encodeHead(e *bytes.Buffer, t byte, n uint64) int {
18101885var (
18111886 typeMarshaler = reflect .TypeOf ((* Marshaler )(nil )).Elem ()
18121887 typeBinaryMarshaler = reflect .TypeOf ((* encoding .BinaryMarshaler )(nil )).Elem ()
1888+ typeTextMarshaler = reflect .TypeOf ((* encoding .TextMarshaler )(nil )).Elem ()
18131889 typeRawMessage = reflect .TypeOf (RawMessage (nil ))
18141890 typeByteString = reflect .TypeOf (ByteString ("" ))
18151891)
@@ -1852,6 +1928,17 @@ func getEncodeFuncInternal(t reflect.Type) (ef encodeFunc, ief isEmptyFunc, izf
18521928 ief = bme .isEmpty
18531929 }()
18541930 }
1931+ if reflect .PointerTo (t ).Implements (typeTextMarshaler ) {
1932+ defer func () {
1933+ // capture encoding method used for modes that disable TextMarshaler
1934+ tme := textMarshalerEncoder {
1935+ alternateEncode : ef ,
1936+ alternateIsEmpty : ief ,
1937+ }
1938+ ef = tme .encode
1939+ ief = tme .isEmpty
1940+ }()
1941+ }
18551942 switch k {
18561943 case reflect .Bool :
18571944 return encodeBool , isEmptyBool , getIsZeroFunc (t )
0 commit comments