io.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. // Copyright ©2015 The Gonum Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mat
  5. import (
  6. "bytes"
  7. "encoding/binary"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "math"
  12. )
  13. // version is the current on-disk codec version.
  14. const version uint32 = 0x1
  15. // maxLen is the biggest slice/array len one can create on a 32/64b platform.
  16. const maxLen = int64(int(^uint(0) >> 1))
  17. var (
  18. headerSize = binary.Size(storage{})
  19. sizeFloat64 = binary.Size(float64(0))
  20. errWrongType = errors.New("mat: wrong data type")
  21. errTooBig = errors.New("mat: resulting data slice too big")
  22. errTooSmall = errors.New("mat: input slice too small")
  23. errBadBuffer = errors.New("mat: data buffer size mismatch")
  24. errBadSize = errors.New("mat: invalid dimension")
  25. )
  26. // Type encoding scheme:
  27. //
  28. // Type Form Packing Uplo Unit Rows Columns kU kL
  29. // uint8 [GST] uint8 [BPF] uint8 [AUL] bool int64 int64 int64 int64
  30. // General 'G' 'F' 'A' false r c 0 0
  31. // Band 'G' 'B' 'A' false r c kU kL
  32. // Symmetric 'S' 'F' ul false n n 0 0
  33. // SymmetricBand 'S' 'B' ul false n n k k
  34. // SymmetricPacked 'S' 'P' ul false n n 0 0
  35. // Triangular 'T' 'F' ul Diag==Unit n n 0 0
  36. // TriangularBand 'T' 'B' ul Diag==Unit n n k k
  37. // TriangularPacked 'T' 'P' ul Diag==Unit n n 0 0
  38. //
  39. // G - general, S - symmetric, T - triangular
  40. // F - full, B - band, P - packed
  41. // A - all, U - upper, L - lower
  42. // MarshalBinary encodes the receiver into a binary form and returns the result.
  43. //
  44. // Dense is little-endian encoded as follows:
  45. // 0 - 3 Version = 1 (uint32)
  46. // 4 'G' (byte)
  47. // 5 'F' (byte)
  48. // 6 'A' (byte)
  49. // 7 0 (byte)
  50. // 8 - 15 number of rows (int64)
  51. // 16 - 23 number of columns (int64)
  52. // 24 - 31 0 (int64)
  53. // 32 - 39 0 (int64)
  54. // 40 - .. matrix data elements (float64)
  55. // [0,0] [0,1] ... [0,ncols-1]
  56. // [1,0] [1,1] ... [1,ncols-1]
  57. // ...
  58. // [nrows-1,0] ... [nrows-1,ncols-1]
  59. func (m Dense) MarshalBinary() ([]byte, error) {
  60. bufLen := int64(headerSize) + int64(m.mat.Rows)*int64(m.mat.Cols)*int64(sizeFloat64)
  61. if bufLen <= 0 {
  62. // bufLen is too big and has wrapped around.
  63. return nil, errTooBig
  64. }
  65. header := storage{
  66. Form: 'G', Packing: 'F', Uplo: 'A',
  67. Rows: int64(m.mat.Rows), Cols: int64(m.mat.Cols),
  68. Version: version,
  69. }
  70. buf := make([]byte, bufLen)
  71. n, err := header.marshalBinaryTo(bytes.NewBuffer(buf[:0]))
  72. if err != nil {
  73. return buf[:n], err
  74. }
  75. p := headerSize
  76. r, c := m.Dims()
  77. for i := 0; i < r; i++ {
  78. for j := 0; j < c; j++ {
  79. binary.LittleEndian.PutUint64(buf[p:p+sizeFloat64], math.Float64bits(m.at(i, j)))
  80. p += sizeFloat64
  81. }
  82. }
  83. return buf, nil
  84. }
  85. // MarshalBinaryTo encodes the receiver into a binary form and writes it into w.
  86. // MarshalBinaryTo returns the number of bytes written into w and an error, if any.
  87. //
  88. // See MarshalBinary for the on-disk layout.
  89. func (m Dense) MarshalBinaryTo(w io.Writer) (int, error) {
  90. header := storage{
  91. Form: 'G', Packing: 'F', Uplo: 'A',
  92. Rows: int64(m.mat.Rows), Cols: int64(m.mat.Cols),
  93. Version: version,
  94. }
  95. n, err := header.marshalBinaryTo(w)
  96. if err != nil {
  97. return n, err
  98. }
  99. r, c := m.Dims()
  100. var b [8]byte
  101. for i := 0; i < r; i++ {
  102. for j := 0; j < c; j++ {
  103. binary.LittleEndian.PutUint64(b[:], math.Float64bits(m.at(i, j)))
  104. nn, err := w.Write(b[:])
  105. n += nn
  106. if err != nil {
  107. return n, err
  108. }
  109. }
  110. }
  111. return n, nil
  112. }
  113. // UnmarshalBinary decodes the binary form into the receiver.
  114. // It panics if the receiver is a non-empty Dense matrix.
  115. //
  116. // See MarshalBinary for the on-disk layout.
  117. //
  118. // Limited checks on the validity of the binary input are performed:
  119. // - matrix.ErrShape is returned if the number of rows or columns is negative,
  120. // - an error is returned if the resulting Dense matrix is too
  121. // big for the current architecture (e.g. a 16GB matrix written by a
  122. // 64b application and read back from a 32b application.)
  123. // UnmarshalBinary does not limit the size of the unmarshaled matrix, and so
  124. // it should not be used on untrusted data.
  125. func (m *Dense) UnmarshalBinary(data []byte) error {
  126. if !m.IsEmpty() {
  127. panic("mat: unmarshal into non-empty matrix")
  128. }
  129. if len(data) < headerSize {
  130. return errTooSmall
  131. }
  132. var header storage
  133. err := header.unmarshalBinary(data[:headerSize])
  134. if err != nil {
  135. return err
  136. }
  137. rows := header.Rows
  138. cols := header.Cols
  139. header.Version = 0
  140. header.Rows = 0
  141. header.Cols = 0
  142. if (header != storage{Form: 'G', Packing: 'F', Uplo: 'A'}) {
  143. return errWrongType
  144. }
  145. if rows < 0 || cols < 0 {
  146. return errBadSize
  147. }
  148. size := rows * cols
  149. if size == 0 {
  150. return ErrZeroLength
  151. }
  152. if int(size) < 0 || size > maxLen {
  153. return errTooBig
  154. }
  155. if len(data) != headerSize+int(rows*cols)*sizeFloat64 {
  156. return errBadBuffer
  157. }
  158. p := headerSize
  159. m.reuseAsNonZeroed(int(rows), int(cols))
  160. for i := range m.mat.Data {
  161. m.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[p : p+sizeFloat64]))
  162. p += sizeFloat64
  163. }
  164. return nil
  165. }
  166. // UnmarshalBinaryFrom decodes the binary form into the receiver and returns
  167. // the number of bytes read and an error if any.
  168. // It panics if the receiver is a non-empty Dense matrix.
  169. //
  170. // See MarshalBinary for the on-disk layout.
  171. //
  172. // Limited checks on the validity of the binary input are performed:
  173. // - matrix.ErrShape is returned if the number of rows or columns is negative,
  174. // - an error is returned if the resulting Dense matrix is too
  175. // big for the current architecture (e.g. a 16GB matrix written by a
  176. // 64b application and read back from a 32b application.)
  177. // UnmarshalBinary does not limit the size of the unmarshaled matrix, and so
  178. // it should not be used on untrusted data.
  179. func (m *Dense) UnmarshalBinaryFrom(r io.Reader) (int, error) {
  180. if !m.IsEmpty() {
  181. panic("mat: unmarshal into non-empty matrix")
  182. }
  183. var header storage
  184. n, err := header.unmarshalBinaryFrom(r)
  185. if err != nil {
  186. return n, err
  187. }
  188. rows := header.Rows
  189. cols := header.Cols
  190. header.Version = 0
  191. header.Rows = 0
  192. header.Cols = 0
  193. if (header != storage{Form: 'G', Packing: 'F', Uplo: 'A'}) {
  194. return n, errWrongType
  195. }
  196. if rows < 0 || cols < 0 {
  197. return n, errBadSize
  198. }
  199. size := rows * cols
  200. if size == 0 {
  201. return n, ErrZeroLength
  202. }
  203. if int(size) < 0 || size > maxLen {
  204. return n, errTooBig
  205. }
  206. m.reuseAsNonZeroed(int(rows), int(cols))
  207. var b [8]byte
  208. for i := range m.mat.Data {
  209. nn, err := readFull(r, b[:])
  210. n += nn
  211. if err != nil {
  212. if err == io.EOF {
  213. return n, io.ErrUnexpectedEOF
  214. }
  215. return n, err
  216. }
  217. m.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(b[:]))
  218. }
  219. return n, nil
  220. }
  221. // MarshalBinary encodes the receiver into a binary form and returns the result.
  222. //
  223. // VecDense is little-endian encoded as follows:
  224. //
  225. // 0 - 3 Version = 1 (uint32)
  226. // 4 'G' (byte)
  227. // 5 'F' (byte)
  228. // 6 'A' (byte)
  229. // 7 0 (byte)
  230. // 8 - 15 number of elements (int64)
  231. // 16 - 23 1 (int64)
  232. // 24 - 31 0 (int64)
  233. // 32 - 39 0 (int64)
  234. // 40 - .. vector's data elements (float64)
  235. func (v VecDense) MarshalBinary() ([]byte, error) {
  236. bufLen := int64(headerSize) + int64(v.mat.N)*int64(sizeFloat64)
  237. if bufLen <= 0 {
  238. // bufLen is too big and has wrapped around.
  239. return nil, errTooBig
  240. }
  241. header := storage{
  242. Form: 'G', Packing: 'F', Uplo: 'A',
  243. Rows: int64(v.mat.N), Cols: 1,
  244. Version: version,
  245. }
  246. buf := make([]byte, bufLen)
  247. n, err := header.marshalBinaryTo(bytes.NewBuffer(buf[:0]))
  248. if err != nil {
  249. return buf[:n], err
  250. }
  251. p := headerSize
  252. for i := 0; i < v.mat.N; i++ {
  253. binary.LittleEndian.PutUint64(buf[p:p+sizeFloat64], math.Float64bits(v.at(i)))
  254. p += sizeFloat64
  255. }
  256. return buf, nil
  257. }
  258. // MarshalBinaryTo encodes the receiver into a binary form, writes it to w and
  259. // returns the number of bytes written and an error if any.
  260. //
  261. // See MarshalBainry for the on-disk format.
  262. func (v VecDense) MarshalBinaryTo(w io.Writer) (int, error) {
  263. header := storage{
  264. Form: 'G', Packing: 'F', Uplo: 'A',
  265. Rows: int64(v.mat.N), Cols: 1,
  266. Version: version,
  267. }
  268. n, err := header.marshalBinaryTo(w)
  269. if err != nil {
  270. return n, err
  271. }
  272. var buf [8]byte
  273. for i := 0; i < v.mat.N; i++ {
  274. binary.LittleEndian.PutUint64(buf[:], math.Float64bits(v.at(i)))
  275. nn, err := w.Write(buf[:])
  276. n += nn
  277. if err != nil {
  278. return n, err
  279. }
  280. }
  281. return n, nil
  282. }
  283. // UnmarshalBinary decodes the binary form into the receiver.
  284. // It panics if the receiver is a non-empty VecDense.
  285. //
  286. // See MarshalBinary for the on-disk layout.
  287. //
  288. // Limited checks on the validity of the binary input are performed:
  289. // - matrix.ErrShape is returned if the number of rows is negative,
  290. // - an error is returned if the resulting VecDense is too
  291. // big for the current architecture (e.g. a 16GB vector written by a
  292. // 64b application and read back from a 32b application.)
  293. // UnmarshalBinary does not limit the size of the unmarshaled vector, and so
  294. // it should not be used on untrusted data.
  295. func (v *VecDense) UnmarshalBinary(data []byte) error {
  296. if !v.IsEmpty() {
  297. panic("mat: unmarshal into non-empty vector")
  298. }
  299. if len(data) < headerSize {
  300. return errTooSmall
  301. }
  302. var header storage
  303. err := header.unmarshalBinary(data[:headerSize])
  304. if err != nil {
  305. return err
  306. }
  307. if header.Cols != 1 {
  308. return ErrShape
  309. }
  310. n := header.Rows
  311. header.Version = 0
  312. header.Rows = 0
  313. header.Cols = 0
  314. if (header != storage{Form: 'G', Packing: 'F', Uplo: 'A'}) {
  315. return errWrongType
  316. }
  317. if n == 0 {
  318. return ErrZeroLength
  319. }
  320. if n < 0 {
  321. return errBadSize
  322. }
  323. if int64(maxLen) < n {
  324. return errTooBig
  325. }
  326. if len(data) != headerSize+int(n)*sizeFloat64 {
  327. return errBadBuffer
  328. }
  329. p := headerSize
  330. v.reuseAsNonZeroed(int(n))
  331. for i := range v.mat.Data {
  332. v.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[p : p+sizeFloat64]))
  333. p += sizeFloat64
  334. }
  335. return nil
  336. }
  337. // UnmarshalBinaryFrom decodes the binary form into the receiver, from the
  338. // io.Reader and returns the number of bytes read and an error if any.
  339. // It panics if the receiver is a non-empty VecDense.
  340. //
  341. // See MarshalBinary for the on-disk layout.
  342. // See UnmarshalBinary for the list of sanity checks performed on the input.
  343. func (v *VecDense) UnmarshalBinaryFrom(r io.Reader) (int, error) {
  344. if !v.IsEmpty() {
  345. panic("mat: unmarshal into non-empty vector")
  346. }
  347. var header storage
  348. n, err := header.unmarshalBinaryFrom(r)
  349. if err != nil {
  350. return n, err
  351. }
  352. if header.Cols != 1 {
  353. return n, ErrShape
  354. }
  355. l := header.Rows
  356. header.Version = 0
  357. header.Rows = 0
  358. header.Cols = 0
  359. if (header != storage{Form: 'G', Packing: 'F', Uplo: 'A'}) {
  360. return n, errWrongType
  361. }
  362. if l == 0 {
  363. return n, ErrZeroLength
  364. }
  365. if l < 0 {
  366. return n, errBadSize
  367. }
  368. if int64(maxLen) < l {
  369. return n, errTooBig
  370. }
  371. v.reuseAsNonZeroed(int(l))
  372. var b [8]byte
  373. for i := range v.mat.Data {
  374. nn, err := readFull(r, b[:])
  375. n += nn
  376. if err != nil {
  377. if err == io.EOF {
  378. return n, io.ErrUnexpectedEOF
  379. }
  380. return n, err
  381. }
  382. v.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(b[:]))
  383. }
  384. return n, nil
  385. }
  386. // storage is the internal representation of the storage format of a
  387. // serialised matrix.
  388. type storage struct {
  389. Version uint32 // Keep this first.
  390. Form byte // [GST]
  391. Packing byte // [BPF]
  392. Uplo byte // [AUL]
  393. Unit bool
  394. Rows int64
  395. Cols int64
  396. KU int64
  397. KL int64
  398. }
  399. // TODO(kortschak): Consider replacing these with calls to direct
  400. // encoding/decoding of fields rather than to binary.Write/binary.Read.
  401. func (s storage) marshalBinaryTo(w io.Writer) (int, error) {
  402. buf := bytes.NewBuffer(make([]byte, 0, headerSize))
  403. err := binary.Write(buf, binary.LittleEndian, s)
  404. if err != nil {
  405. return 0, err
  406. }
  407. return w.Write(buf.Bytes())
  408. }
  409. func (s *storage) unmarshalBinary(buf []byte) error {
  410. err := binary.Read(bytes.NewReader(buf), binary.LittleEndian, s)
  411. if err != nil {
  412. return err
  413. }
  414. if s.Version != version {
  415. return fmt.Errorf("mat: incorrect version: %d", s.Version)
  416. }
  417. return nil
  418. }
  419. func (s *storage) unmarshalBinaryFrom(r io.Reader) (int, error) {
  420. buf := make([]byte, headerSize)
  421. n, err := readFull(r, buf)
  422. if err != nil {
  423. return n, err
  424. }
  425. return n, s.unmarshalBinary(buf[:n])
  426. }
  427. // readFull reads from r into buf until it has read len(buf).
  428. // It returns the number of bytes copied and an error if fewer bytes were read.
  429. // If an EOF happens after reading fewer than len(buf) bytes, io.ErrUnexpectedEOF is returned.
  430. func readFull(r io.Reader, buf []byte) (int, error) {
  431. var n int
  432. var err error
  433. for n < len(buf) && err == nil {
  434. var nn int
  435. nn, err = r.Read(buf[n:])
  436. n += nn
  437. }
  438. if n == len(buf) {
  439. return n, nil
  440. }
  441. if err == io.EOF {
  442. return n, io.ErrUnexpectedEOF
  443. }
  444. return n, err
  445. }