1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.collections4.bag;
18
19 import java.io.IOException;
20 import java.io.ObjectInputStream;
21 import java.io.ObjectOutputStream;
22 import java.lang.reflect.Array;
23 import java.util.Collection;
24 import java.util.ConcurrentModificationException;
25 import java.util.Iterator;
26 import java.util.Map;
27 import java.util.Map.Entry;
28 import java.util.Objects;
29 import java.util.Set;
30
31 import org.apache.commons.collections4.Bag;
32 import org.apache.commons.collections4.CollectionUtils;
33 import org.apache.commons.collections4.set.UnmodifiableSet;
34
35
36
37
38
39
40
41
42
43
44
45
46
47 public abstract class AbstractMapBag<E> implements Bag<E> {
48
49
50
51
52 static class BagIterator<E> implements Iterator<E> {
53 private final AbstractMapBag<E> parent;
54 private final Iterator<Map.Entry<E, MutableInteger>> entryIterator;
55 private Map.Entry<E, MutableInteger> current;
56 private int itemCount;
57 private final int mods;
58 private boolean canRemove;
59
60
61
62
63
64
65 BagIterator(final AbstractMapBag<E> parent) {
66 this.parent = parent;
67 this.entryIterator = parent.map.entrySet().iterator();
68 this.current = null;
69 this.mods = parent.modCount;
70 this.canRemove = false;
71 }
72
73
74 @Override
75 public boolean hasNext() {
76 return itemCount > 0 || entryIterator.hasNext();
77 }
78
79
80 @Override
81 public E next() {
82 if (parent.modCount != mods) {
83 throw new ConcurrentModificationException();
84 }
85 if (itemCount == 0) {
86 current = entryIterator.next();
87 itemCount = current.getValue().value;
88 }
89 canRemove = true;
90 itemCount--;
91 return current.getKey();
92 }
93
94
95 @Override
96 public void remove() {
97 if (parent.modCount != mods) {
98 throw new ConcurrentModificationException();
99 }
100 if (!canRemove) {
101 throw new IllegalStateException();
102 }
103 final MutableInteger mut = current.getValue();
104 if (mut.value > 1) {
105 mut.value--;
106 } else {
107 entryIterator.remove();
108 }
109 parent.size--;
110 canRemove = false;
111 }
112 }
113
114
115
116
117 protected static class MutableInteger {
118
119
120 protected int value;
121
122
123
124
125
126 MutableInteger(final int value) {
127 this.value = value;
128 }
129
130 @Override
131 public boolean equals(final Object obj) {
132 if (!(obj instanceof MutableInteger)) {
133 return false;
134 }
135 return ((MutableInteger) obj).value == value;
136 }
137
138 @Override
139 public int hashCode() {
140 return value;
141 }
142 }
143
144
145 private transient Map<E, MutableInteger> map;
146
147
148 private int size;
149
150
151 private transient int modCount;
152
153
154 private transient Set<E> uniqueSet;
155
156
157
158
159 protected AbstractMapBag() {
160 }
161
162
163
164
165
166
167
168 protected AbstractMapBag(final Map<E, MutableInteger> map) {
169 this.map = Objects.requireNonNull(map, "map");
170 }
171
172
173
174
175
176
177
178
179 protected AbstractMapBag(final Map<E, MutableInteger> map, final Iterable<? extends E> iterable) {
180 this(map);
181 iterable.forEach(this::add);
182 }
183
184
185
186
187
188
189
190 @Override
191 public boolean add(final E object) {
192 return add(object, 1);
193 }
194
195
196
197
198
199
200
201
202 @Override
203 public boolean add(final E object, final int nCopies) {
204 modCount++;
205 if (nCopies > 0) {
206 final MutableInteger mut = map.get(object);
207 size += nCopies;
208 if (mut == null) {
209 map.put(object, new MutableInteger(nCopies));
210 return true;
211 }
212 mut.value += nCopies;
213 }
214 return false;
215 }
216
217
218
219
220
221
222
223 @Override
224 public boolean addAll(final Collection<? extends E> coll) {
225 boolean changed = false;
226 for (final E current : coll) {
227 final boolean added = add(current);
228 changed = changed || added;
229 }
230 return changed;
231 }
232
233
234
235
236 @Override
237 public void clear() {
238 modCount++;
239 map.clear();
240 size = 0;
241 }
242
243
244
245
246
247
248
249
250 @Override
251 public boolean contains(final Object object) {
252 return map.containsKey(object);
253 }
254
255
256
257
258
259
260
261
262 boolean containsAll(final Bag<?> other) {
263 for (final Object current : other.uniqueSet()) {
264 if (getCount(current) < other.getCount(current)) {
265 return false;
266 }
267 }
268 return true;
269 }
270
271
272
273
274
275
276
277 @Override
278 public boolean containsAll(final Collection<?> coll) {
279 if (coll instanceof Bag) {
280 return containsAll((Bag<?>) coll);
281 }
282 return containsAll(new HashBag<>(coll));
283 }
284
285
286
287
288
289
290
291
292
293 protected void doReadObject(final Map<E, MutableInteger> map, final ObjectInputStream in)
294 throws IOException, ClassNotFoundException {
295 this.map = map;
296 final int entrySize = in.readInt();
297 for (int i = 0; i < entrySize; i++) {
298 @SuppressWarnings("unchecked")
299 final E obj = (E) in.readObject();
300 final int count = in.readInt();
301 map.put(obj, new MutableInteger(count));
302 size += count;
303 }
304 }
305
306
307
308
309
310
311 protected void doWriteObject(final ObjectOutputStream out) throws IOException {
312 out.writeInt(map.size());
313 for (final Entry<E, MutableInteger> entry : map.entrySet()) {
314 out.writeObject(entry.getKey());
315 out.writeInt(entry.getValue().value);
316 }
317 }
318
319
320
321
322
323
324
325
326 @Override
327 public boolean equals(final Object object) {
328 if (object == this) {
329 return true;
330 }
331 if (!(object instanceof Bag)) {
332 return false;
333 }
334 final Bag<?> other = (Bag<?>) object;
335 if (other.size() != size()) {
336 return false;
337 }
338 for (final E element : map.keySet()) {
339 if (other.getCount(element) != getCount(element)) {
340 return false;
341 }
342 }
343 return true;
344 }
345
346
347
348
349
350
351
352
353 @Override
354 public int getCount(final Object object) {
355 final MutableInteger count = map.get(object);
356 if (count != null) {
357 return count.value;
358 }
359 return 0;
360 }
361
362
363
364
365
366
367
368 protected Map<E, MutableInteger> getMap() {
369 return map;
370 }
371
372
373
374
375
376
377
378
379
380
381 @Override
382 public int hashCode() {
383 int total = 0;
384 for (final Entry<E, MutableInteger> entry : map.entrySet()) {
385 final E element = entry.getKey();
386 final MutableInteger count = entry.getValue();
387 total += (element == null ? 0 : element.hashCode()) ^ count.value;
388 }
389 return total;
390 }
391
392
393
394
395
396
397 @Override
398 public boolean isEmpty() {
399 return map.isEmpty();
400 }
401
402
403
404
405
406
407
408 @Override
409 public Iterator<E> iterator() {
410 return new BagIterator<>(this);
411 }
412
413
414
415
416
417
418
419 @Override
420 public boolean remove(final Object object) {
421 final MutableInteger mut = map.get(object);
422 if (mut == null) {
423 return false;
424 }
425 modCount++;
426 map.remove(object);
427 size -= mut.value;
428 return true;
429 }
430
431
432
433
434
435
436
437
438 @Override
439 public boolean remove(final Object object, final int nCopies) {
440 final MutableInteger mut = map.get(object);
441 if (mut == null) {
442 return false;
443 }
444 if (nCopies <= 0) {
445 return false;
446 }
447 modCount++;
448 if (nCopies < mut.value) {
449 mut.value -= nCopies;
450 size -= nCopies;
451 } else {
452 map.remove(object);
453 size -= mut.value;
454 }
455 return true;
456 }
457
458
459
460
461
462
463
464
465 @Override
466 public boolean removeAll(final Collection<?> coll) {
467 boolean result = false;
468 if (coll != null) {
469 for (final Object current : coll) {
470 final boolean changed = remove(current, 1);
471 result = result || changed;
472 }
473 }
474 return result;
475 }
476
477
478
479
480
481
482
483
484 boolean retainAll(final Bag<?> other) {
485 boolean result = false;
486 final Bag<E> excess = new HashBag<>();
487 for (final E current : uniqueSet()) {
488 final int myCount = getCount(current);
489 final int otherCount = other.getCount(current);
490 if (1 <= otherCount && otherCount <= myCount) {
491 excess.add(current, myCount - otherCount);
492 } else {
493 excess.add(current, myCount);
494 }
495 }
496 if (!excess.isEmpty()) {
497 result = removeAll(excess);
498 }
499 return result;
500 }
501
502
503
504
505
506
507
508
509 @Override
510 public boolean retainAll(final Collection<?> coll) {
511 if (coll instanceof Bag) {
512 return retainAll((Bag<?>) coll);
513 }
514 return retainAll(new HashBag<>(coll));
515 }
516
517
518
519
520
521
522 @Override
523 public int size() {
524 return size;
525 }
526
527
528
529
530
531
532 @Override
533 public Object[] toArray() {
534 final Object[] result = new Object[size()];
535 int i = 0;
536 for (final E current : map.keySet()) {
537 for (int index = getCount(current); index > 0; index--) {
538 result[i++] = current;
539 }
540 }
541 return result;
542 }
543
544
545
546
547
548
549
550
551
552
553
554
555
556 @Override
557 public <T> T[] toArray(T[] array) {
558 final int size = size();
559 if (array.length < size) {
560 @SuppressWarnings("unchecked")
561 final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), size);
562 array = unchecked;
563 }
564
565 int i = 0;
566 for (final E current : map.keySet()) {
567 for (int index = getCount(current); index > 0; index--) {
568
569 @SuppressWarnings("unchecked")
570 final T unchecked = (T) current;
571 array[i++] = unchecked;
572 }
573 }
574 while (i < array.length) {
575 array[i++] = null;
576 }
577 return array;
578 }
579
580
581
582
583
584
585 @Override
586 public String toString() {
587 if (isEmpty()) {
588 return "[]";
589 }
590 final StringBuilder buf = new StringBuilder();
591 buf.append(CollectionUtils.DEFAULT_TOSTRING_PREFIX);
592 final Iterator<E> it = uniqueSet().iterator();
593 while (it.hasNext()) {
594 final Object current = it.next();
595 final int count = getCount(current);
596 buf.append(count);
597 buf.append(CollectionUtils.COLON);
598 buf.append(current);
599 if (it.hasNext()) {
600 buf.append(CollectionUtils.COMMA);
601 }
602 }
603 buf.append(CollectionUtils.DEFAULT_TOSTRING_SUFFIX);
604 return buf.toString();
605 }
606
607
608
609
610
611
612 @Override
613 public Set<E> uniqueSet() {
614 if (uniqueSet == null) {
615 uniqueSet = UnmodifiableSet.<E>unmodifiableSet(map.keySet());
616 }
617 return uniqueSet;
618 }
619
620 }