encoder.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "crypto/rand"
  7. "fmt"
  8. "io"
  9. rdebug "runtime/debug"
  10. "sync"
  11. "github.com/klauspost/compress/zstd/internal/xxhash"
  12. )
  13. // Encoder provides encoding to Zstandard.
  14. // An Encoder can be used for either compressing a stream via the
  15. // io.WriteCloser interface supported by the Encoder or as multiple independent
  16. // tasks via the EncodeAll function.
  17. // Smaller encodes are encouraged to use the EncodeAll function.
  18. // Use NewWriter to create a new instance.
  19. type Encoder struct {
  20. o encoderOptions
  21. encoders chan encoder
  22. state encoderState
  23. init sync.Once
  24. }
  25. type encoder interface {
  26. Encode(blk *blockEnc, src []byte)
  27. EncodeNoHist(blk *blockEnc, src []byte)
  28. Block() *blockEnc
  29. CRC() *xxhash.Digest
  30. AppendCRC([]byte) []byte
  31. WindowSize(size int) int32
  32. UseBlock(*blockEnc)
  33. Reset()
  34. }
  35. type encoderState struct {
  36. w io.Writer
  37. filling []byte
  38. current []byte
  39. previous []byte
  40. encoder encoder
  41. writing *blockEnc
  42. err error
  43. writeErr error
  44. nWritten int64
  45. headerWritten bool
  46. eofWritten bool
  47. // This waitgroup indicates an encode is running.
  48. wg sync.WaitGroup
  49. // This waitgroup indicates we have a block encoding/writing.
  50. wWg sync.WaitGroup
  51. }
  52. // NewWriter will create a new Zstandard encoder.
  53. // If the encoder will be used for encoding blocks a nil writer can be used.
  54. func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
  55. initPredefined()
  56. var e Encoder
  57. e.o.setDefault()
  58. for _, o := range opts {
  59. err := o(&e.o)
  60. if err != nil {
  61. return nil, err
  62. }
  63. }
  64. if w != nil {
  65. e.Reset(w)
  66. } else {
  67. e.init.Do(func() {
  68. e.initialize()
  69. })
  70. }
  71. return &e, nil
  72. }
  73. func (e *Encoder) initialize() {
  74. e.encoders = make(chan encoder, e.o.concurrent)
  75. for i := 0; i < e.o.concurrent; i++ {
  76. e.encoders <- e.o.encoder()
  77. }
  78. }
  79. // Reset will re-initialize the writer and new writes will encode to the supplied writer
  80. // as a new, independent stream.
  81. func (e *Encoder) Reset(w io.Writer) {
  82. e.init.Do(func() {
  83. e.initialize()
  84. })
  85. s := &e.state
  86. s.wg.Wait()
  87. s.wWg.Wait()
  88. if cap(s.filling) == 0 {
  89. s.filling = make([]byte, 0, e.o.blockSize)
  90. }
  91. if cap(s.current) == 0 {
  92. s.current = make([]byte, 0, e.o.blockSize)
  93. }
  94. if cap(s.previous) == 0 {
  95. s.previous = make([]byte, 0, e.o.blockSize)
  96. }
  97. if s.encoder == nil {
  98. s.encoder = e.o.encoder()
  99. }
  100. if s.writing == nil {
  101. s.writing = &blockEnc{}
  102. s.writing.init()
  103. }
  104. s.writing.initNewEncode()
  105. s.filling = s.filling[:0]
  106. s.current = s.current[:0]
  107. s.previous = s.previous[:0]
  108. s.encoder.Reset()
  109. s.headerWritten = false
  110. s.eofWritten = false
  111. s.w = w
  112. s.err = nil
  113. s.nWritten = 0
  114. s.writeErr = nil
  115. }
  116. // Write data to the encoder.
  117. // Input data will be buffered and as the buffer fills up
  118. // content will be compressed and written to the output.
  119. // When done writing, use Close to flush the remaining output
  120. // and write CRC if requested.
  121. func (e *Encoder) Write(p []byte) (n int, err error) {
  122. s := &e.state
  123. for len(p) > 0 {
  124. if len(p)+len(s.filling) < e.o.blockSize {
  125. if e.o.crc {
  126. _, _ = s.encoder.CRC().Write(p)
  127. }
  128. s.filling = append(s.filling, p...)
  129. return n + len(p), nil
  130. }
  131. add := p
  132. if len(p)+len(s.filling) > e.o.blockSize {
  133. add = add[:e.o.blockSize-len(s.filling)]
  134. }
  135. if e.o.crc {
  136. _, _ = s.encoder.CRC().Write(add)
  137. }
  138. s.filling = append(s.filling, add...)
  139. p = p[len(add):]
  140. n += len(add)
  141. if len(s.filling) < e.o.blockSize {
  142. return n, nil
  143. }
  144. err := e.nextBlock(false)
  145. if err != nil {
  146. return n, err
  147. }
  148. if debug && len(s.filling) > 0 {
  149. panic(len(s.filling))
  150. }
  151. }
  152. return n, nil
  153. }
  154. // nextBlock will synchronize and start compressing input in e.state.filling.
  155. // If an error has occurred during encoding it will be returned.
  156. func (e *Encoder) nextBlock(final bool) error {
  157. s := &e.state
  158. // Wait for current block.
  159. s.wg.Wait()
  160. if s.err != nil {
  161. return s.err
  162. }
  163. if len(s.filling) > e.o.blockSize {
  164. return fmt.Errorf("block > maxStoreBlockSize")
  165. }
  166. if !s.headerWritten {
  167. var tmp [maxHeaderSize]byte
  168. fh := frameHeader{
  169. ContentSize: 0,
  170. WindowSize: uint32(s.encoder.WindowSize(0)),
  171. SingleSegment: false,
  172. Checksum: e.o.crc,
  173. DictID: 0,
  174. }
  175. dst, err := fh.appendTo(tmp[:0])
  176. if err != nil {
  177. return err
  178. }
  179. s.headerWritten = true
  180. s.wWg.Wait()
  181. var n2 int
  182. n2, s.err = s.w.Write(dst)
  183. if s.err != nil {
  184. return s.err
  185. }
  186. s.nWritten += int64(n2)
  187. }
  188. if s.eofWritten {
  189. // Ensure we only write it once.
  190. final = false
  191. }
  192. if len(s.filling) == 0 {
  193. // Final block, but no data.
  194. if final {
  195. enc := s.encoder
  196. blk := enc.Block()
  197. blk.reset(nil)
  198. blk.last = true
  199. blk.encodeRaw(nil)
  200. s.wWg.Wait()
  201. _, s.err = s.w.Write(blk.output)
  202. s.nWritten += int64(len(blk.output))
  203. s.eofWritten = true
  204. }
  205. return s.err
  206. }
  207. // Move blocks forward.
  208. s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
  209. s.wg.Add(1)
  210. go func(src []byte) {
  211. if debug {
  212. println("Adding block,", len(src), "bytes, final:", final)
  213. }
  214. defer func() {
  215. if r := recover(); r != nil {
  216. s.err = fmt.Errorf("panic while encoding: %v", r)
  217. rdebug.PrintStack()
  218. }
  219. s.wg.Done()
  220. }()
  221. enc := s.encoder
  222. blk := enc.Block()
  223. enc.Encode(blk, src)
  224. blk.last = final
  225. if final {
  226. s.eofWritten = true
  227. }
  228. // Wait for pending writes.
  229. s.wWg.Wait()
  230. if s.writeErr != nil {
  231. s.err = s.writeErr
  232. return
  233. }
  234. // Transfer encoders from previous write block.
  235. blk.swapEncoders(s.writing)
  236. // Transfer recent offsets to next.
  237. enc.UseBlock(s.writing)
  238. s.writing = blk
  239. s.wWg.Add(1)
  240. go func() {
  241. defer func() {
  242. if r := recover(); r != nil {
  243. s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
  244. rdebug.PrintStack()
  245. }
  246. s.wWg.Done()
  247. }()
  248. err := errIncompressible
  249. // If we got the exact same number of literals as input,
  250. // assume the literals cannot be compressed.
  251. if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
  252. err = blk.encode(e.o.noEntropy)
  253. }
  254. switch err {
  255. case errIncompressible:
  256. if debug {
  257. println("Storing incompressible block as raw")
  258. }
  259. blk.encodeRaw(src)
  260. // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
  261. case nil:
  262. default:
  263. s.writeErr = err
  264. return
  265. }
  266. _, s.writeErr = s.w.Write(blk.output)
  267. s.nWritten += int64(len(blk.output))
  268. }()
  269. }(s.current)
  270. return nil
  271. }
  272. // ReadFrom reads data from r until EOF or error.
  273. // The return value n is the number of bytes read.
  274. // Any error except io.EOF encountered during the read is also returned.
  275. //
  276. // The Copy function uses ReaderFrom if available.
  277. func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
  278. if debug {
  279. println("Using ReadFrom")
  280. }
  281. // Maybe handle stuff queued?
  282. e.state.filling = e.state.filling[:e.o.blockSize]
  283. src := e.state.filling
  284. for {
  285. n2, err := r.Read(src)
  286. _, _ = e.state.encoder.CRC().Write(src[:n2])
  287. // src is now the unfilled part...
  288. src = src[n2:]
  289. n += int64(n2)
  290. switch err {
  291. case io.EOF:
  292. e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
  293. if debug {
  294. println("ReadFrom: got EOF final block:", len(e.state.filling))
  295. }
  296. return n, e.nextBlock(true)
  297. default:
  298. if debug {
  299. println("ReadFrom: got error:", err)
  300. }
  301. e.state.err = err
  302. return n, err
  303. case nil:
  304. }
  305. if len(src) > 0 {
  306. if debug {
  307. println("ReadFrom: got space left in source:", len(src))
  308. }
  309. continue
  310. }
  311. err = e.nextBlock(false)
  312. if err != nil {
  313. return n, err
  314. }
  315. e.state.filling = e.state.filling[:e.o.blockSize]
  316. src = e.state.filling
  317. }
  318. }
  319. // Flush will send the currently written data to output
  320. // and block until everything has been written.
  321. // This should only be used on rare occasions where pushing the currently queued data is critical.
  322. func (e *Encoder) Flush() error {
  323. s := &e.state
  324. if len(s.filling) > 0 {
  325. err := e.nextBlock(false)
  326. if err != nil {
  327. return err
  328. }
  329. }
  330. s.wg.Wait()
  331. s.wWg.Wait()
  332. if s.err != nil {
  333. return s.err
  334. }
  335. return s.writeErr
  336. }
  337. // Close will flush the final output and close the stream.
  338. // The function will block until everything has been written.
  339. // The Encoder can still be re-used after calling this.
  340. func (e *Encoder) Close() error {
  341. s := &e.state
  342. if s.encoder == nil {
  343. return nil
  344. }
  345. err := e.nextBlock(true)
  346. if err != nil {
  347. return err
  348. }
  349. s.wg.Wait()
  350. s.wWg.Wait()
  351. if s.err != nil {
  352. return s.err
  353. }
  354. if s.writeErr != nil {
  355. return s.writeErr
  356. }
  357. // Write CRC
  358. if e.o.crc && s.err == nil {
  359. // heap alloc.
  360. var tmp [4]byte
  361. _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
  362. s.nWritten += 4
  363. }
  364. // Add padding with content from crypto/rand.Reader
  365. if s.err == nil && e.o.pad > 0 {
  366. add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
  367. frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
  368. if err != nil {
  369. return err
  370. }
  371. _, s.err = s.w.Write(frame)
  372. }
  373. return s.err
  374. }
  375. // EncodeAll will encode all input in src and append it to dst.
  376. // This function can be called concurrently, but each call will only run on a single goroutine.
  377. // If empty input is given, nothing is returned, unless WithZeroFrames is specified.
  378. // Encoded blocks can be concatenated and the result will be the combined input stream.
  379. // Data compressed with EncodeAll can be decoded with the Decoder,
  380. // using either a stream or DecodeAll.
  381. func (e *Encoder) EncodeAll(src, dst []byte) []byte {
  382. if len(src) == 0 {
  383. if e.o.fullZero {
  384. // Add frame header.
  385. fh := frameHeader{
  386. ContentSize: 0,
  387. WindowSize: MinWindowSize,
  388. SingleSegment: true,
  389. // Adding a checksum would be a waste of space.
  390. Checksum: false,
  391. DictID: 0,
  392. }
  393. dst, _ = fh.appendTo(dst)
  394. // Write raw block as last one only.
  395. var blk blockHeader
  396. blk.setSize(0)
  397. blk.setType(blockTypeRaw)
  398. blk.setLast(true)
  399. dst = blk.appendTo(dst)
  400. }
  401. return dst
  402. }
  403. e.init.Do(func() {
  404. e.o.setDefault()
  405. e.initialize()
  406. })
  407. enc := <-e.encoders
  408. defer func() {
  409. // Release encoder reference to last block.
  410. enc.Reset()
  411. e.encoders <- enc
  412. }()
  413. enc.Reset()
  414. blk := enc.Block()
  415. // Use single segments when above minimum window and below 1MB.
  416. single := len(src) < 1<<20 && len(src) > MinWindowSize
  417. if e.o.single != nil {
  418. single = *e.o.single
  419. }
  420. fh := frameHeader{
  421. ContentSize: uint64(len(src)),
  422. WindowSize: uint32(enc.WindowSize(len(src))),
  423. SingleSegment: single,
  424. Checksum: e.o.crc,
  425. DictID: 0,
  426. }
  427. // If less than 1MB, allocate a buffer up front.
  428. if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 {
  429. dst = make([]byte, 0, len(src))
  430. }
  431. dst, err := fh.appendTo(dst)
  432. if err != nil {
  433. panic(err)
  434. }
  435. if len(src) <= e.o.blockSize && len(src) <= maxBlockSize {
  436. // Slightly faster with no history and everything in one block.
  437. if e.o.crc {
  438. _, _ = enc.CRC().Write(src)
  439. }
  440. blk.reset(nil)
  441. blk.last = true
  442. enc.EncodeNoHist(blk, src)
  443. // If we got the exact same number of literals as input,
  444. // assume the literals cannot be compressed.
  445. err := errIncompressible
  446. oldout := blk.output
  447. if len(blk.literals) != len(src) || len(src) != e.o.blockSize {
  448. // Output directly to dst
  449. blk.output = dst
  450. err = blk.encode(e.o.noEntropy)
  451. }
  452. switch err {
  453. case errIncompressible:
  454. if debug {
  455. println("Storing incompressible block as raw")
  456. }
  457. dst = blk.encodeRawTo(dst, src)
  458. case nil:
  459. dst = blk.output
  460. default:
  461. panic(err)
  462. }
  463. blk.output = oldout
  464. } else {
  465. for len(src) > 0 {
  466. todo := src
  467. if len(todo) > e.o.blockSize {
  468. todo = todo[:e.o.blockSize]
  469. }
  470. src = src[len(todo):]
  471. if e.o.crc {
  472. _, _ = enc.CRC().Write(todo)
  473. }
  474. blk.reset(nil)
  475. blk.pushOffsets()
  476. enc.Encode(blk, todo)
  477. if len(src) == 0 {
  478. blk.last = true
  479. }
  480. err := errIncompressible
  481. // If we got the exact same number of literals as input,
  482. // assume the literals cannot be compressed.
  483. if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
  484. err = blk.encode(e.o.noEntropy)
  485. }
  486. switch err {
  487. case errIncompressible:
  488. if debug {
  489. println("Storing incompressible block as raw")
  490. }
  491. dst = blk.encodeRawTo(dst, todo)
  492. blk.popOffsets()
  493. case nil:
  494. dst = append(dst, blk.output...)
  495. default:
  496. panic(err)
  497. }
  498. }
  499. }
  500. if e.o.crc {
  501. dst = enc.AppendCRC(dst)
  502. }
  503. // Add padding with content from crypto/rand.Reader
  504. if e.o.pad > 0 {
  505. add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
  506. dst, err = skippableFrame(dst, add, rand.Reader)
  507. if err != nil {
  508. panic(err)
  509. }
  510. }
  511. return dst
  512. }