response.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. package protocol
  2. import (
  3. "fmt"
  4. "io"
  5. )
  6. func ReadResponse(r io.Reader, apiKey ApiKey, apiVersion int16) (correlationID int32, msg Message, err error) {
  7. if i := int(apiKey); i < 0 || i >= len(apiTypes) {
  8. err = fmt.Errorf("unsupported api key: %d", i)
  9. return
  10. }
  11. t := &apiTypes[apiKey]
  12. if t == nil {
  13. err = fmt.Errorf("unsupported api: %s", apiNames[apiKey])
  14. return
  15. }
  16. minVersion := t.minVersion()
  17. maxVersion := t.maxVersion()
  18. if apiVersion < minVersion || apiVersion > maxVersion {
  19. err = fmt.Errorf("unsupported %s version: v%d not in range v%d-v%d", apiKey, apiVersion, minVersion, maxVersion)
  20. return
  21. }
  22. d := &decoder{reader: r, remain: 4}
  23. size := d.readInt32()
  24. if err = d.err; err != nil {
  25. err = dontExpectEOF(err)
  26. return
  27. }
  28. d.remain = int(size)
  29. correlationID = d.readInt32()
  30. res := &t.responses[apiVersion-minVersion]
  31. if res.flexible {
  32. // In the flexible case, there's a tag buffer at the end of the response header
  33. taggedCount := int(d.readUnsignedVarInt())
  34. for i := 0; i < taggedCount; i++ {
  35. d.readUnsignedVarInt() // tagID
  36. size := d.readUnsignedVarInt()
  37. // Just throw away the values for now
  38. d.read(int(size))
  39. }
  40. }
  41. msg = res.new()
  42. res.decode(d, valueOf(msg))
  43. d.discardAll()
  44. if err = d.err; err != nil {
  45. err = dontExpectEOF(err)
  46. }
  47. return
  48. }
  49. func WriteResponse(w io.Writer, apiVersion int16, correlationID int32, msg Message) error {
  50. apiKey := msg.ApiKey()
  51. if i := int(apiKey); i < 0 || i >= len(apiTypes) {
  52. return fmt.Errorf("unsupported api key: %d", i)
  53. }
  54. t := &apiTypes[apiKey]
  55. if t == nil {
  56. return fmt.Errorf("unsupported api: %s", apiNames[apiKey])
  57. }
  58. minVersion := t.minVersion()
  59. maxVersion := t.maxVersion()
  60. if apiVersion < minVersion || apiVersion > maxVersion {
  61. return fmt.Errorf("unsupported %s version: v%d not in range v%d-v%d", apiKey, apiVersion, minVersion, maxVersion)
  62. }
  63. r := &t.responses[apiVersion-minVersion]
  64. v := valueOf(msg)
  65. b := newPageBuffer()
  66. defer b.unref()
  67. e := &encoder{writer: b}
  68. e.writeInt32(0) // placeholder for the response size
  69. e.writeInt32(correlationID)
  70. r.encode(e, v)
  71. err := e.err
  72. if err == nil {
  73. size := packUint32(uint32(b.Size()) - 4)
  74. b.WriteAt(size[:], 0)
  75. _, err = b.WriteTo(w)
  76. }
  77. return err
  78. }