diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..e5c07875150ff25a502146a8f1cc5dcf751bc2d6 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,47 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +Examples/Example3/11074.png filter=lfs diff=lfs merge=lfs -text +Examples/Example3/ISPP/10986.png filter=lfs diff=lfs merge=lfs -text +Examples/Example3/ISPP/10989.png filter=lfs diff=lfs merge=lfs -text +Examples/Example3/ISPP/10990.png filter=lfs diff=lfs merge=lfs -text +Examples/Example3/ISPP/10992.png filter=lfs diff=lfs merge=lfs -text +Examples/Example3/ISPP/10993.png filter=lfs diff=lfs merge=lfs -text +Examples/Example3/ISPP/10994.png filter=lfs diff=lfs merge=lfs -text +Examples/Example3/ISPP/10995.png filter=lfs diff=lfs merge=lfs -text +Examples/Example3/ISPP/10996.png filter=lfs diff=lfs merge=lfs -text +Examples/Example3/ISPP/10997.png filter=lfs diff=lfs merge=lfs -text +Examples/Example3/ISPP/10998.png filter=lfs diff=lfs merge=lfs -text +Examples/Example4/11074.png filter=lfs diff=lfs merge=lfs -text +Examples/Example4/ISPP/10986.png filter=lfs diff=lfs merge=lfs -text +Examples/Example4/ISPP/10989.png filter=lfs diff=lfs merge=lfs -text +Examples/Example4/ISPP/10990.png filter=lfs diff=lfs merge=lfs -text +Examples/Example4/ISPP/10992.png filter=lfs diff=lfs merge=lfs -text +Examples/Example4/ISPP/10993.png filter=lfs diff=lfs merge=lfs -text +Examples/Example4/ISPP/10994.png filter=lfs diff=lfs merge=lfs -text +Examples/Example4/ISPP/10995.png filter=lfs diff=lfs merge=lfs -text +Examples/Example4/ISPP/10996.png filter=lfs diff=lfs merge=lfs -text +Examples/Example4/ISPP/10997.png filter=lfs diff=lfs merge=lfs -text +Examples/Example4/ISPP/10998.png filter=lfs diff=lfs merge=lfs -text +Examples/Example5/11074.png filter=lfs diff=lfs merge=lfs -text +Examples/Example5/ISPP/10986.png filter=lfs diff=lfs merge=lfs -text +Examples/Example5/ISPP/10989.png filter=lfs diff=lfs merge=lfs -text +Examples/Example5/ISPP/10990.png filter=lfs diff=lfs merge=lfs -text +Examples/Example5/ISPP/10992.png filter=lfs diff=lfs merge=lfs -text +Examples/Example5/ISPP/10993.png filter=lfs diff=lfs merge=lfs -text +Examples/Example5/ISPP/10994.png filter=lfs diff=lfs merge=lfs -text +Examples/Example5/ISPP/10995.png filter=lfs diff=lfs merge=lfs -text +Examples/Example5/ISPP/10996.png filter=lfs diff=lfs merge=lfs -text +Examples/Example5/ISPP/10997.png filter=lfs diff=lfs merge=lfs -text +Examples/Example5/ISPP/10998.png filter=lfs diff=lfs merge=lfs -text +Examples/Example6/11074.png filter=lfs diff=lfs merge=lfs -text +Examples/Example6/ISPP/10986.png filter=lfs diff=lfs merge=lfs -text +Examples/Example6/ISPP/10989.png filter=lfs diff=lfs merge=lfs -text +Examples/Example6/ISPP/10990.png filter=lfs diff=lfs merge=lfs -text +Examples/Example6/ISPP/10992.png filter=lfs diff=lfs merge=lfs -text +Examples/Example6/ISPP/10993.png filter=lfs diff=lfs merge=lfs -text +Examples/Example6/ISPP/10994.png filter=lfs diff=lfs merge=lfs -text +Examples/Example6/ISPP/10995.png filter=lfs diff=lfs merge=lfs -text +Examples/Example6/ISPP/10996.png filter=lfs diff=lfs merge=lfs -text +Examples/Example6/ISPP/10997.png filter=lfs diff=lfs merge=lfs -text +Examples/Example6/ISPP/10998.png filter=lfs diff=lfs merge=lfs -text diff --git a/Examples/Example1/ISPP/1600.AWGN.1.png b/Examples/Example1/ISPP/1600.AWGN.1.png new file mode 100644 index 0000000000000000000000000000000000000000..563106c2361a4cf608f469f9e08b7d5dd55926e8 Binary files /dev/null and b/Examples/Example1/ISPP/1600.AWGN.1.png differ diff --git a/Examples/Example1/ISPP/1600.AWGN.2.png b/Examples/Example1/ISPP/1600.AWGN.2.png new file mode 100644 index 0000000000000000000000000000000000000000..a55a0953b3791b24914a86d6bd601963734c6748 Binary files /dev/null and b/Examples/Example1/ISPP/1600.AWGN.2.png differ diff --git a/Examples/Example1/ISPP/1600.AWGN.3.png b/Examples/Example1/ISPP/1600.AWGN.3.png new file mode 100644 index 0000000000000000000000000000000000000000..5eb1f33395c9bb2fd515a3a6e26dce600b30738b Binary files /dev/null and b/Examples/Example1/ISPP/1600.AWGN.3.png differ diff --git a/Examples/Example1/ISPP/1600.AWGN.4.png b/Examples/Example1/ISPP/1600.AWGN.4.png new file mode 100644 index 0000000000000000000000000000000000000000..cf9fbbf20b7ea0c183eadec824b749c66e897ec8 Binary files /dev/null and b/Examples/Example1/ISPP/1600.AWGN.4.png differ diff --git a/Examples/Example1/ISPP/1600.AWGN.5.png b/Examples/Example1/ISPP/1600.AWGN.5.png new file mode 100644 index 0000000000000000000000000000000000000000..58a397e65ab1bb7acf30dde9d7bbe62abf57cf5b Binary files /dev/null and b/Examples/Example1/ISPP/1600.AWGN.5.png differ diff --git a/Examples/Example1/ISPP/1600.BLUR.1.png b/Examples/Example1/ISPP/1600.BLUR.1.png new file mode 100644 index 0000000000000000000000000000000000000000..cab65b4700b944dce245694a2aa8562f7e76b543 Binary files /dev/null and b/Examples/Example1/ISPP/1600.BLUR.1.png differ diff --git a/Examples/Example1/ISPP/1600.BLUR.2.png b/Examples/Example1/ISPP/1600.BLUR.2.png new file mode 100644 index 0000000000000000000000000000000000000000..8e0ee2352905db582a019b35106a816146f76710 Binary files /dev/null and b/Examples/Example1/ISPP/1600.BLUR.2.png differ diff --git a/Examples/Example1/ISPP/1600.BLUR.3.png b/Examples/Example1/ISPP/1600.BLUR.3.png new file mode 100644 index 0000000000000000000000000000000000000000..92aebca28f88f5f654341f45defad5052d19748a Binary files /dev/null and b/Examples/Example1/ISPP/1600.BLUR.3.png differ diff --git a/Examples/Example1/ISPP/1600.BLUR.4.png b/Examples/Example1/ISPP/1600.BLUR.4.png new file mode 100644 index 0000000000000000000000000000000000000000..3bbf90fd9199db0a703509605624a4361710e6a7 Binary files /dev/null and b/Examples/Example1/ISPP/1600.BLUR.4.png differ diff --git a/Examples/Example1/ISPP/1600.BLUR.5.png b/Examples/Example1/ISPP/1600.BLUR.5.png new file mode 100644 index 0000000000000000000000000000000000000000..3ce305f8cfb3016205aa0ac7adc716094bda690b Binary files /dev/null and b/Examples/Example1/ISPP/1600.BLUR.5.png differ diff --git a/Examples/Example1/cactus.png b/Examples/Example1/cactus.png new file mode 100644 index 0000000000000000000000000000000000000000..786ba9bd13ffb48c04925b541f9fe03f03c0ec7a Binary files /dev/null and b/Examples/Example1/cactus.png differ diff --git a/Examples/Example2/ISPP/1600.AWGN.1.png b/Examples/Example2/ISPP/1600.AWGN.1.png new file mode 100644 index 0000000000000000000000000000000000000000..563106c2361a4cf608f469f9e08b7d5dd55926e8 Binary files /dev/null and b/Examples/Example2/ISPP/1600.AWGN.1.png differ diff --git a/Examples/Example2/ISPP/1600.AWGN.2.png b/Examples/Example2/ISPP/1600.AWGN.2.png new file mode 100644 index 0000000000000000000000000000000000000000..a55a0953b3791b24914a86d6bd601963734c6748 Binary files /dev/null and b/Examples/Example2/ISPP/1600.AWGN.2.png differ diff --git a/Examples/Example2/ISPP/1600.AWGN.3.png b/Examples/Example2/ISPP/1600.AWGN.3.png new file mode 100644 index 0000000000000000000000000000000000000000..5eb1f33395c9bb2fd515a3a6e26dce600b30738b Binary files /dev/null and b/Examples/Example2/ISPP/1600.AWGN.3.png differ diff --git a/Examples/Example2/ISPP/1600.AWGN.4.png b/Examples/Example2/ISPP/1600.AWGN.4.png new file mode 100644 index 0000000000000000000000000000000000000000..cf9fbbf20b7ea0c183eadec824b749c66e897ec8 Binary files /dev/null and b/Examples/Example2/ISPP/1600.AWGN.4.png differ diff --git a/Examples/Example2/ISPP/1600.AWGN.5.png b/Examples/Example2/ISPP/1600.AWGN.5.png new file mode 100644 index 0000000000000000000000000000000000000000..58a397e65ab1bb7acf30dde9d7bbe62abf57cf5b Binary files /dev/null and b/Examples/Example2/ISPP/1600.AWGN.5.png differ diff --git a/Examples/Example2/ISPP/1600.BLUR.1.png b/Examples/Example2/ISPP/1600.BLUR.1.png new file mode 100644 index 0000000000000000000000000000000000000000..cab65b4700b944dce245694a2aa8562f7e76b543 Binary files /dev/null and b/Examples/Example2/ISPP/1600.BLUR.1.png differ diff --git a/Examples/Example2/ISPP/1600.BLUR.2.png b/Examples/Example2/ISPP/1600.BLUR.2.png new file mode 100644 index 0000000000000000000000000000000000000000..8e0ee2352905db582a019b35106a816146f76710 Binary files /dev/null and b/Examples/Example2/ISPP/1600.BLUR.2.png differ diff --git a/Examples/Example2/ISPP/1600.BLUR.3.png b/Examples/Example2/ISPP/1600.BLUR.3.png new file mode 100644 index 0000000000000000000000000000000000000000..92aebca28f88f5f654341f45defad5052d19748a Binary files /dev/null and b/Examples/Example2/ISPP/1600.BLUR.3.png differ diff --git a/Examples/Example2/ISPP/1600.BLUR.4.png b/Examples/Example2/ISPP/1600.BLUR.4.png new file mode 100644 index 0000000000000000000000000000000000000000..3bbf90fd9199db0a703509605624a4361710e6a7 Binary files /dev/null and b/Examples/Example2/ISPP/1600.BLUR.4.png differ diff --git a/Examples/Example2/ISPP/1600.BLUR.5.png b/Examples/Example2/ISPP/1600.BLUR.5.png new file mode 100644 index 0000000000000000000000000000000000000000..3ce305f8cfb3016205aa0ac7adc716094bda690b Binary files /dev/null and b/Examples/Example2/ISPP/1600.BLUR.5.png differ diff --git a/Examples/Example2/cactus.png b/Examples/Example2/cactus.png new file mode 100644 index 0000000000000000000000000000000000000000..786ba9bd13ffb48c04925b541f9fe03f03c0ec7a Binary files /dev/null and b/Examples/Example2/cactus.png differ diff --git a/Examples/Example3/11074.png b/Examples/Example3/11074.png new file mode 100644 index 0000000000000000000000000000000000000000..49c28b4aadb30844863a450ee105912514f462af --- /dev/null +++ b/Examples/Example3/11074.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cfb1bd5866fd1e5dccedd491dcfe0f4cb6d1a45684eab7f164e1f722d0f7aac +size 3209086 diff --git a/Examples/Example3/ISPP/10986.png b/Examples/Example3/ISPP/10986.png new file mode 100644 index 0000000000000000000000000000000000000000..4df8b29526755a1dcca8d77120c950dda24c22d8 --- /dev/null +++ b/Examples/Example3/ISPP/10986.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a99a55367e7edf4abc4acc474869a15a8cc0cfc95f20efba776f7627c68244bf +size 3069716 diff --git a/Examples/Example3/ISPP/10989.png b/Examples/Example3/ISPP/10989.png new file mode 100644 index 0000000000000000000000000000000000000000..42a673e06bbf154b2b602e0495fe4fa1ad7da8e4 --- /dev/null +++ b/Examples/Example3/ISPP/10989.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56b8b2be506f15727f0e0b083bdf841ea953cd79575540805fbe7c37986049bd +size 5464892 diff --git a/Examples/Example3/ISPP/10990.png b/Examples/Example3/ISPP/10990.png new file mode 100644 index 0000000000000000000000000000000000000000..512bbe5fddafe153e006dbba749bbcfc306e867f --- /dev/null +++ b/Examples/Example3/ISPP/10990.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e4a646089f2e092cb75d7f50138be6001ecb9e1e4e4d398564fb9b4243566a1 +size 3719887 diff --git a/Examples/Example3/ISPP/10992.png b/Examples/Example3/ISPP/10992.png new file mode 100644 index 0000000000000000000000000000000000000000..58442c181e88df1f3a86c46583414c2041208fb0 --- /dev/null +++ b/Examples/Example3/ISPP/10992.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fef090ba9f287bd9a9846f4b1089aede83e8a0016e7ece0ae89ce3d1f5b6e07e +size 1974713 diff --git a/Examples/Example3/ISPP/10993.png b/Examples/Example3/ISPP/10993.png new file mode 100644 index 0000000000000000000000000000000000000000..33dc7a3dc63e7e024c39ce4d821339e97ad30d17 --- /dev/null +++ b/Examples/Example3/ISPP/10993.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9e289b2625af82aa614592ee5f94a45fc1634014d9a54e6667bd7c024223a12 +size 6528198 diff --git a/Examples/Example3/ISPP/10994.png b/Examples/Example3/ISPP/10994.png new file mode 100644 index 0000000000000000000000000000000000000000..b075459362a74a717b220671be4afa7d40ae16b3 --- /dev/null +++ b/Examples/Example3/ISPP/10994.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b1e5bd808c0c5ff1b5ef79c53c28437bba632ea7eb5d903e5ecb1f8b04c3b2c +size 4222359 diff --git a/Examples/Example3/ISPP/10995.png b/Examples/Example3/ISPP/10995.png new file mode 100644 index 0000000000000000000000000000000000000000..05be6a1acb06984bee535badbf744691e91aa77d --- /dev/null +++ b/Examples/Example3/ISPP/10995.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1331c1d539b1395aabf0a504ae20bb7c1a40d35a93df4772d9007e50a5696fe5 +size 3529341 diff --git a/Examples/Example3/ISPP/10996.png b/Examples/Example3/ISPP/10996.png new file mode 100644 index 0000000000000000000000000000000000000000..e8202738921a3b15d037c411ade1b1a388326395 --- /dev/null +++ b/Examples/Example3/ISPP/10996.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebabd04872b52332f90a0e614d4aa79ac1c9fdcfec9b51cbaa1519f2cabfd0b4 +size 2348264 diff --git a/Examples/Example3/ISPP/10997.png b/Examples/Example3/ISPP/10997.png new file mode 100644 index 0000000000000000000000000000000000000000..3fc21b37d160d774443aa438e0a80617fb7ce85b --- /dev/null +++ b/Examples/Example3/ISPP/10997.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36c2f89633aafebbe546cb16e1ab1567a07bb9e6e511b59ace9d1b9b35130cdb +size 3097184 diff --git a/Examples/Example3/ISPP/10998.png b/Examples/Example3/ISPP/10998.png new file mode 100644 index 0000000000000000000000000000000000000000..11c74c8bc72e8daeda411b360cfa748df20e19ba --- /dev/null +++ b/Examples/Example3/ISPP/10998.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f06a3bbb4627d3c27857c88feb0f8a35dabfc1fa34f713320a9b7462e5a23f1 +size 3385941 diff --git a/Examples/Example4/11074.png b/Examples/Example4/11074.png new file mode 100644 index 0000000000000000000000000000000000000000..49c28b4aadb30844863a450ee105912514f462af --- /dev/null +++ b/Examples/Example4/11074.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cfb1bd5866fd1e5dccedd491dcfe0f4cb6d1a45684eab7f164e1f722d0f7aac +size 3209086 diff --git a/Examples/Example4/ISPP/10986.png b/Examples/Example4/ISPP/10986.png new file mode 100644 index 0000000000000000000000000000000000000000..4df8b29526755a1dcca8d77120c950dda24c22d8 --- /dev/null +++ b/Examples/Example4/ISPP/10986.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a99a55367e7edf4abc4acc474869a15a8cc0cfc95f20efba776f7627c68244bf +size 3069716 diff --git a/Examples/Example4/ISPP/10989.png b/Examples/Example4/ISPP/10989.png new file mode 100644 index 0000000000000000000000000000000000000000..42a673e06bbf154b2b602e0495fe4fa1ad7da8e4 --- /dev/null +++ b/Examples/Example4/ISPP/10989.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56b8b2be506f15727f0e0b083bdf841ea953cd79575540805fbe7c37986049bd +size 5464892 diff --git a/Examples/Example4/ISPP/10990.png b/Examples/Example4/ISPP/10990.png new file mode 100644 index 0000000000000000000000000000000000000000..512bbe5fddafe153e006dbba749bbcfc306e867f --- /dev/null +++ b/Examples/Example4/ISPP/10990.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e4a646089f2e092cb75d7f50138be6001ecb9e1e4e4d398564fb9b4243566a1 +size 3719887 diff --git a/Examples/Example4/ISPP/10992.png b/Examples/Example4/ISPP/10992.png new file mode 100644 index 0000000000000000000000000000000000000000..58442c181e88df1f3a86c46583414c2041208fb0 --- /dev/null +++ b/Examples/Example4/ISPP/10992.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fef090ba9f287bd9a9846f4b1089aede83e8a0016e7ece0ae89ce3d1f5b6e07e +size 1974713 diff --git a/Examples/Example4/ISPP/10993.png b/Examples/Example4/ISPP/10993.png new file mode 100644 index 0000000000000000000000000000000000000000..33dc7a3dc63e7e024c39ce4d821339e97ad30d17 --- /dev/null +++ b/Examples/Example4/ISPP/10993.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9e289b2625af82aa614592ee5f94a45fc1634014d9a54e6667bd7c024223a12 +size 6528198 diff --git a/Examples/Example4/ISPP/10994.png b/Examples/Example4/ISPP/10994.png new file mode 100644 index 0000000000000000000000000000000000000000..b075459362a74a717b220671be4afa7d40ae16b3 --- /dev/null +++ b/Examples/Example4/ISPP/10994.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b1e5bd808c0c5ff1b5ef79c53c28437bba632ea7eb5d903e5ecb1f8b04c3b2c +size 4222359 diff --git a/Examples/Example4/ISPP/10995.png b/Examples/Example4/ISPP/10995.png new file mode 100644 index 0000000000000000000000000000000000000000..05be6a1acb06984bee535badbf744691e91aa77d --- /dev/null +++ b/Examples/Example4/ISPP/10995.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1331c1d539b1395aabf0a504ae20bb7c1a40d35a93df4772d9007e50a5696fe5 +size 3529341 diff --git a/Examples/Example4/ISPP/10996.png b/Examples/Example4/ISPP/10996.png new file mode 100644 index 0000000000000000000000000000000000000000..e8202738921a3b15d037c411ade1b1a388326395 --- /dev/null +++ b/Examples/Example4/ISPP/10996.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebabd04872b52332f90a0e614d4aa79ac1c9fdcfec9b51cbaa1519f2cabfd0b4 +size 2348264 diff --git a/Examples/Example4/ISPP/10997.png b/Examples/Example4/ISPP/10997.png new file mode 100644 index 0000000000000000000000000000000000000000..3fc21b37d160d774443aa438e0a80617fb7ce85b --- /dev/null +++ b/Examples/Example4/ISPP/10997.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36c2f89633aafebbe546cb16e1ab1567a07bb9e6e511b59ace9d1b9b35130cdb +size 3097184 diff --git a/Examples/Example4/ISPP/10998.png b/Examples/Example4/ISPP/10998.png new file mode 100644 index 0000000000000000000000000000000000000000..11c74c8bc72e8daeda411b360cfa748df20e19ba --- /dev/null +++ b/Examples/Example4/ISPP/10998.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f06a3bbb4627d3c27857c88feb0f8a35dabfc1fa34f713320a9b7462e5a23f1 +size 3385941 diff --git a/Examples/Example5/11074.png b/Examples/Example5/11074.png new file mode 100644 index 0000000000000000000000000000000000000000..49c28b4aadb30844863a450ee105912514f462af --- /dev/null +++ b/Examples/Example5/11074.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cfb1bd5866fd1e5dccedd491dcfe0f4cb6d1a45684eab7f164e1f722d0f7aac +size 3209086 diff --git a/Examples/Example5/ISPP/10986.png b/Examples/Example5/ISPP/10986.png new file mode 100644 index 0000000000000000000000000000000000000000..4df8b29526755a1dcca8d77120c950dda24c22d8 --- /dev/null +++ b/Examples/Example5/ISPP/10986.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a99a55367e7edf4abc4acc474869a15a8cc0cfc95f20efba776f7627c68244bf +size 3069716 diff --git a/Examples/Example5/ISPP/10989.png b/Examples/Example5/ISPP/10989.png new file mode 100644 index 0000000000000000000000000000000000000000..42a673e06bbf154b2b602e0495fe4fa1ad7da8e4 --- /dev/null +++ b/Examples/Example5/ISPP/10989.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56b8b2be506f15727f0e0b083bdf841ea953cd79575540805fbe7c37986049bd +size 5464892 diff --git a/Examples/Example5/ISPP/10990.png b/Examples/Example5/ISPP/10990.png new file mode 100644 index 0000000000000000000000000000000000000000..512bbe5fddafe153e006dbba749bbcfc306e867f --- /dev/null +++ b/Examples/Example5/ISPP/10990.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e4a646089f2e092cb75d7f50138be6001ecb9e1e4e4d398564fb9b4243566a1 +size 3719887 diff --git a/Examples/Example5/ISPP/10992.png b/Examples/Example5/ISPP/10992.png new file mode 100644 index 0000000000000000000000000000000000000000..58442c181e88df1f3a86c46583414c2041208fb0 --- /dev/null +++ b/Examples/Example5/ISPP/10992.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fef090ba9f287bd9a9846f4b1089aede83e8a0016e7ece0ae89ce3d1f5b6e07e +size 1974713 diff --git a/Examples/Example5/ISPP/10993.png b/Examples/Example5/ISPP/10993.png new file mode 100644 index 0000000000000000000000000000000000000000..33dc7a3dc63e7e024c39ce4d821339e97ad30d17 --- /dev/null +++ b/Examples/Example5/ISPP/10993.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9e289b2625af82aa614592ee5f94a45fc1634014d9a54e6667bd7c024223a12 +size 6528198 diff --git a/Examples/Example5/ISPP/10994.png b/Examples/Example5/ISPP/10994.png new file mode 100644 index 0000000000000000000000000000000000000000..b075459362a74a717b220671be4afa7d40ae16b3 --- /dev/null +++ b/Examples/Example5/ISPP/10994.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b1e5bd808c0c5ff1b5ef79c53c28437bba632ea7eb5d903e5ecb1f8b04c3b2c +size 4222359 diff --git a/Examples/Example5/ISPP/10995.png b/Examples/Example5/ISPP/10995.png new file mode 100644 index 0000000000000000000000000000000000000000..05be6a1acb06984bee535badbf744691e91aa77d --- /dev/null +++ b/Examples/Example5/ISPP/10995.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1331c1d539b1395aabf0a504ae20bb7c1a40d35a93df4772d9007e50a5696fe5 +size 3529341 diff --git a/Examples/Example5/ISPP/10996.png b/Examples/Example5/ISPP/10996.png new file mode 100644 index 0000000000000000000000000000000000000000..e8202738921a3b15d037c411ade1b1a388326395 --- /dev/null +++ b/Examples/Example5/ISPP/10996.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebabd04872b52332f90a0e614d4aa79ac1c9fdcfec9b51cbaa1519f2cabfd0b4 +size 2348264 diff --git a/Examples/Example5/ISPP/10997.png b/Examples/Example5/ISPP/10997.png new file mode 100644 index 0000000000000000000000000000000000000000..3fc21b37d160d774443aa438e0a80617fb7ce85b --- /dev/null +++ b/Examples/Example5/ISPP/10997.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36c2f89633aafebbe546cb16e1ab1567a07bb9e6e511b59ace9d1b9b35130cdb +size 3097184 diff --git a/Examples/Example5/ISPP/10998.png b/Examples/Example5/ISPP/10998.png new file mode 100644 index 0000000000000000000000000000000000000000..11c74c8bc72e8daeda411b360cfa748df20e19ba --- /dev/null +++ b/Examples/Example5/ISPP/10998.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f06a3bbb4627d3c27857c88feb0f8a35dabfc1fa34f713320a9b7462e5a23f1 +size 3385941 diff --git a/Examples/Example6/11074.png b/Examples/Example6/11074.png new file mode 100644 index 0000000000000000000000000000000000000000..49c28b4aadb30844863a450ee105912514f462af --- /dev/null +++ b/Examples/Example6/11074.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cfb1bd5866fd1e5dccedd491dcfe0f4cb6d1a45684eab7f164e1f722d0f7aac +size 3209086 diff --git a/Examples/Example6/ISPP/10986.png b/Examples/Example6/ISPP/10986.png new file mode 100644 index 0000000000000000000000000000000000000000..4df8b29526755a1dcca8d77120c950dda24c22d8 --- /dev/null +++ b/Examples/Example6/ISPP/10986.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a99a55367e7edf4abc4acc474869a15a8cc0cfc95f20efba776f7627c68244bf +size 3069716 diff --git a/Examples/Example6/ISPP/10989.png b/Examples/Example6/ISPP/10989.png new file mode 100644 index 0000000000000000000000000000000000000000..42a673e06bbf154b2b602e0495fe4fa1ad7da8e4 --- /dev/null +++ b/Examples/Example6/ISPP/10989.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56b8b2be506f15727f0e0b083bdf841ea953cd79575540805fbe7c37986049bd +size 5464892 diff --git a/Examples/Example6/ISPP/10990.png b/Examples/Example6/ISPP/10990.png new file mode 100644 index 0000000000000000000000000000000000000000..512bbe5fddafe153e006dbba749bbcfc306e867f --- /dev/null +++ b/Examples/Example6/ISPP/10990.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e4a646089f2e092cb75d7f50138be6001ecb9e1e4e4d398564fb9b4243566a1 +size 3719887 diff --git a/Examples/Example6/ISPP/10992.png b/Examples/Example6/ISPP/10992.png new file mode 100644 index 0000000000000000000000000000000000000000..58442c181e88df1f3a86c46583414c2041208fb0 --- /dev/null +++ b/Examples/Example6/ISPP/10992.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fef090ba9f287bd9a9846f4b1089aede83e8a0016e7ece0ae89ce3d1f5b6e07e +size 1974713 diff --git a/Examples/Example6/ISPP/10993.png b/Examples/Example6/ISPP/10993.png new file mode 100644 index 0000000000000000000000000000000000000000..33dc7a3dc63e7e024c39ce4d821339e97ad30d17 --- /dev/null +++ b/Examples/Example6/ISPP/10993.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9e289b2625af82aa614592ee5f94a45fc1634014d9a54e6667bd7c024223a12 +size 6528198 diff --git a/Examples/Example6/ISPP/10994.png b/Examples/Example6/ISPP/10994.png new file mode 100644 index 0000000000000000000000000000000000000000..b075459362a74a717b220671be4afa7d40ae16b3 --- /dev/null +++ b/Examples/Example6/ISPP/10994.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b1e5bd808c0c5ff1b5ef79c53c28437bba632ea7eb5d903e5ecb1f8b04c3b2c +size 4222359 diff --git a/Examples/Example6/ISPP/10995.png b/Examples/Example6/ISPP/10995.png new file mode 100644 index 0000000000000000000000000000000000000000..05be6a1acb06984bee535badbf744691e91aa77d --- /dev/null +++ b/Examples/Example6/ISPP/10995.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1331c1d539b1395aabf0a504ae20bb7c1a40d35a93df4772d9007e50a5696fe5 +size 3529341 diff --git a/Examples/Example6/ISPP/10996.png b/Examples/Example6/ISPP/10996.png new file mode 100644 index 0000000000000000000000000000000000000000..e8202738921a3b15d037c411ade1b1a388326395 --- /dev/null +++ b/Examples/Example6/ISPP/10996.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebabd04872b52332f90a0e614d4aa79ac1c9fdcfec9b51cbaa1519f2cabfd0b4 +size 2348264 diff --git a/Examples/Example6/ISPP/10997.png b/Examples/Example6/ISPP/10997.png new file mode 100644 index 0000000000000000000000000000000000000000..3fc21b37d160d774443aa438e0a80617fb7ce85b --- /dev/null +++ b/Examples/Example6/ISPP/10997.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36c2f89633aafebbe546cb16e1ab1567a07bb9e6e511b59ace9d1b9b35130cdb +size 3097184 diff --git a/Examples/Example6/ISPP/10998.png b/Examples/Example6/ISPP/10998.png new file mode 100644 index 0000000000000000000000000000000000000000..11c74c8bc72e8daeda411b360cfa748df20e19ba --- /dev/null +++ b/Examples/Example6/ISPP/10998.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f06a3bbb4627d3c27857c88feb0f8a35dabfc1fa34f713320a9b7462e5a23f1 +size 3385941 diff --git a/Examples/Example7/198.bmp b/Examples/Example7/198.bmp new file mode 100644 index 0000000000000000000000000000000000000000..8fc5475e46e8a4ca929de766de3c692ba19afba7 Binary files /dev/null and b/Examples/Example7/198.bmp differ diff --git a/Examples/Example7/ISPP/188.bmp b/Examples/Example7/ISPP/188.bmp new file mode 100644 index 0000000000000000000000000000000000000000..50e3eaa302acab66cff85cd2aa708dfee0fa6d8f Binary files /dev/null and b/Examples/Example7/ISPP/188.bmp differ diff --git a/Examples/Example7/ISPP/189.bmp b/Examples/Example7/ISPP/189.bmp new file mode 100644 index 0000000000000000000000000000000000000000..c3858007525ccc1c90b5f615d6198ec1850ece69 Binary files /dev/null and b/Examples/Example7/ISPP/189.bmp differ diff --git a/Examples/Example7/ISPP/190.bmp b/Examples/Example7/ISPP/190.bmp new file mode 100644 index 0000000000000000000000000000000000000000..7edda3a337581a31eefc932186a1c21afed5abc6 Binary files /dev/null and b/Examples/Example7/ISPP/190.bmp differ diff --git a/Examples/Example7/ISPP/191.bmp b/Examples/Example7/ISPP/191.bmp new file mode 100644 index 0000000000000000000000000000000000000000..d83172e05c6c0f5341ffa486da9454ce6c65f681 Binary files /dev/null and b/Examples/Example7/ISPP/191.bmp differ diff --git a/Examples/Example7/ISPP/192.bmp b/Examples/Example7/ISPP/192.bmp new file mode 100644 index 0000000000000000000000000000000000000000..039619b375806d0d5a2c533852acb2f2c4aceddc Binary files /dev/null and b/Examples/Example7/ISPP/192.bmp differ diff --git a/Examples/Example7/ISPP/193.bmp b/Examples/Example7/ISPP/193.bmp new file mode 100644 index 0000000000000000000000000000000000000000..e12fee15c8b6f7fd042c14c01cdff584b4f25e85 Binary files /dev/null and b/Examples/Example7/ISPP/193.bmp differ diff --git a/Examples/Example7/ISPP/194.bmp b/Examples/Example7/ISPP/194.bmp new file mode 100644 index 0000000000000000000000000000000000000000..96ed99f29f5fb70b497a4eaef8bee9e9259488ac Binary files /dev/null and b/Examples/Example7/ISPP/194.bmp differ diff --git a/Examples/Example7/ISPP/195.bmp b/Examples/Example7/ISPP/195.bmp new file mode 100644 index 0000000000000000000000000000000000000000..8dcc71a496fd0cfac35935ccab1b12614087556b Binary files /dev/null and b/Examples/Example7/ISPP/195.bmp differ diff --git a/Examples/Example7/ISPP/196.bmp b/Examples/Example7/ISPP/196.bmp new file mode 100644 index 0000000000000000000000000000000000000000..1b6c0d6624f361a616377c1c8d5fef35d2f2129a Binary files /dev/null and b/Examples/Example7/ISPP/196.bmp differ diff --git a/Examples/Example7/ISPP/197.bmp b/Examples/Example7/ISPP/197.bmp new file mode 100644 index 0000000000000000000000000000000000000000..07b24d13af4a5a4c66a3191d073de31e0371db2c Binary files /dev/null and b/Examples/Example7/ISPP/197.bmp differ diff --git a/PromptIQA/__pycache__/run_promptIQA.cpython-38.pyc b/PromptIQA/__pycache__/run_promptIQA.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc59853327b8e56db9bad4c2238fe5e9cba726a6 Binary files /dev/null and b/PromptIQA/__pycache__/run_promptIQA.cpython-38.pyc differ diff --git a/PromptIQA/checkpoints/best_model_five_22.pth.tar b/PromptIQA/checkpoints/best_model_five_22.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..1a1e49eb3329f8a202d14ee4e710469d7f9f9765 --- /dev/null +++ b/PromptIQA/checkpoints/best_model_five_22.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:993555b9efaeae660d2dd6f4056f13c6957628ca592a2ce74ff2e8eb5a4a2280 +size 1272842308 diff --git a/PromptIQA/models/__pycache__/gc_loss.cpython-38.pyc b/PromptIQA/models/__pycache__/gc_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19b12e10565089b6885d014c77d178319d13240d Binary files /dev/null and b/PromptIQA/models/__pycache__/gc_loss.cpython-38.pyc differ diff --git a/PromptIQA/models/__pycache__/monet.cpython-37.pyc b/PromptIQA/models/__pycache__/monet.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4981f40217bf658496b52316346921df9ceba3f Binary files /dev/null and b/PromptIQA/models/__pycache__/monet.cpython-37.pyc differ diff --git a/PromptIQA/models/__pycache__/monet.cpython-38.pyc b/PromptIQA/models/__pycache__/monet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..770191ec2151812ca4a843d9947533bca623595d Binary files /dev/null and b/PromptIQA/models/__pycache__/monet.cpython-38.pyc differ diff --git a/PromptIQA/models/__pycache__/monet.cpython-39.pyc b/PromptIQA/models/__pycache__/monet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6b4fcd2a8e81b19787d0fb4e8fd3532a427c77f Binary files /dev/null and b/PromptIQA/models/__pycache__/monet.cpython-39.pyc differ diff --git a/PromptIQA/models/__pycache__/monet_IPF.cpython-38.pyc b/PromptIQA/models/__pycache__/monet_IPF.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1311a33a93d1728d9ebc864bead0b8472e1c90a8 Binary files /dev/null and b/PromptIQA/models/__pycache__/monet_IPF.cpython-38.pyc differ diff --git a/PromptIQA/models/__pycache__/monet_test.cpython-38.pyc b/PromptIQA/models/__pycache__/monet_test.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4d3722f0d41ed5a89aeebcc6a61bd71c8f1397f Binary files /dev/null and b/PromptIQA/models/__pycache__/monet_test.cpython-38.pyc differ diff --git a/PromptIQA/models/__pycache__/monet_wo_prompt.cpython-37.pyc b/PromptIQA/models/__pycache__/monet_wo_prompt.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1a6537b1e8f9bdf84613f38d2b03d3be97fa4a2 Binary files /dev/null and b/PromptIQA/models/__pycache__/monet_wo_prompt.cpython-37.pyc differ diff --git a/PromptIQA/models/__pycache__/monet_wo_prompt.cpython-38.pyc b/PromptIQA/models/__pycache__/monet_wo_prompt.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03e183d81f249ed953a8156f9fd4d9d4f0b7eaf7 Binary files /dev/null and b/PromptIQA/models/__pycache__/monet_wo_prompt.cpython-38.pyc differ diff --git a/PromptIQA/models/__pycache__/vit_base.cpython-37.pyc b/PromptIQA/models/__pycache__/vit_base.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..965087a8258e7fdf2089beb4a8b636dc3f160356 Binary files /dev/null and b/PromptIQA/models/__pycache__/vit_base.cpython-37.pyc differ diff --git a/PromptIQA/models/__pycache__/vit_base.cpython-38.pyc b/PromptIQA/models/__pycache__/vit_base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e92e5a8180bc11d32e194889797f7d739bceac1c Binary files /dev/null and b/PromptIQA/models/__pycache__/vit_base.cpython-38.pyc differ diff --git a/PromptIQA/models/__pycache__/vit_large.cpython-37.pyc b/PromptIQA/models/__pycache__/vit_large.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45b3c0cb6e6d8685edfbcc673dbe059f271c6e1a Binary files /dev/null and b/PromptIQA/models/__pycache__/vit_large.cpython-37.pyc differ diff --git a/PromptIQA/models/__pycache__/vit_large.cpython-38.pyc b/PromptIQA/models/__pycache__/vit_large.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae4be164ecf3153037d34f0ae810137673dcf705 Binary files /dev/null and b/PromptIQA/models/__pycache__/vit_large.cpython-38.pyc differ diff --git a/PromptIQA/models/gc_loss.py b/PromptIQA/models/gc_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a91fa673b04c5ca5a1d759b1de4f1daef0792e07 --- /dev/null +++ b/PromptIQA/models/gc_loss.py @@ -0,0 +1,99 @@ +import torch.nn as nn +import torch +import numpy as np + + +class GC_Loss(nn.Module): + def __init__(self, queue_len=800, alpha=0.5, beta=0.5, gamma=1): + super(GC_Loss, self).__init__() + self.pred_queue = list() + self.gt_queue = list() + self.queue_len = 0 + + self.queue_max_len = queue_len + print('The queue length is: ', self.queue_max_len) + self.mse = torch.nn.MSELoss().cuda() + + self.alpha, self.beta, self.gamma = alpha, beta, gamma + + def consistency(self, pred_data, gt_data): + pred_one_batch, pred_queue = pred_data + gt_one_batch, gt_queue = gt_data + + pred_mean = torch.mean(pred_queue) + gt_mean = torch.mean(gt_queue) + + diff_pred = pred_one_batch - pred_mean + diff_gt = gt_one_batch - gt_mean + + x1 = torch.sum(torch.mul(diff_pred, diff_gt)) + x2_1 = torch.sqrt(torch.sum(torch.mul(diff_pred, diff_pred))) + x2_2 = torch.sqrt(torch.sum(torch.mul(diff_gt, diff_gt))) + + return x1 / (x2_1 * x2_2) + + def ppra(self, x): + """ + Pairwise Preference-based Rank Approximation + """ + + x_bar, x_std = torch.mean(x), torch.std(x) + x_n = (x - x_bar) / x_std + x_n_T = x_n.reshape(-1, 1) + + rank_x = x_n_T - x_n_T.transpose(1, 0) + rank_x = torch.sum(1 / 2 * (1 + torch.erf(rank_x / torch.sqrt(torch.tensor(2, dtype=torch.float)))), dim=1) + + return rank_x + + @torch.no_grad() + def enqueue(self, pred, gt): + bs = pred.shape[0] + self.queue_len = self.queue_len + bs + + self.pred_queue = self.pred_queue + pred.tolist() + self.gt_queue = self.gt_queue + gt.cpu().detach().numpy().tolist() + + if self.queue_len > self.queue_max_len: + self.dequeue(self.queue_len - self.queue_max_len) + self.queue_len = self.queue_max_len + + @torch.no_grad() + def dequeue(self, n): + for _ in range(n): + self.pred_queue.pop(0) + self.gt_queue.pop(0) + + def clear(self): + self.pred_queue.clear() + self.gt_queue.clear() + + def forward(self, x, y): + x_queue = self.pred_queue.copy() + y_queue = self.gt_queue.copy() + + x_all = torch.cat((x, torch.tensor(x_queue).cuda()), dim=0) + y_all = torch.cat((y, torch.tensor(y_queue).cuda()), dim=0) + + PLCC = self.consistency((x, x_all), (y, y_all)) + PGC = 1 - PLCC + + rank_x = self.ppra(x_all) + rank_y = self.ppra(y_all) + SROCC = self.consistency((rank_x[:x.shape[0]], rank_x), (rank_y[:y.shape[0]], rank_y)) + SGC = 1 - SROCC + + GC = (self.alpha * PGC + self.beta * SGC + self.gamma) * self.mse(x, y) + self.enqueue(x, y) + + return GC + + +if __name__ == '__main__': + gc = GC_Loss().cuda() + x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float).cuda() + y = torch.tensor([6, 7, 8, 9, 15], dtype=torch.float).cuda() + + res = gc(x, y) + + print(res) diff --git a/PromptIQA/models/monet.py b/PromptIQA/models/monet.py new file mode 100644 index 0000000000000000000000000000000000000000..b83608bebe007b48de778b773cb84250d998d171 --- /dev/null +++ b/PromptIQA/models/monet.py @@ -0,0 +1,393 @@ +""" + The completion for Mean-opinion Network(MoNet) +""" +import torch +import torch.nn as nn +import timm + +from timm.models.vision_transformer import Block +from einops import rearrange +from itertools import combinations + +from tqdm import tqdm + +class Attention_Block(nn.Module): + def __init__(self, dim, drop=0.1): + super().__init__() + self.c_q = nn.Linear(dim, dim) + self.c_k = nn.Linear(dim, dim) + self.c_v = nn.Linear(dim, dim) + self.norm_fact = dim ** -0.5 + self.softmax = nn.Softmax(dim=-1) + self.proj_drop = nn.Dropout(drop) + + def forward(self, x): + _x = x + B, C, N = x.shape + q = self.c_q(x) + k = self.c_k(x) + v = self.c_v(x) + + attn = q @ k.transpose(-2, -1) * self.norm_fact + attn = self.softmax(attn) + x = (attn @ v).transpose(1, 2).reshape(B, C, N) + x = self.proj_drop(x) + x = x + _x + return x + + +class Self_Attention(nn.Module): + """ Self attention Layer""" + + def __init__(self, in_dim): + super(Self_Attention, self).__init__() + + self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, inFeature): + bs, C, w, h = inFeature.size() + + proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous() + proj_key = self.kConv(inFeature).view(bs, -1, w * h) + energy = torch.bmm(proj_query, proj_key) + attention = self.softmax(energy) + proj_value = self.vConv(inFeature).view(bs, -1, w * h) + + out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous()) + out = out.view(bs, C, w, h) + + out = self.gamma * out + inFeature + + return out + + +class MAL(nn.Module): + """ + Multi-view Attention Learning (MAL) module + """ + + def __init__(self, in_dim=768, feature_num=4, feature_size=28): + super().__init__() + + self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention + self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention + + # Self attention module for each input feature + self.attention_module = nn.ModuleList() + for _ in range(feature_num): + self.attention_module.append(Self_Attention(in_dim)) + + self.feature_num = feature_num + self.in_dim = in_dim + + def forward(self, features): + feature = torch.tensor([]).cuda() + for index, _ in enumerate(features): + feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0) + features = feature + + input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28 + bs, _, _ = input_tensor.shape # [2, 3072, 784] + + in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim, + c=self.feature_num) # bs, 768, 28 * 28 * feature_num + feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768 + + in_channel = input_tensor.permute(0, 2, 1).contiguous() # bs, 28 * 28, 768 * feature_num + channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28 + + weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim, + c=self.feature_num) + channel_weight_sum.permute(0, 2, 1).contiguous()) / 2 # [2, 3072, 784] + + weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1) + + return weight_sum_res # bs, 768, 28 * 28 + + +class SaveOutput: + def __init__(self): + self.outputs = [] + + def __call__(self, module, module_in, module_out): + self.outputs.append(module_out) + + def clear(self): + self.outputs = [] + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +from functools import partial +class MoNet(nn.Module): + def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224): + super().__init__() + self.img_size = img_size + self.input_size = img_size // patch_size + self.dim_mlp = dim_mlp + + self.vit = timm.create_model('vit_base_patch8_224', pretrained=True) + self.vit.norm = nn.Identity() + self.vit.head = nn.Identity() + + self.save_output = SaveOutput() + + # Register Hooks + hook_handles = [] + for layer in self.vit.modules(): + if isinstance(layer, Block): + handle = layer.register_forward_hook(self.save_output) + hook_handles.append(handle) + + self.MALs = nn.ModuleList() + for _ in range(3): + self.MALs.append(MAL()) + + # Image Quality Score Regression + self.fusion_mal = MAL(feature_num=3) + self.block = Block(dim_mlp, 12) + self.cnn = nn.Sequential( + nn.Conv2d(dim_mlp, 256, 5), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.AvgPool2d((2, 2)), + nn.Conv2d(256, 128, 3), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.AvgPool2d((2, 2)), + nn.Conv2d(128, 128, 3), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.AvgPool2d((3, 3)), + ) + + self.i_p_fusion = nn.Sequential( + Block(128, 4), + Block(128, 4), + Block(128, 4), + ) + self.mlp = nn.Sequential( + nn.Linear(128, 64), + nn.GELU(), + nn.Linear(64, 128), + ) + + self.prompt_fusion = nn.Sequential( + Block(128, 4), + Block(128, 4), + Block(128, 4), + ) + + dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=128, num_heads=4, mlp_ratio=4, qkv_bias=True, drop=0, + attn_drop=0, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU) + for i in range(8)]) + self.norm = nn.LayerNorm(128) + + self.score_block = nn.Sequential( + nn.Linear(128, 128 // 2), + nn.ReLU(), + nn.Dropout(drop), + nn.Linear(128 // 2, 1), + nn.Sigmoid() + ) + + self.prompt_feature = {} + + @torch.no_grad() + def clear(self): + self.prompt_feature = {} + + @torch.no_grad() + def inference(self, x, data_type): + prompt_feature = self.prompt_feature[data_type] # 1, n, 128 + + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128 + prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128 + + fusion = self.blocks(torch.cat((img_feature, prompt_feature), dim=1)) # bs, 2, 1 + fusion = self.norm(fusion) + fusion = self.score_block(fusion) + + # iq_res = torch.mean(fusion, dim=1).view(-1) + iq_res = fusion[:, 0].view(-1) + + return iq_res + + @torch.no_grad() + def check_prompt(self, data_type): + return data_type in self.prompt_feature + + @torch.no_grad() + def forward_prompt(self, x, score, data_type): + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + # 分数线性变换为128维 + # score_feature = self.score_projection(score) # bs, 128 + score_feature = score.expand(-1, 128) + + # img_feature 和 score_feature融合得到 funsion_feature + funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128 + funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128 + + # print('Load Prompt For Testing.', funsion_feature.shape) + # self.prompt_feature = funsion_feature.clone() + self.prompt_feature[data_type] = funsion_feature.clone() + + def forward(self, x, score): + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + # 分数线性变换为128维 + # score_feature = self.score_projection(score) # bs, 128 + score_feature = score.expand(-1, 128) # bs, 128 + + # img_feature 和 score_feature融合得到 funsion_feature + funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128 + funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128 + funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128 + funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128 + + fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1 + fusion = self.norm(fusion) + fusion = self.score_block(fusion) + # iq_res = torch.mean(fusion, dim=1).view(-1) + iq_res = fusion[:, 0].view(-1) + + # differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1 + # differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1) + + gt_res = score.view(-1) + # diff_gt_res = 1 - score.view(-1) + + return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res') + + def extract_feature(self, save_output, block_index=[2, 5, 8, 11]): + x1 = save_output.outputs[block_index[0]][:, 1:] + x2 = save_output.outputs[block_index[1]][:, 1:] + x3 = save_output.outputs[block_index[2]][:, 1:] + x4 = save_output.outputs[block_index[3]][:, 1:] + x = torch.cat((x1, x2, x3, x4), dim=2) + return x + + def expand(self, A): + A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1) + + B = None + for index, i in enumerate(A_expanded): + rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0) + if B is None: + B = rmv + else: + B = torch.cat((B, rmv), dim=0) + + return B + +if __name__ == '__main__': + in_feature = torch.zeros((10, 3, 224, 224)).cuda() + gt_feature = torch.tensor([[1], [2], [3], [4], [5], [6], [7], [8], [9], [10]], dtype=torch.float).cuda() + model = MoNet().cuda() + + iq_res, gt_res = model(in_feature, gt_feature) + + print(iq_res.shape) + print(gt_res.shape) diff --git a/PromptIQA/models/monet_IPF.py b/PromptIQA/models/monet_IPF.py new file mode 100644 index 0000000000000000000000000000000000000000..4da63207080ea12974eb6cb538f1a63c602d1109 --- /dev/null +++ b/PromptIQA/models/monet_IPF.py @@ -0,0 +1,397 @@ +""" + The completion for Mean-opinion Network(MoNet) +""" +import torch +import torch.nn as nn +import timm + +from timm.models.vision_transformer import Block +from einops import rearrange +from itertools import combinations + +from tqdm import tqdm + +class Attention_Block(nn.Module): + def __init__(self, dim, drop=0.1): + super().__init__() + self.c_q = nn.Linear(dim, dim) + self.c_k = nn.Linear(dim, dim) + self.c_v = nn.Linear(dim, dim) + self.norm_fact = dim ** -0.5 + self.softmax = nn.Softmax(dim=-1) + self.proj_drop = nn.Dropout(drop) + + def forward(self, x): + _x = x + B, C, N = x.shape + q = self.c_q(x) + k = self.c_k(x) + v = self.c_v(x) + + attn = q @ k.transpose(-2, -1) * self.norm_fact + attn = self.softmax(attn) + x = (attn @ v).transpose(1, 2).reshape(B, C, N) + x = self.proj_drop(x) + x = x + _x + return x + + +class Self_Attention(nn.Module): + """ Self attention Layer""" + + def __init__(self, in_dim): + super(Self_Attention, self).__init__() + + self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, inFeature): + bs, C, w, h = inFeature.size() + + proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous() + proj_key = self.kConv(inFeature).view(bs, -1, w * h) + energy = torch.bmm(proj_query, proj_key) + attention = self.softmax(energy) + proj_value = self.vConv(inFeature).view(bs, -1, w * h) + + out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous()) + out = out.view(bs, C, w, h) + + out = self.gamma * out + inFeature + + return out + + +class MAL(nn.Module): + """ + Multi-view Attention Learning (MAL) module + """ + + def __init__(self, in_dim=768, feature_num=4, feature_size=28): + super().__init__() + + self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention + self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention + + # Self attention module for each input feature + self.attention_module = nn.ModuleList() + for _ in range(feature_num): + self.attention_module.append(Self_Attention(in_dim)) + + self.feature_num = feature_num + self.in_dim = in_dim + + def forward(self, features): + feature = torch.tensor([]).cuda() + for index, _ in enumerate(features): + feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0) + features = feature + + input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28 + bs, _, _ = input_tensor.shape # [2, 3072, 784] + + in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim, + c=self.feature_num) # bs, 768, 28 * 28 * feature_num + feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768 + + in_channel = input_tensor.permute(0, 2, 1).contiguous() # bs, 28 * 28, 768 * feature_num + channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28 + + weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim, + c=self.feature_num) + channel_weight_sum.permute(0, 2, 1).contiguous()) / 2 # [2, 3072, 784] + + weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1) + + return weight_sum_res # bs, 768, 28 * 28 + + +class SaveOutput: + def __init__(self): + self.outputs = [] + + def __call__(self, module, module_in, module_out): + self.outputs.append(module_out) + + def clear(self): + self.outputs = [] + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x +import torch +from functools import partial +class MoNet(nn.Module): + def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224): + super().__init__() + self.img_size = img_size + self.input_size = img_size // patch_size + self.dim_mlp = dim_mlp + + self.vit = timm.create_model('vit_base_patch8_224', pretrained=True) + self.vit.norm = nn.Identity() + self.vit.head = nn.Identity() + + self.save_output = SaveOutput() + + # Register Hooks + hook_handles = [] + for layer in self.vit.modules(): + if isinstance(layer, Block): + handle = layer.register_forward_hook(self.save_output) + hook_handles.append(handle) + + self.MALs = nn.ModuleList() + for _ in range(3): + self.MALs.append(MAL()) + + # Image Quality Score Regression + self.fusion_mal = MAL(feature_num=3) + self.block = Block(dim_mlp, 12) + self.cnn = nn.Sequential( + nn.Conv2d(dim_mlp, 256, 5), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.AvgPool2d((2, 2)), + nn.Conv2d(256, 128, 3), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.AvgPool2d((2, 2)), + nn.Conv2d(128, 128, 3), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.AvgPool2d((3, 3)), + ) + + self.i_p_fusion = nn.Sequential( + Block(128, 4), + Block(128, 4), + Block(128, 4), + ) + self.mlp = nn.Sequential( + nn.Linear(128, 64), + nn.GELU(), + nn.Linear(64, 128), + ) + + self.prompt_fusion = nn.Sequential( + Block(128, 4), + Block(128, 4), + Block(128, 4), + ) + + dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=128, num_heads=4, mlp_ratio=4, qkv_bias=True, drop=0, + attn_drop=0, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU) + for i in range(8)]) + self.norm = nn.LayerNorm(128) + + self.score_block = nn.Sequential( + nn.Linear(128, 128 // 2), + nn.ReLU(), + nn.Dropout(drop), + nn.Linear(128 // 2, 1), + nn.Sigmoid() + ) + + self.prompt_feature = {} + + + + @torch.no_grad() + def clear(self): + self.prompt_feature = {} + + @torch.no_grad() + def inference(self, x, data_type): + prompt_feature = self.prompt_feature[data_type] # 1, n, 128 + + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128 + prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128 + + fusion = self.blocks(torch.cat((img_feature, prompt_feature), dim=1))[:, 0, :] # bs, 2, 1 + # fusion = self.norm(fusion)[:, 0, :] + # fusion = self.score_block(fusion) + + # # iq_res = torch.mean(fusion, dim=1).view(-1) + # iq_res = fusion[:, 0].view(-1) + + return fusion + + @torch.no_grad() + def check_prompt(self, data_type): + return data_type in self.prompt_feature + + @torch.no_grad() + def forward_prompt(self, x, score, data_type): + if data_type in self.prompt_feature: + return + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + # 分数线性变换为128维 + # score_feature = self.score_projection(score) # bs, 128 + score_feature = score.expand(-1, 128) + + # img_feature 和 score_feature融合得到 funsion_feature + funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128 + funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128 + + # print('Load Prompt For Testing.', funsion_feature.shape) + # self.prompt_feature = funsion_feature.clone() + self.prompt_feature[data_type] = funsion_feature.clone() + + def forward(self, x, score): + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + # 分数线性变换为128维 + # score_feature = self.score_projection(score) # bs, 128 + score_feature = score.expand(-1, 128) # bs, 128 + + # img_feature 和 score_feature融合得到 funsion_feature + funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128 + funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128 + funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128 + funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128 + + fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1 + fusion = self.norm(fusion) + fusion = self.score_block(fusion) + # iq_res = torch.mean(fusion, dim=1).view(-1) + iq_res = fusion[:, 0].view(-1) + + # differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1 + # differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1) + + gt_res = score.view(-1) + # diff_gt_res = 1 - score.view(-1) + + return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res') + + def extract_feature(self, save_output, block_index=[2, 5, 8, 11]): + x1 = save_output.outputs[block_index[0]][:, 1:] + x2 = save_output.outputs[block_index[1]][:, 1:] + x3 = save_output.outputs[block_index[2]][:, 1:] + x4 = save_output.outputs[block_index[3]][:, 1:] + x = torch.cat((x1, x2, x3, x4), dim=2) + return x + + def expand(self, A): + A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1) + + B = None + for index, i in enumerate(A_expanded): + rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0) + if B is None: + B = rmv + else: + B = torch.cat((B, rmv), dim=0) + + return B + +if __name__ == '__main__': + in_feature = torch.zeros((10, 3, 224, 224)).cuda() + gt_feature = torch.tensor([[0, 100, 1], [0, 100, 2], [0, 100, 3], [0, 100, 4], [0, 100, 5], [0, 100, 6], [0, 100, 7], [0, 100, 8], [0, 100, 9], [0, 100, 10]], dtype=torch.float).cuda() + model = MoNet().cuda() + + iq_res, gt_res = model(in_feature, gt_feature) + + print(iq_res.shape) + print(gt_res.shape) diff --git a/PromptIQA/models/monet_test.py b/PromptIQA/models/monet_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a4bffd26f1cacb6b31740d31f6a4bc6db8407e0c --- /dev/null +++ b/PromptIQA/models/monet_test.py @@ -0,0 +1,389 @@ +""" + The completion for Mean-opinion Network(MoNet) +""" +import torch +import torch.nn as nn +import timm + +from timm.models.vision_transformer import Block +from einops import rearrange +from itertools import combinations + +from tqdm import tqdm + +class Attention_Block(nn.Module): + def __init__(self, dim, drop=0.1): + super().__init__() + self.c_q = nn.Linear(dim, dim) + self.c_k = nn.Linear(dim, dim) + self.c_v = nn.Linear(dim, dim) + self.norm_fact = dim ** -0.5 + self.softmax = nn.Softmax(dim=-1) + self.proj_drop = nn.Dropout(drop) + + def forward(self, x): + _x = x + B, C, N = x.shape + q = self.c_q(x) + k = self.c_k(x) + v = self.c_v(x) + + attn = q @ k.transpose(-2, -1) * self.norm_fact + attn = self.softmax(attn) + x = (attn @ v).transpose(1, 2).reshape(B, C, N) + x = self.proj_drop(x) + x = x + _x + return x + + +class Self_Attention(nn.Module): + """ Self attention Layer""" + + def __init__(self, in_dim): + super(Self_Attention, self).__init__() + + self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, inFeature): + bs, C, w, h = inFeature.size() + + proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous() + proj_key = self.kConv(inFeature).view(bs, -1, w * h) + energy = torch.bmm(proj_query, proj_key) + attention = self.softmax(energy) + proj_value = self.vConv(inFeature).view(bs, -1, w * h) + + out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous()) + out = out.view(bs, C, w, h) + + out = self.gamma * out + inFeature + + return out + + +class MAL(nn.Module): + """ + Multi-view Attention Learning (MAL) module + """ + + def __init__(self, in_dim=768, feature_num=4, feature_size=28): + super().__init__() + + self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention + self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention + + # Self attention module for each input feature + self.attention_module = nn.ModuleList() + for _ in range(feature_num): + self.attention_module.append(Self_Attention(in_dim)) + + self.feature_num = feature_num + self.in_dim = in_dim + + def forward(self, features): + feature = torch.tensor([]).cuda() + for index, _ in enumerate(features): + feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0) + features = feature + + input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28 + bs, _, _ = input_tensor.shape # [2, 3072, 784] + + in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim, + c=self.feature_num) # bs, 768, 28 * 28 * feature_num + feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768 + + in_channel = input_tensor.permute(0, 2, 1).contiguous() # bs, 28 * 28, 768 * feature_num + channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28 + + weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim, + c=self.feature_num) + channel_weight_sum.permute(0, 2, 1).contiguous()) / 2 # [2, 3072, 784] + + weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1) + + return weight_sum_res # bs, 768, 28 * 28 + + +class SaveOutput: + def __init__(self): + self.outputs = [] + + def __call__(self, module, module_in, module_out): + self.outputs.append(module_out) + + def clear(self): + self.outputs = [] + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x +class MoNet(nn.Module): + def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224): + super().__init__() + self.img_size = img_size + self.input_size = img_size // patch_size + self.dim_mlp = dim_mlp + + self.vit = timm.create_model('vit_base_patch8_224', pretrained=True) + self.vit.norm = nn.Identity() + self.vit.head = nn.Identity() + + self.save_output = SaveOutput() + + # Register Hooks + hook_handles = [] + for layer in self.vit.modules(): + if isinstance(layer, Block): + handle = layer.register_forward_hook(self.save_output) + hook_handles.append(handle) + + self.MALs = nn.ModuleList() + for _ in range(3): + self.MALs.append(MAL()) + + # Image Quality Score Regression + self.fusion_mal = MAL(feature_num=3) + self.block = Block(dim_mlp, 12) + self.cnn = nn.Sequential( + nn.Conv2d(dim_mlp, 256, 5), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.AvgPool2d((2, 2)), + nn.Conv2d(256, 128, 3), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.AvgPool2d((2, 2)), + nn.Conv2d(128, 128, 3), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.AvgPool2d((3, 3)), + ) + + # self.score_projection = nn.Sequential( + # nn.Linear(1, 64), + # nn.GELU(), + # nn.Linear(64, 128), + # ) + + # self.i_p_fusion = nn.Sequential( + # Block(128, 8), + # Block(128, 8), + # Block(128, 8), + # ) + self.i_p_fusion = nn.Sequential( + Block(128, 4), + Block(128, 4), + Block(128, 4), + ) + self.mlp = nn.Sequential( + nn.Linear(128, 64), + nn.GELU(), + nn.Linear(64, 128), + ) + + self.score_block = nn.Sequential( + Block(128, 4), + Block(128, 4), + # Block(128, 4), + nn.Linear(128, 128 // 2), + nn.ReLU(), + nn.Dropout(drop), + nn.Linear(128 // 2, 1), + nn.Sigmoid() + ) + + # self.diff_block = nn.Sequential( + # Block(128, 8), + # Block(128, 8), + # Block(128, 8), + # nn.Linear(128, 64), + # nn.GELU(), + # nn.Linear(64, 1), + # ) + self.prompt_feature = None + + @torch.no_grad() + def clear(self): + self.prompt_feature = None + + @torch.no_grad() + def inference(self, x): + prompt_feature = self.prompt_feature # 1, n, 128 + + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128 + + fusion = self.score_block(torch.cat((img_feature, prompt_feature), dim=1)) # bs, n, 1 + + # iq_res = torch.mean(fusion, dim=1).view(-1) + iq_res = fusion[:, 0].view(-1) + + return iq_res + + def extract_feature(self, save_output, block_index=[2, 5, 8, 11]): + x1 = save_output.outputs[block_index[0]][:, 1:] + x2 = save_output.outputs[block_index[1]][:, 1:] + x3 = save_output.outputs[block_index[2]][:, 1:] + x4 = save_output.outputs[block_index[3]][:, 1:] + x = torch.cat((x1, x2, x3, x4), dim=2) + return x + + @torch.no_grad() + def forward_prompt(self, x, score): + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + # 分数线性变换为128维 + # score_feature = self.score_projection(score) # bs, 128 + score_feature = score.expand(-1, 128) + + # img_feature 和 score_feature融合得到 funsion_feature + funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128 + funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128 + + print('Load Prompt For Testing.', funsion_feature.shape) + self.prompt_feature = funsion_feature.clone() + + def expand(self, A): + A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1) + + B = None + for index, i in enumerate(A_expanded): + rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0) + if B is None: + B = rmv + else: + B = torch.cat((B, rmv), dim=0) + + return B + + def forward(self, x, score): + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + # 分数线性变换为128维 + # score_feature = self.score_projection(score) # bs, 128 + score_feature = score.expand(-1, 128) # bs, 128 + + # img_feature 和 score_feature融合得到 funsion_feature + funsion_feature = self.i_p_fusion(torch.cat((img_feature.detach(), score_feature.unsqueeze(1).detach()), dim=1)) # bs, 2, 128 + funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128 + funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128 + + fusion = self.score_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1 + # iq_res = torch.mean(fusion, dim=1).view(-1) + iq_res = fusion[:, 0].view(-1) + + # differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1 + # differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1) + + gt_res = score.view(-1) + # diff_gt_res = 1 - score.view(-1) + + return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res') + + +if __name__ == '__main__': + in_feature = torch.zeros((10, 3, 224, 224)).cuda() + gt_feature = torch.tensor([[0, 100, 1], [0, 100, 2], [0, 100, 3], [0, 100, 4], [0, 100, 5], [0, 100, 6], [0, 100, 7], [0, 100, 8], [0, 100, 9], [0, 100, 10]], dtype=torch.float).cuda() + model = MoNet().cuda() + + iq_res, gt_res = model(in_feature, gt_feature) + + print(iq_res.shape) + print(gt_res.shape) diff --git a/PromptIQA/models/monet_wo_prompt.py b/PromptIQA/models/monet_wo_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..bb934209cbd010fbade4b46038e0731c516e3729 --- /dev/null +++ b/PromptIQA/models/monet_wo_prompt.py @@ -0,0 +1,392 @@ +""" + The completion for Mean-opinion Network(MoNet) +""" +import torch +import torch.nn as nn +import timm + +from timm.models.vision_transformer import Block +from einops import rearrange +from itertools import combinations + +from tqdm import tqdm +import os +# os.environ['CUDA_VISIBLE_DEVICES'] = '7' + +class Attention_Block(nn.Module): + def __init__(self, dim, drop=0.1): + super().__init__() + self.c_q = nn.Linear(dim, dim) + self.c_k = nn.Linear(dim, dim) + self.c_v = nn.Linear(dim, dim) + self.norm_fact = dim ** -0.5 + self.softmax = nn.Softmax(dim=-1) + self.proj_drop = nn.Dropout(drop) + + def forward(self, x): + _x = x + B, C, N = x.shape + q = self.c_q(x) + k = self.c_k(x) + v = self.c_v(x) + + attn = q @ k.transpose(-2, -1) * self.norm_fact + attn = self.softmax(attn) + x = (attn @ v).transpose(1, 2).reshape(B, C, N) + x = self.proj_drop(x) + x = x + _x + return x + + +class Self_Attention(nn.Module): + """ Self attention Layer""" + + def __init__(self, in_dim): + super(Self_Attention, self).__init__() + + self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, inFeature): + bs, C, w, h = inFeature.size() + + proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous() + proj_key = self.kConv(inFeature).view(bs, -1, w * h) + energy = torch.bmm(proj_query, proj_key) + attention = self.softmax(energy) + proj_value = self.vConv(inFeature).view(bs, -1, w * h) + + out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous()) + out = out.view(bs, C, w, h) + + out = self.gamma * out + inFeature + + return out + + +class MAL(nn.Module): + """ + Multi-view Attention Learning (MAL) module + """ + + def __init__(self, in_dim=768, feature_num=4, feature_size=28): + super().__init__() + + self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention + self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention + + # Self attention module for each input feature + self.attention_module = nn.ModuleList() + for _ in range(feature_num): + self.attention_module.append(Self_Attention(in_dim)) + + self.feature_num = feature_num + self.in_dim = in_dim + + def forward(self, features): + feature = torch.tensor([]).cuda() + for index, _ in enumerate(features): + feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0) + features = feature + + input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28 + bs, _, _ = input_tensor.shape # [2, 3072, 784] + + in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim, + c=self.feature_num) # bs, 768, 28 * 28 * feature_num + feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768 + + in_channel = input_tensor.permute(0, 2, 1).contiguous() # bs, 28 * 28, 768 * feature_num + channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28 + + weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim, + c=self.feature_num) + channel_weight_sum.permute(0, 2, 1).contiguous()) / 2 # [2, 3072, 784] + + weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1) + + return weight_sum_res # bs, 768, 28 * 28 + + +class SaveOutput: + def __init__(self): + self.outputs = [] + + def __call__(self, module, module_in, module_out): + self.outputs.append(module_out) + + def clear(self): + self.outputs = [] + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +from functools import partial +class MoNet(nn.Module): + def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224): + super().__init__() + self.img_size = img_size + self.input_size = img_size // patch_size + self.dim_mlp = dim_mlp + + self.vit = timm.create_model('vit_base_patch8_224', pretrained=True) + self.vit.norm = nn.Identity() + self.vit.head = nn.Identity() + + self.save_output = SaveOutput() + + # Register Hooks + hook_handles = [] + for layer in self.vit.modules(): + if isinstance(layer, Block): + handle = layer.register_forward_hook(self.save_output) + hook_handles.append(handle) + + self.MALs = nn.ModuleList() + for _ in range(3): + self.MALs.append(MAL()) + + # Image Quality Score Regression + self.fusion_mal = MAL(feature_num=3) + self.block = Block(dim_mlp, 12) + self.cnn = nn.Sequential( + nn.Conv2d(dim_mlp, 256, 5), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.AvgPool2d((2, 2)), + nn.Conv2d(256, 128, 3), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.AvgPool2d((2, 2)), + nn.Conv2d(128, 128, 3), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.AvgPool2d((3, 3)), + ) + + # self.i_p_fusion = nn.Sequential( + # Block(128, 4), + # Block(128, 4), + # Block(128, 4), + # ) + # self.mlp = nn.Sequential( + # nn.Linear(128, 64), + # nn.GELU(), + # nn.Linear(64, 128), + # ) + + dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=128, num_heads=4, mlp_ratio=4, qkv_bias=True, drop=0, + attn_drop=0, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU) + for i in range(8)]) + self.norm = nn.LayerNorm(128) + + self.score_block = nn.Sequential( + nn.Linear(128, 128 // 2), + nn.ReLU(), + nn.Dropout(drop), + nn.Linear(128 // 2, 1), + nn.Sigmoid() + ) + + self.prompt_feature = {} + + @torch.no_grad() + def clear(self): + self.prompt_feature = {} + + @torch.no_grad() + def inference(self, x, data_type): + # prompt_feature = self.prompt_feature[data_type] # 1, n, 128 + + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + # prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128 + # prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128 + + fusion = self.blocks(img_feature) # bs, 2, 1 + fusion = self.norm(fusion) + fusion = self.score_block(fusion) + + # iq_res = torch.mean(fusion, dim=1).view(-1) + iq_res = fusion[:, 0].view(-1) + + return iq_res + + @torch.no_grad() + def check_prompt(self, data_type): + return data_type in self.prompt_feature + + @torch.no_grad() + def forward_prompt(self, x, score, data_type): + pass + # if data_type in self.prompt_feature: + # return + # _x = self.vit(x) + # x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + # self.save_output.outputs.clear() + + # x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + # x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + # x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28 + + # # Different Opinion Features (DOF) + # DOF = torch.tensor([]).cuda() + # for index, _ in enumerate(self.MALs): + # DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + # DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # # Image Quality Score Regression + # fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + # IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + # IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + # img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + # # 分数线性变换为128维 + # # score_feature = self.score_projection(score) # bs, 128 + # score_feature = score.expand(-1, 128) + + # # img_feature 和 score_feature融合得到 funsion_feature + # funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128 + # funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128 + + # # print('Load Prompt For Testing.', funsion_feature.shape) + # # self.prompt_feature = funsion_feature.clone() + # self.prompt_feature[data_type] = funsion_feature.clone() + + def forward(self, x, score): + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + # 分数线性变换为128维 + # score_feature = self.score_projection(score) # bs, 128 + # score_feature = score.expand(-1, 128) # bs, 128 + + # # img_feature 和 score_feature融合得到 funsion_feature + # funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128 + # funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128 + # funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128 + # funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128 + + fusion = self.blocks(img_feature) # bs, 2, 1 + fusion = self.norm(fusion) + fusion = self.score_block(fusion) + # iq_res = torch.mean(fusion, dim=1).view(-1) + iq_res = fusion[:, 0].view(-1) + + # differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1 + # differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1) + + gt_res = score.view(-1) + # diff_gt_res = 1 - score.view(-1) + + return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res') + + def extract_feature(self, save_output, block_index=[2, 5, 8, 11]): + x1 = save_output.outputs[block_index[0]][:, 1:] + x2 = save_output.outputs[block_index[1]][:, 1:] + x3 = save_output.outputs[block_index[2]][:, 1:] + x4 = save_output.outputs[block_index[3]][:, 1:] + x = torch.cat((x1, x2, x3, x4), dim=2) + return x + + def expand(self, A): + A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1) + + B = None + for index, i in enumerate(A_expanded): + rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0) + if B is None: + B = rmv + else: + B = torch.cat((B, rmv), dim=0) + + return B + +if __name__ == '__main__': + in_feature = torch.zeros((2, 3, 224, 224)).cuda() + gt_feature = torch.tensor([[0, 100, 1], [0, 100, 2]], dtype=torch.float).cuda() + model = MoNet().cuda() + + iq_res, gt_res = model(in_feature, gt_feature) + + print(iq_res) + print(gt_res.shape) diff --git a/PromptIQA/models/vit_base.py b/PromptIQA/models/vit_base.py new file mode 100644 index 0000000000000000000000000000000000000000..f4fd49706addb282b4437a7af7c9cdd2eb6cb882 --- /dev/null +++ b/PromptIQA/models/vit_base.py @@ -0,0 +1,402 @@ +""" + The completion for Mean-opinion Network(MoNet) +""" +import torch +import torch.nn as nn +import timm + +from timm.models.vision_transformer import Block +from einops import rearrange +from itertools import combinations + +from tqdm import tqdm + + +class Attention_Block(nn.Module): + def __init__(self, dim, drop=0.1): + super().__init__() + self.c_q = nn.Linear(dim, dim) + self.c_k = nn.Linear(dim, dim) + self.c_v = nn.Linear(dim, dim) + self.norm_fact = dim ** -0.5 + self.softmax = nn.Softmax(dim=-1) + self.proj_drop = nn.Dropout(drop) + + def forward(self, x): + _x = x + B, C, N = x.shape + q = self.c_q(x) + k = self.c_k(x) + v = self.c_v(x) + + attn = q @ k.transpose(-2, -1) * self.norm_fact + attn = self.softmax(attn) + x = (attn @ v).transpose(1, 2).reshape(B, C, N) + x = self.proj_drop(x) + x = x + _x + return x + + +class Self_Attention(nn.Module): + """ Self attention Layer""" + + def __init__(self, in_dim): + super(Self_Attention, self).__init__() + + self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, inFeature): + bs, C, w, h = inFeature.size() + + proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous() + proj_key = self.kConv(inFeature).view(bs, -1, w * h) + energy = torch.bmm(proj_query, proj_key) + attention = self.softmax(energy) + proj_value = self.vConv(inFeature).view(bs, -1, w * h) + + out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous()) + out = out.view(bs, C, w, h) + + out = self.gamma * out + inFeature + + return out + + +class three_cnn(nn.Module): + def __init__(self, in_dim) -> None: + super().__init__() + + self.three_cnn = nn.Sequential( + nn.Conv2d(in_dim, in_dim // 2, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(in_dim // 2, in_dim // 2, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(in_dim // 2, in_dim, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + def forward(self, input): + return self.three_cnn(input) + + +class MAL(nn.Module): + def __init__(self, in_dim=768, feature_num=4, feature_size=28): + super().__init__() + self.attention_module = nn.ModuleList() + for i in range(feature_num): + self.attention_module.append(three_cnn(in_dim)) + + self.feature_num = feature_num + self.in_dim = in_dim + self.feature_size = feature_size + + def forward(self, features): + feature = torch.tensor([]).cuda() + for index, _ in enumerate(features): + feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(1)), dim=1) + feature = torch.mean(feature, dim=1) + features = feature.view(-1, self.in_dim, self.feature_size * self.feature_size) + + return features # bs, 768, 28 * 28 + + +class SaveOutput: + def __init__(self): + self.outputs = [] + + def __call__(self, module, module_in, module_out): + self.outputs.append(module_out) + + def clear(self): + self.outputs = [] + + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +from functools import partial + + +class MoNet(nn.Module): + def __init__(self, patch_size=8, drop=0.1, dim_mlp=768, img_size=224): + super().__init__() + self.img_size = img_size + self.input_size = img_size // patch_size + self.dim_mlp = dim_mlp + + self.vit = timm.create_model('vit_base_patch8_224', pretrained=True) + self.vit.norm = nn.Identity() + self.vit.head = nn.Identity() + + self.save_output = SaveOutput() + + # Register Hooks + hook_handles = [] + for layer in self.vit.modules(): + if isinstance(layer, Block): + handle = layer.register_forward_hook(self.save_output) + hook_handles.append(handle) + + self.MALs = nn.ModuleList() + for _ in range(1): + self.MALs.append(MAL()) + + # Image Quality Score Regression + self.fusion_mal = MAL(feature_num=1) + self.block = Block(dim_mlp, 12) + self.cnn = nn.Sequential( + nn.Conv2d(dim_mlp, 256, 5), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.AvgPool2d((2, 2)), + nn.Conv2d(256, 128, 3), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.AvgPool2d((2, 2)), + nn.Conv2d(128, 128, 3), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.AvgPool2d((3, 3)), + ) + + self.i_p_fusion = nn.Sequential( + Block(128, 4), + Block(128, 4), + Block(128, 4), + ) + self.mlp = nn.Sequential( + nn.Linear(128, 64), + nn.GELU(), + nn.Linear(64, 128), + ) + + self.prompt_fusion = nn.Sequential( + Block(128, 4), + Block(128, 4), + Block(128, 4), + ) + + dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=128, num_heads=4, mlp_ratio=4, qkv_bias=True, drop=0, + attn_drop=0, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU) + for i in range(8)]) + self.norm = nn.LayerNorm(128) + + self.score_block = nn.Sequential( + nn.Linear(128, 128 // 2), + nn.ReLU(), + nn.Dropout(drop), + nn.Linear(128 // 2, 1), + nn.Sigmoid() + ) + + self.prompt_feature = {} + + @torch.no_grad() + def clear(self): + self.prompt_feature = {} + + @torch.no_grad() + def inference(self, x, data_type): + prompt_feature = self.prompt_feature[data_type] # 1, n, 128 + + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128 + prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128 + + fusion = self.blocks(torch.cat((img_feature, prompt_feature), dim=1)) # bs, 2, 1 + fusion = self.norm(fusion) + fusion = self.score_block(fusion) + + # iq_res = torch.mean(fusion, dim=1).view(-1) + iq_res = fusion[:, 0].view(-1) + + return iq_res + + @torch.no_grad() + def check_prompt(self, data_type): + return data_type in self.prompt_feature + + @torch.no_grad() + def forward_prompt(self, x, score, data_type): + if data_type in self.prompt_feature: + return + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + # 分数线性变换为128维 + # score_feature = self.score_projection(score) # bs, 128 + score_feature = score.expand(-1, 128) + + # img_feature 和 score_feature融合得到 funsion_feature + funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128 + funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128 + + # print('Load Prompt For Testing.', funsion_feature.shape) + # self.prompt_feature = funsion_feature.clone() + self.prompt_feature[data_type] = funsion_feature.clone() + + def forward(self, x, score): + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + # 分数线性变换为128维 + # score_feature = self.score_projection(score) # bs, 128 + score_feature = score.expand(-1, 128) # bs, 128 + + # img_feature 和 score_feature融合得到 funsion_feature + # funsion_feature = self.i_p_fusion(torch.cat((img_feature.detach(), score_feature.unsqueeze(1).detach()), dim=1)) # bs, 2, 128 + funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128 + funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) #bs, 128 + funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128 + funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128 + + fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1 + fusion = self.norm(fusion) + fusion = self.score_block(fusion) + # iq_res = torch.mean(fusion, dim=1).view(-1) + iq_res = fusion[:, 0].view(-1) + + # differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1 + # differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1) + + gt_res = score.view(-1) + # diff_gt_res = 1 - score.view(-1) + + return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res') + + def extract_feature(self, save_output, block_index=None): + block_index = [2, 5, 8, 11] + x1 = save_output.outputs[block_index[0]][:, 1:] + x2 = save_output.outputs[block_index[1]][:, 1:] + x3 = save_output.outputs[block_index[2]][:, 1:] + x4 = save_output.outputs[block_index[3]][:, 1:] + x = torch.cat((x1, x2, x3, x4), dim=2) + return x + + def expand(self, A): + A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1) + + B = None + for index, i in enumerate(A_expanded): + rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0) + if B is None: + B = rmv + else: + B = torch.cat((B, rmv), dim=0) + + return B + + +if __name__ == '__main__': + in_feature = torch.zeros((11, 3, 384, 384)).cuda() + gt_feature = torch.tensor( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype=torch.float).cuda() + gt_feature = gt_feature.reshape(-1, 1) + model = MoNet().cuda() + + (iq_res, _), (_, _) = model(in_feature, gt_feature) + + print(iq_res.shape) + # print(gt_res.shape) diff --git a/PromptIQA/models/vit_large.py b/PromptIQA/models/vit_large.py new file mode 100644 index 0000000000000000000000000000000000000000..fe1bd68bd7fd452e557e97b25d9b1cc96610d18a --- /dev/null +++ b/PromptIQA/models/vit_large.py @@ -0,0 +1,405 @@ +""" + The completion for Mean-opinion Network(MoNet) +""" +import torch +import torch.nn as nn +import timm + +from timm.models.vision_transformer import Block +from einops import rearrange +from itertools import combinations + +from tqdm import tqdm + + +class Attention_Block(nn.Module): + def __init__(self, dim, drop=0.1): + super().__init__() + self.c_q = nn.Linear(dim, dim) + self.c_k = nn.Linear(dim, dim) + self.c_v = nn.Linear(dim, dim) + self.norm_fact = dim ** -0.5 + self.softmax = nn.Softmax(dim=-1) + self.proj_drop = nn.Dropout(drop) + + def forward(self, x): + _x = x + B, C, N = x.shape + q = self.c_q(x) + k = self.c_k(x) + v = self.c_v(x) + + attn = q @ k.transpose(-2, -1) * self.norm_fact + attn = self.softmax(attn) + x = (attn @ v).transpose(1, 2).reshape(B, C, N) + x = self.proj_drop(x) + x = x + _x + return x + + +class Self_Attention(nn.Module): + """ Self attention Layer""" + + def __init__(self, in_dim): + super(Self_Attention, self).__init__() + + self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, inFeature): + bs, C, w, h = inFeature.size() + + proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1).contiguous() + proj_key = self.kConv(inFeature).view(bs, -1, w * h) + energy = torch.bmm(proj_query, proj_key) + attention = self.softmax(energy) + proj_value = self.vConv(inFeature).view(bs, -1, w * h) + + out = torch.bmm(proj_value, attention.permute(0, 2, 1).contiguous()) + out = out.view(bs, C, w, h) + + out = self.gamma * out + inFeature + + return out + + +class three_cnn(nn.Module): + def __init__(self, in_dim) -> None: + super().__init__() + + self.three_cnn = nn.Sequential( + nn.Conv2d(in_dim, in_dim // 2, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(in_dim // 2, in_dim // 2, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(in_dim // 2, in_dim, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + def forward(self, input): + return self.three_cnn(input) + + +class MAL(nn.Module): + def __init__(self, in_dim=768, feature_num=4, feature_size=28): + super().__init__() + self.attention_module = nn.ModuleList() + for i in range(feature_num): + self.attention_module.append(three_cnn(in_dim)) + + self.feature_num = feature_num + self.in_dim = in_dim + self.feature_size = feature_size + + def forward(self, features): + feature = torch.tensor([]).cuda() + for index, _ in enumerate(features): + feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(1)), dim=1) + feature = torch.mean(feature, dim=1) + features = feature.view(-1, self.in_dim, self.feature_size * self.feature_size) + + return features # bs, 768, 28 * 28 + + +class SaveOutput: + def __init__(self): + self.outputs = [] + + def __call__(self, module, module_in, module_out): + self.outputs.append(module_out) + + def clear(self): + self.outputs = [] + + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +from functools import partial + + +class MoNet(nn.Module): + def __init__(self, patch_size=32, drop=0.1, dim_mlp=1024, img_size=384): + super().__init__() + self.img_size = img_size + self.input_size = img_size // patch_size + self.dim_mlp = dim_mlp + + self.vit = timm.create_model('vit_large_patch32_384', pretrained=True) + self.vit.norm = nn.Identity() + self.vit.head = nn.Identity() + self.vit.head_drop = nn.Identity() + + self.save_output = SaveOutput() + + # Register Hooks + hook_handles = [] + for layer in self.vit.modules(): + if isinstance(layer, Block): + handle = layer.register_forward_hook(self.save_output) + hook_handles.append(handle) + + self.MALs = nn.ModuleList() + for _ in range(3): + self.MALs.append(MAL(in_dim=dim_mlp, feature_size=self.input_size)) + + # Image Quality Score Regression + self.fusion_mal = MAL(in_dim=dim_mlp, feature_num=3, feature_size=self.input_size) + self.block = Block(dim_mlp, 16) + self.cnn = nn.Sequential( + nn.Conv2d(dim_mlp, 512, 5), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + nn.AvgPool2d((2, 2)), # 4 + nn.Conv2d(512, 256, 3, 1), # 2 + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.AvgPool2d((2, 2)), + ) + + self.i_p_fusion = nn.Sequential( + Block(256, 8), + Block(256, 8), + Block(256, 8), + ) + self.mlp = nn.Sequential( + nn.Linear(256, 128), + nn.GELU(), + nn.Linear(128, 256), + ) + + self.prompt_fusion = nn.Sequential( + Block(256, 8), + Block(256, 8), + Block(256, 8), + ) + + dpr = [x.item() for x in torch.linspace(0, 0, 8)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block(dim=256, num_heads=8, mlp_ratio=4, qkv_bias=True, attn_drop=0, drop_path=dpr[i]) + for i in range(8)]) + self.norm = nn.LayerNorm(256) + + self.score_block = nn.Sequential( + nn.Linear(256, 256 // 2), + nn.ReLU(), + nn.Dropout(drop), + nn.Linear(256 // 2, 1), + nn.Sigmoid() + ) + self.prompt_feature = {} + + @torch.no_grad() + def clear(self): + self.prompt_feature = {} + + @torch.no_grad() + def inference(self, x, data_type): + prompt_feature = self.prompt_feature[data_type] # 1, n, 128 + + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, + h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # bs, 4, 768, 28 * 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, + h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + prompt_feature = prompt_feature.repeat(img_feature.shape[0], 1, 1) # bs, n, 128 + prompt_feature = self.prompt_fusion(prompt_feature) # bs, n, 128 + + fusion = self.blocks(torch.cat((img_feature, prompt_feature), dim=1)) # bs, 2, 1 + fusion = self.norm(fusion) + fusion = self.score_block(fusion) + + # iq_res = torch.mean(fusion, dim=1).view(-1) + iq_res = fusion[:, 0].view(-1) + + return iq_res + + @torch.no_grad() + def check_prompt(self, data_type): + return data_type in self.prompt_feature + + @torch.no_grad() + def forward_prompt(self, x, score, data_type): + if data_type in self.prompt_feature: + return + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, + h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, + h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + # 分数线性变换为128维 + # score_feature = self.score_projection(score) # bs, 128 + score_feature = score.expand(-1, 256) + + # img_feature 和 score_feature融合得到 funsion_feature + funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128 + funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)).unsqueeze(0) # 1, n, 128 + + # print('Load Prompt For Testing.', funsion_feature.shape) + # self.prompt_feature = funsion_feature.clone() + self.prompt_feature[data_type] = funsion_feature.clone() + + def forward(self, x, score): + _x = self.vit(x) + x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 + self.save_output.outputs.clear() + + x = x.permute(0, 2, 1).contiguous() # bs, 768 * 4, 28 * 28 + x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, + h=self.input_size) # bs, 4, 768, 28, 28 + x = x.permute(1, 0, 2, 3, 4).contiguous() # 4, bs, 768, 28, 28 + + # Different Opinion Features (DOF) + DOF = torch.tensor([]).cuda() + for index, _ in enumerate(self.MALs): + DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) + DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # M, bs, 768, 28, 28 + # Image Quality Score Regression + fusion_mal = self.fusion_mal(DOF).permute(0, 2, 1).contiguous() # bs, 28 * 28 768 + IQ_feature = self.block(fusion_mal).permute(0, 2, 1).contiguous() # bs, 768, 28 * 28 + IQ_feature = rearrange(IQ_feature, 'c d (w h) -> c d w h', w=self.input_size, + h=self.input_size) # bs, 768, 28, 28 + img_feature = self.cnn(IQ_feature).squeeze(-1).squeeze(-1).unsqueeze(1) # bs, 1, 128 + + # 分数线性变换为128维 + # score_feature = self.score_projection(score) # bs, 128 + score_feature = score.expand(-1, 256) # bs, 128 + + # img_feature 和 score_feature融合得到 funsion_feature funsion_feature = self.i_p_fusion(torch.cat(( + # img_feature.detach(), score_feature.unsqueeze(1).detach()), dim=1)) # bs, 2, 128 + funsion_feature = self.i_p_fusion(torch.cat((img_feature, score_feature.unsqueeze(1)), dim=1)) # bs, 2, 128 + funsion_feature = self.mlp(torch.mean(funsion_feature, dim=1)) # bs, 128 + funsion_feature = self.expand(funsion_feature) # bs, bs - 1, 128 + funsion_feature = self.prompt_fusion(funsion_feature) # bs, bs - 1, 128 + + fusion = self.blocks(torch.cat((img_feature, funsion_feature), dim=1)) # bs, 2, 1 + fusion = self.norm(fusion) + fusion = self.score_block(fusion) + # iq_res = torch.mean(fusion, dim=1).view(-1) + iq_res = fusion[:, 0].view(-1) + + # differ_fusion = self.diff_block(torch.cat((img_feature, funsion_feature), dim=1)) # bs, n, 1 + # differ_iq_res = torch.mean(differ_fusion, dim=1).view(-1) + + gt_res = score.view(-1) + # diff_gt_res = 1 - score.view(-1) + + return (iq_res, 'differ_iq_res'), (gt_res, 'diff_gt_res') + + def extract_feature(self, save_output, block_index=None): + if block_index is None: + block_index = [5, 11, 17, 23] + x1 = save_output.outputs[block_index[0]][:, 1:] + x2 = save_output.outputs[block_index[1]][:, 1:] + x3 = save_output.outputs[block_index[2]][:, 1:] + x4 = save_output.outputs[block_index[3]][:, 1:] + x = torch.cat((x1, x2, x3, x4), dim=2) + return x + + def expand(self, A): + A_expanded = A.unsqueeze(0).expand(A.size(0), -1, -1) + + B = None + for index, i in enumerate(A_expanded): + rmv = torch.cat((i[:index], i[index + 1:])).unsqueeze(0) + if B is None: + B = rmv + else: + B = torch.cat((B, rmv), dim=0) + + return B + + +if __name__ == '__main__': + in_feature = torch.zeros((11, 3, 384, 384)).cuda() + gt_feature = torch.tensor( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype=torch.float).cuda() + gt_feature = gt_feature.reshape(-1, 1) + model = MoNet().cuda() + + (iq_res, _), (_, _) = model(in_feature, gt_feature) + + print(iq_res.shape) + # print(gt_res.shape) diff --git a/PromptIQA/run_promptIQA copy.py b/PromptIQA/run_promptIQA copy.py new file mode 100644 index 0000000000000000000000000000000000000000..82167217b029eabb1025780328ba3a9b79c87f0b --- /dev/null +++ b/PromptIQA/run_promptIQA copy.py @@ -0,0 +1,109 @@ +import os +import random +import torchvision +import cv2 +import torch +from models import monet as MoNet +import numpy as np +from utils.dataset.process import ToTensor, Normalize +from utils.toolkit import * +import warnings +warnings.filterwarnings('ignore') + +import sys +sys.path.append(os.path.dirname(__file__)) + +class PromptIQA(): + def __init__(self) -> None: + pass + +def load_image(img_path, size=224): + try: + d_img = cv2.imread(img_path, cv2.IMREAD_COLOR) + d_img = cv2.resize(d_img, (size, size), interpolation=cv2.INTER_CUBIC) + d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB) + d_img = np.array(d_img).astype('float32') / 255 + d_img = np.transpose(d_img, (2, 0, 1)) + except: + print(img_path) + + return d_img + +def load_model(pkl_path): + + model = MoNet.MoNet() + dict_pkl = {} + # prompt_num = torch.load(pkl_path, map_location='cpu').get('prompt_num') + for key, value in torch.load(pkl_path, map_location='cpu')['state_dict'].items(): + dict_pkl[key[7:]] = value + model.load_state_dict(dict_pkl) + print('Load Model From ', pkl_path) + + return model + +def get_an_img_score(img_path, target): + transform=torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()]) + values_to_insert = np.array([0.0, 1.0]) + position_to_insert = 0 + target = np.insert(target, position_to_insert, values_to_insert) + + sample = load_image(img_path) + samples = {'img': sample, 'gt': target} + samples = transform(samples) + + return samples + +import random +if __name__ == '__main__': + pkl_path = "./checkpoints/best_model_five_22.pth.tar" + model = load_model(pkl_path).cuda() + model.eval() + + img_path = '/mnt/storage/PromptIQA_Demo/CSIQ/dst_src' + + img_tensor, gt_tensor = None, None + img_list = os.listdir(img_path) + random.shuffle(img_list) + for idx, img_name in enumerate(img_list): + if idx == 10: + break + + img_name = os.path.join(img_path, img_name) + score = np.array(idx / 10) + samples = get_an_img_score(img_name, score) + + if img_tensor is None: + img_tensor = samples['img'].unsqueeze(0) + gt_tensor = samples['gt'].type(torch.FloatTensor).unsqueeze(0) + else: + img_tensor = torch.cat((img_tensor, samples['img'].unsqueeze(0)), dim=0) + gt_tensor = torch.cat((gt_tensor, samples['gt'].type(torch.FloatTensor).unsqueeze(0)), dim=0) + + print(img_tensor.shape) + print(gt_tensor.shape) + print(gt_tensor) + + img = img_tensor.squeeze(0).cuda() + label = gt_tensor.squeeze(0).cuda() + reverse = False + if reverse == 2: + label = torch.rand_like(label[:, -1]).cuda() + print(label) + elif reverse == 3: + print('Total Random') + label = torch.rand_like(label[:, -1]).cuda() + img = torch.rand_like(img).cuda() + else: + label = label[:, -1].cuda() if not reverse else (1 - label[:, -1].cuda()) + print('input label: ', label) + model.forward_prompt(img, label.reshape(-1, 1), 'livec') + + img_name = '/mnt/storage/PromptIQA_Demo/CSIQ/src_imgs/1600.png' + score = np.array(random.random()) + samples = get_an_img_score(img_name, score) + + img = samples['img'].unsqueeze(0).cuda() + print(img.shape) + pred = model.inference(img, 'livec') + + print(pred) \ No newline at end of file diff --git a/PromptIQA/run_promptIQA.py b/PromptIQA/run_promptIQA.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea46d6e4624cf56a2150ed422e22549d94f407d --- /dev/null +++ b/PromptIQA/run_promptIQA.py @@ -0,0 +1,73 @@ +import os +import random +import torchvision +import cv2 +import torch +from PromptIQA.models import monet as MoNet +import numpy as np +from PromptIQA.utils.dataset.process import ToTensor, Normalize +from PromptIQA.utils.toolkit import * +import warnings +warnings.filterwarnings('ignore') + +import sys +sys.path.append(os.path.dirname(__file__)) + +def load_model(pkl_path): + model = MoNet.MoNet() + dict_pkl = {} + for key, value in torch.load(pkl_path, map_location='cpu')['state_dict'].items(): + dict_pkl[key[7:]] = value + model.load_state_dict(dict_pkl) + print('Load Model From ', pkl_path) + return model + +class PromptIQA(): + def __init__(self) -> None: + self.pkl_path = "./PromptIQA/checkpoints/best_model_five_22.pth.tar" + self.model = load_model(self.pkl_path).cuda() + self.model.eval() + + self.transform = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()]) + + def get_an_img_score(self, img_path, target=0): + def load_image(img_path, size=224): + if isinstance(img_path, str): + d_img = cv2.imread(img_path, cv2.IMREAD_COLOR) + else: + d_img = img_path + d_img = cv2.resize(d_img, (size, size), interpolation=cv2.INTER_CUBIC) + d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB) + d_img = np.array(d_img).astype('float32') / 255 + d_img = np.transpose(d_img, (2, 0, 1)) + return d_img + + sample = load_image(img_path) + samples = {'img': sample, 'gt': target} + samples = self.transform(samples) + + return samples + + def run(self, ISPP_I, ISPP_S, image): + img_tensor, gt_tensor = None, None + + for isp_i, isp_s in zip(ISPP_I, ISPP_S): + score = np.array(isp_s) + samples = self.get_an_img_score(isp_i, score) + + if img_tensor is None: + img_tensor = samples['img'].unsqueeze(0) + gt_tensor = samples['gt'].type(torch.FloatTensor).unsqueeze(0) + else: + img_tensor = torch.cat((img_tensor, samples['img'].unsqueeze(0)), dim=0) + gt_tensor = torch.cat((gt_tensor, samples['gt'].type(torch.FloatTensor).unsqueeze(0)), dim=0) + + img = img_tensor.squeeze(0).cuda() + label = gt_tensor.squeeze(0).cuda() + self.model.forward_prompt(img, label.reshape(-1, 1), 'example') + + samples = self.get_an_img_score(image) + img = samples['img'].unsqueeze(0).cuda() + pred = self.model.inference(img, 'example') + + return round(pred.item(), 4) \ No newline at end of file diff --git a/PromptIQA/t.py b/PromptIQA/t.py new file mode 100644 index 0000000000000000000000000000000000000000..de91e410072455664c11503abc2c956e29336aa0 --- /dev/null +++ b/PromptIQA/t.py @@ -0,0 +1,2 @@ +a = "(1+1)**(2**2)" +print(eval(a)) \ No newline at end of file diff --git a/PromptIQA/test.py b/PromptIQA/test.py new file mode 100644 index 0000000000000000000000000000000000000000..4855e057ff53407e2567aff91b2296d43623985b --- /dev/null +++ b/PromptIQA/test.py @@ -0,0 +1,429 @@ +import sys + +from utils import log_writer + +import argparse +import builtins +import os +import random +import shutil +import time + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +# from models import monet as MoNet +from torch.utils.data import ConcatDataset +from utils.dataset import data_loader + +from utils.toolkit import * + +loger_path = None + + +def init(config): + global loger_path + if config.dist_url == "env://" and config.world_size == -1: + config.world_size = int(os.environ["WORLD_SIZE"]) + + config.distributed = config.world_size > 1 or config.multiprocessing_distributed + + print("config.distributed", config.distributed) + + loger_path = os.path.join(config.save_path, "inference_log") + if not os.path.isdir(loger_path): + os.makedirs(loger_path) + + print("----------------------------------") + print( + "Begin Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) + ) + printArgs(config, loger_path) + # os.environ["CUDA_VISIBLE_DEVICES"] = '2,3,4,5,6,7' + os.environ["CUDA_VISIBLE_DEVICES"] = '0' + # os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3,4,5' + # os.environ["CUDA_VISIBLE_DEVICES"] = '6,7' + # os.environ["CUDA_VISIBLE_DEVICES"] = '6' + # setup_seed(config.seed) + + +def main(config): + init(config) + ngpus_per_node = torch.cuda.device_count() + if config.multiprocessing_distributed: + config.world_size = ngpus_per_node * config.world_size + + print(config.world_size, ngpus_per_node, ngpus_per_node) + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config)) + else: + # Simply call main_worker function + main_worker(config.gpu, ngpus_per_node, config) + + print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))) + + +@torch.no_grad() +def gather_together(data): # 封装成一个函数,,用于收集各个gpu上的data数据,并返回一个list + dist.barrier() + world_size = dist.get_world_size() + gather_data = [None for _ in range(world_size)] + dist.all_gather_object(gather_data, data) + return gather_data + +import importlib.util +def main_worker(gpu, ngpus_per_node, args): + models_path = os.path.join(args.save_path, "training_files", 'models', 'monet.py') + spec = importlib.util.spec_from_file_location("monet_module", models_path) + monet_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(monet_module) + MoNet = monet_module + + loger_path = os.path.join(args.save_path, "inference_log") + if gpu == 0: + sys.stdout = log_writer.Logger(os.path.join(loger_path, f"inference_log_{args.prompt_type}_{args.reverse}.log")) + args.gpu = gpu + + # suppress printing if not master + if args.multiprocessing_distributed and args.gpu != 0: + def print_pass(*args): + pass + + builtins.print = print_pass + + if args.gpu is not None: + print("Use GPU: {} for testing".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + args.rank = args.rank * ngpus_per_node + gpu + dist.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + + # create model + model = MoNet.MoNet() + dict_pkl = {} + prompt_num = torch.load(args.pkl_path, map_location='cpu').get('prompt_num') + for key, value in torch.load(args.pkl_path, map_location='cpu')['state_dict'].items(): + dict_pkl[key[7:]] = value + model.load_state_dict(dict_pkl) + print('Load Model From ', args.pkl_path) + + if args.distributed: + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu] + ) + print("Model Distribute.") + else: + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model) + + if prompt_num is None: + prompt_num = args.batch_size - 1 + prompt_num = 10 + print('prompt_num', prompt_num) + + test_prompt_list, test_data_list = {}, [] + # fix_prompt = None + for dataset in args.dataset: + print('---Load ', dataset) + path, train_index, test_index = get_data(dataset=dataset, split_seed=args.seed) + # if dataset == 'spaq' and False: + if dataset == 'spaq': + for column in range(2, 8): + print('sapq column train', column) + test_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, test_index, istrain=False, column=column) + test_data_list.append(test_dataset.get_samples()) + + train_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, train_index, istrain=False, column=column) + test_prompt_list[dataset+f'_{column}'] = train_dataset.get_prompt(prompt_num, args.prompt_type) + else: + test_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, test_index, istrain=False, types=args.types) + test_data_list.append(test_dataset.get_samples()) + + train_dataset = data_loader.Data_Loader(args.batch_size, dataset, path, train_index, istrain=False, types=args.types) + test_prompt_list[dataset] = train_dataset.get_prompt(prompt_num, args.prompt_type) + print('args.prompt_type', args.prompt_type) + + combined_test_samples = ConcatDataset(test_data_list) + print("test_dataset", len(combined_test_samples)) + test_sampler = torch.utils.data.distributed.DistributedSampler(combined_test_samples) + + test_loader = torch.utils.data.DataLoader( + combined_test_samples, + batch_size=1, + shuffle=(test_sampler is None), + num_workers=args.workers, + sampler=test_sampler, + drop_last=False, + pin_memory=True, + ) + + if args.distributed: + test_sampler.set_epoch(0) + + for idxsa in range(1): + test_srocc, test_plcc, pred_scores, gt_scores, path = test( + test_loader, model, test_prompt_list, reverse=args.reverse + ) + print('gt_scores', len(pred_scores), len(gt_scores)) + print('Summary---') + + gt_scores = gather_together(gt_scores) # 进行汇总,得到一个list + pred_scores = gather_together(pred_scores) # 进行汇总,得到一个list + + gt_score_dict, pred_score_dict = {}, {} + for sublist in gt_scores: + for k, v in sublist.items(): + if k not in gt_score_dict: + gt_score_dict[k] = v + else: + gt_score_dict[k] = gt_score_dict[k] + v + + for sublist in pred_scores: + for k, v in sublist.items(): + if k not in pred_score_dict: + pred_score_dict[k] = v + else: + pred_score_dict[k] = pred_score_dict[k] + v + + gt_score_dict = dict(sorted(gt_score_dict.items())) + test_srocc, test_plcc = 0, 0 + for k, v in gt_score_dict.items(): + test_srocc_, test_plcc_ = cal_srocc_plcc(gt_score_dict[k], pred_score_dict[k]) + print('\t{} Test SROCC: {}, PLCC: {}'.format(k, round(test_srocc_, 4), round(test_plcc_, 4))) + # print('Pred: ', pred_score_dict[k][:10]) + # print('GT: ', gt_score_dict[k][:10]) + # print('-----') + + with open('{}_{}.csv'.format(idxsa, k), 'w') as f: + for i, j in zip(gt_score_dict[k], pred_score_dict[k]): + f.write('{},{}\n'.format(i, j)) + test_srocc += test_srocc_ + test_plcc += test_plcc_ + + +def test(test_loader, MoNet, promt_data_loader, reverse=False): + """Training""" + pred_scores = {} + gt_scores = {} + path = [] + + batch_time = AverageMeter("Time", ":6.3f") + srocc = AverageMeter("SROCC", ":6.2f") + plcc = AverageMeter("PLCC", ":6.2f") + progress = ProgressMeter( + len(test_loader), + [batch_time, srocc, plcc], + prefix="Testing ", + ) + + print('reverse ----', reverse) + MoNet.train(False) + with torch.no_grad(): + for index, (img_or, label_or, paths, dataset_type) in enumerate(test_loader): + # print(dataset_type) + t = time.time() + dataset_type = dataset_type[0] + + has_prompt = False + if hasattr(MoNet.module, 'check_prompt'): + has_prompt = MoNet.module.check_prompt(dataset_type) + + if not has_prompt: + print('Load Prompt For ', dataset_type) + prompt_dataset = promt_data_loader[dataset_type] + for img, label in prompt_dataset: + img = img.squeeze(0).cuda() + label = label.squeeze(0).cuda() + if reverse == 2: + # label = torch.tensor([random.random() for i in range(len(label[:, -1]))]).cuda() + # + label = torch.rand_like(label[:, -1]).cuda() + print(label) + elif reverse == 3: + print('Total Random') + label = torch.rand_like(label[:, -1]).cuda() + img = torch.rand_like(img).cuda() + else: + label = label[:, -1].cuda() if not reverse else (1 - label[:, -1].cuda()) + MoNet.module.forward_prompt(img, label.reshape(-1, 1), dataset_type) + + img = img_or.squeeze(0).cuda() + label = label_or.squeeze(0).cuda()[:, 2] + + # print(img.shape) + + pred = MoNet.module.inference(img, dataset_type) + + if dataset_type not in pred_scores: + pred_scores[dataset_type] = [] + + if dataset_type not in gt_scores: + gt_scores[dataset_type] = [] + + pred_scores[dataset_type] = pred_scores[dataset_type] + pred.cpu().tolist() + gt_scores[dataset_type] = gt_scores[dataset_type] + label.cpu().tolist() + path = path + list(paths) + + batch_time.update(time.time() - t) + + if index % 100 == 0: + for k, v in pred_scores.items(): + test_srocc, test_plcc = cal_srocc_plcc(pred_scores[k], gt_scores[k]) + # print('\t{}, SROCC: {}, PLCC: {}'.format(k, round(test_srocc, 4), round(test_plcc, 4))) + srocc.update(test_srocc) + plcc.update(test_plcc) + + progress.display(index) + + MoNet.module.clear() + # MoNet.train(True) + return 'test_srocc', 'test_plcc', pred_scores, gt_scores, path + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--seed", + dest="seed", + type=int, + default=570908, + help="Random seeds for result reproduction.", + ) + + parser.add_argument( + "--mal_num", + dest="mal_num", + type=int, + default=2, + help="The number of the MAL modules.", + ) + + # data related + parser.add_argument( + "--dataset", + dest="dataset", + nargs='+', default=None, + help="Support datasets: livec|koniq10k|bid|spaq", + ) + + # training related + parser.add_argument( + "--queue_ratio", + dest="queue_ratio", + type=float, + default=0.6, + help="Ratio of queue length used in GC loss to training set length.", + ) + + parser.add_argument( + "--loss", + dest="loss", + type=str, + default="MSE", + help="Loss function to use. Support losses: GC|MAE|MSE.", + ) + + parser.add_argument( + "--lr", dest="lr", type=float, default=1e-5, help="Learning rate" + ) + + parser.add_argument( + "--weight_decay", + dest="weight_decay", + type=float, + default=1e-5, + help="Weight decay", + ) + parser.add_argument( + "--batch_size", dest="batch_size", type=int, default=11, help="Batch size" + ) + parser.add_argument( + "--epochs", dest="epochs", type=int, default=50, help="Epochs for training" + ) + parser.add_argument( + "--T_max", + dest="T_max", + type=int, + default=50, + help="Hyper-parameter for CosineAnnealingLR", + ) + parser.add_argument( + "--eta_min", + dest="eta_min", + type=int, + default=0, + help="Hyper-parameter for CosineAnnealingLR", + ) + + parser.add_argument( + "-j", + "--workers", + default=32, + type=int, + metavar="N", + help="number of data loading workers (default: 32)", + ) + + # result related + parser.add_argument( + "--save_path", + dest="save_path", + type=str, + default="./save_logs/Matrix_Comparation_Koniq_bs_25", + help="The path where the model and logs will be saved.", + ) + + parser.add_argument( + "--world-size", + default=-1, + type=int, + help="number of nodes for distributed training", + ) + parser.add_argument( + "--rank", default=-1, type=int, help="node rank for distributed training" + ) + parser.add_argument( + "--dist-url", + default="tcp://224.66.41.62:23456", + type=str, + help="url used to set up distributed training", + ) + parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" + ) + parser.add_argument( + "--multiprocessing-distributed", + action="store_true", + help="Use multi-processing distributed training to launch " + "N processes per node, which has N GPUs. This is the " + "fastest way to use PyTorch for either single node or " + "multi node data parallel training", + ) + + parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.") + parser.add_argument("--pkl_path", required=True, type=str) + parser.add_argument("--prompt_type", required=True, type=str) + parser.add_argument("--reverse", required=True, type=int) + parser.add_argument("--types", default='SSIM', type=str) + + config = parser.parse_args() + + config.save_path = os.path.dirname(config.pkl_path) + + main(config) diff --git a/PromptIQA/test.sh b/PromptIQA/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..7023e9e85959db0855dfe679d24751d3b7258785 --- /dev/null +++ b/PromptIQA/test.sh @@ -0,0 +1,9 @@ +# python test.py --dist-url 'tcp://localhost:10055' --dataset spaq tid2013 livec bid spaq flive --batch_size 50 --prompt_type fix --multiprocessing-distributed --world-size 1 --rank 0 --reverse 0 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/FourTask/N_F_A_U_RandomScale_MAE_loaderDebug_Rate95/best_model_five_52.pth.tar +# python test.py --dist-url 'tcp://localhost:12755' --dataset csiq --batch_size 50 --prompt_type fix --multiprocessing-distributed --world-size 1 --rank 0 --reverse 3 --seed 2024 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Training_log/FourTask/N_F_A_U_RandomScale_MAE_loaderDebug_Rate95/best_model_five_52.pth.tar +python test.py --dist-url 'tcp://localhost:12755' --dataset livec bid csiq --batch_size 50 --prompt_type random --multiprocessing-distributed --world-size 1 --rank 0 --reverse 2 --seed 2026 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Formal/PromptIQA_2026/best_model_five_92.pth.tar +# reverse 0 no, 1 yes, 2 random + +python test.py --dist-url 'tcp://localhost:12755' --dataset tid2013_other --batch_size 50 --prompt_type random --multiprocessing-distributed --world-size 1 --rank 0 --reverse 2 --seed 2026 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Formal/PromptIQA_2026/best_model_five_92.pth.tar --types 'SSIM' + + +CUDA_VISIBLE_DEVICES="0" python test.py --dist-url 'tcp://localhost:12755' --dataset tid2013_other --batch_size 50 --prompt_type random --multiprocessing-distributed --world-size 1 --rank 0 --reverse 2 --seed 2024 --pkl_path /disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Publication/PromptIQA_2024_WO_Norm_Score/best_model_five_22.pth.tar --types 'SSIM' \ No newline at end of file diff --git a/PromptIQA/utils/__pycache__/iqa_solver.cpython-311.pyc b/PromptIQA/utils/__pycache__/iqa_solver.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..978019ba7e373a9875194d7f6e750145f3851f52 Binary files /dev/null and b/PromptIQA/utils/__pycache__/iqa_solver.cpython-311.pyc differ diff --git a/PromptIQA/utils/__pycache__/iqa_solver.cpython-38.pyc b/PromptIQA/utils/__pycache__/iqa_solver.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f51710e13bc6f6f5009bee6c9904398696d7febc Binary files /dev/null and b/PromptIQA/utils/__pycache__/iqa_solver.cpython-38.pyc differ diff --git a/PromptIQA/utils/__pycache__/log_writer.cpython-311.pyc b/PromptIQA/utils/__pycache__/log_writer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..804b6a1365a03d2cde1ffa5515b8578a884fabea Binary files /dev/null and b/PromptIQA/utils/__pycache__/log_writer.cpython-311.pyc differ diff --git a/PromptIQA/utils/__pycache__/log_writer.cpython-312.pyc b/PromptIQA/utils/__pycache__/log_writer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3539fc36f81cf12ce875524c1ffdcfb4cc3203b7 Binary files /dev/null and b/PromptIQA/utils/__pycache__/log_writer.cpython-312.pyc differ diff --git a/PromptIQA/utils/__pycache__/log_writer.cpython-37.pyc b/PromptIQA/utils/__pycache__/log_writer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38dbd0d62c22065c45898878b9cfed1bdbdd74cb Binary files /dev/null and b/PromptIQA/utils/__pycache__/log_writer.cpython-37.pyc differ diff --git a/PromptIQA/utils/__pycache__/log_writer.cpython-38.pyc b/PromptIQA/utils/__pycache__/log_writer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d34cdf0be326ef885d4d08e9fff62b67cb23919 Binary files /dev/null and b/PromptIQA/utils/__pycache__/log_writer.cpython-38.pyc differ diff --git a/PromptIQA/utils/__pycache__/toolkit.cpython-37.pyc b/PromptIQA/utils/__pycache__/toolkit.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3326ce3aa7be2705da15524ed10b93d0bd25444c Binary files /dev/null and b/PromptIQA/utils/__pycache__/toolkit.cpython-37.pyc differ diff --git a/PromptIQA/utils/__pycache__/toolkit.cpython-38.pyc b/PromptIQA/utils/__pycache__/toolkit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdd1c41712190409e215bdf36173b0ec62fee817 Binary files /dev/null and b/PromptIQA/utils/__pycache__/toolkit.cpython-38.pyc differ diff --git a/PromptIQA/utils/dataset/__pycache__/data_loader.cpython-311.pyc b/PromptIQA/utils/dataset/__pycache__/data_loader.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1de6aef8c63ae143f31d4108568549f1e3e9c1c Binary files /dev/null and b/PromptIQA/utils/dataset/__pycache__/data_loader.cpython-311.pyc differ diff --git a/PromptIQA/utils/dataset/__pycache__/data_loader.cpython-37.pyc b/PromptIQA/utils/dataset/__pycache__/data_loader.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dc2ca9f123923de7a1bb247d365b55ef5894c85 Binary files /dev/null and b/PromptIQA/utils/dataset/__pycache__/data_loader.cpython-37.pyc differ diff --git a/PromptIQA/utils/dataset/__pycache__/data_loader.cpython-38.pyc b/PromptIQA/utils/dataset/__pycache__/data_loader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7656416e76bbb24547188d61066adaf0e100cb73 Binary files /dev/null and b/PromptIQA/utils/dataset/__pycache__/data_loader.cpython-38.pyc differ diff --git a/PromptIQA/utils/dataset/__pycache__/folders.cpython-37.pyc b/PromptIQA/utils/dataset/__pycache__/folders.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe2accd9b43671ddb269f579c1c9e14d4a845665 Binary files /dev/null and b/PromptIQA/utils/dataset/__pycache__/folders.cpython-37.pyc differ diff --git a/PromptIQA/utils/dataset/__pycache__/folders.cpython-38.pyc b/PromptIQA/utils/dataset/__pycache__/folders.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92b9bc2464cfd83ccc64116fb09e1253d6bf6f35 Binary files /dev/null and b/PromptIQA/utils/dataset/__pycache__/folders.cpython-38.pyc differ diff --git a/PromptIQA/utils/dataset/__pycache__/process.cpython-37.pyc b/PromptIQA/utils/dataset/__pycache__/process.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09f0ca8e76f71989af0d0e02a39316c04fc37713 Binary files /dev/null and b/PromptIQA/utils/dataset/__pycache__/process.cpython-37.pyc differ diff --git a/PromptIQA/utils/dataset/__pycache__/process.cpython-38.pyc b/PromptIQA/utils/dataset/__pycache__/process.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0df91d6233cf8a7fa9b699164c729d8ea808ffb Binary files /dev/null and b/PromptIQA/utils/dataset/__pycache__/process.cpython-38.pyc differ diff --git a/PromptIQA/utils/dataset/data_loader.py b/PromptIQA/utils/dataset/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..70f395c6d29a3451ee604d9dabb12401283cf61c --- /dev/null +++ b/PromptIQA/utils/dataset/data_loader.py @@ -0,0 +1,76 @@ +import torch +import torchvision + +from utils.dataset import folders +from utils.dataset.process import ToTensor, Normalize, RandHorizontalFlip + +class Data_Loader(): + """Dataset class for IQA databases""" + + def __init__(self, batch_size, dataset, path, img_indx, istrain=True, column=2, dist_type=None, types='SSIM'): + + self.batch_size = batch_size + self.istrain = istrain + # print('1. column --------', column, dataset) + + if istrain: + transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), RandHorizontalFlip(prob_aug=0.5), ToTensor()]) + else: + transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()]) + + if dataset == 'livec': + self.data = folders.LIVEC(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'koniq10k': + self.data = folders.Koniq10k(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'uhdiqa': + self.data = folders.uhdiqa(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'bid': + self.data = folders.BID(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'spaq': + self.data = folders.SPAQ(root=path, index=img_indx, transform=transforms, column=column, batch_size=batch_size, istrain=istrain) + elif dataset == 'flive': + self.data = folders.FLIVE(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'csiq': + self.data = folders.CSIQ(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain, dist_type=dist_type) + elif dataset == 'csiq_other': + self.data = folders.csiq_other(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain, dist_type=dist_type, types=types) + elif dataset == 'live': + self.data = folders.LIVEFolder(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'tid2013': + self.data = folders.TID2013Folder(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'tid2013_other': + self.data = folders.TID2013Folder_Other_Type(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain, types=types) + elif dataset == 'kadid': + self.data = folders.KADID(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'kadid_other': + self.data = folders.KADID_Other(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain, types=types) + elif dataset == 'PIQ2023': + self.data = folders.PIQ2023(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'GFIQA_20k': + self.data = folders.GFIQA_20k(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'AGIQA_3k': + self.data = folders.AGIQA_3k(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'AIGCIQA2023': + self.data = folders.AIGCIQA2023(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'UWIQA': + self.data = folders.UWIQA(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'CGIQA6k': + self.data = folders.CGIQA6k(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'AIGCIQA3W': + self.data = folders.AIGCIQA3W(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + elif dataset == 'tid2013_multi_dim': + self.data = folders.TID_Multi_Dim(root=path, index=img_indx, transform=transforms, batch_size=batch_size, istrain=istrain) + else: + raise Exception("Only support livec, koniq10k, bid, spaq.") + + def get_data(self): + dataloader = torch.utils.data.DataLoader(self.data, batch_size=self.batch_size, shuffle=self.istrain, num_workers=16, drop_last=self.istrain) + return dataloader + + def get_samples(self): + return self.data + + def get_prompt(self, n=5, sample_type='fix'): + # print('Get {} images for prompting.'.format(n)) + prompt_data = self.data.get_promt(n=n, sample_type=sample_type) + return torch.utils.data.DataLoader(prompt_data, batch_size=prompt_data.__len__(), shuffle=False) \ No newline at end of file diff --git a/PromptIQA/utils/dataset/dataset_info.json b/PromptIQA/utils/dataset/dataset_info.json new file mode 100644 index 0000000000000000000000000000000000000000..a5f4272df78765a58bee4ac77fdcf2f64228eb67 --- /dev/null +++ b/PromptIQA/utils/dataset/dataset_info.json @@ -0,0 +1,23 @@ +{ + "livec": ["/disk1/Datasets/IQA/ChallengeDB_release", 1162], + "koniq10k": ["/disk1/Datasets/IQA/koniq-10k", 10073], + "flive": ["/disk1/Datasets/IQA/FLIVE_Database/database", 39810], + "bid": ["/disk1/Datasets/IQA/BID/ImageDatabase", 586], + "spaq": ["/disk1/Datasets/IQA/SPAQ", 11125], + "live": ["/disk1/Datasets/IQA/LIVE", 29], + "tid2013": ["/disk1/Datasets/IQA/TID2013", 25], + "tid2013_other": ["/disk1/Datasets/IQA/TID2013", 25], + "csiq": ["/disk1/Datasets/IQA/CSIQ", 30], + "csiq_other": ["/disk1/Datasets/IQA/CSIQ", 30], + "kadid": ["/disk1/Datasets/IQA/kadid10k", 81], + "kadid_other": ["/disk1/Datasets/IQA/kadid10k", 81], + "PIQ2023": ["/disk1/Datasets/IQA/2_portrait_image/PIQ2023", 81], + "GFIQA_20k": ["/disk1/Datasets/IQA/2_portrait_image/GFIQA-20k", 19998], + "AGIQA_3k": ["/disk1/Datasets/IQA/3_AI_Generate/AGIQA_3k", 2982], + "UWIQA": ["/disk1/Datasets/IQA/4_underwater/UWIQA", 890], + "AIGCIQA2023": ["/disk1/Datasets/IQA/3_AI_Generate/AIGCIQA2023", 2400], + "tid2013_multi_dim": ["/disk1/Datasets/IQA/TID2013", 25], + "AIGCIQA3W": ["/disk1/Datasets/IQA/3_AI_Generate/AIGCQA-30K-Image", 14000], + "CGIQA6k": ["/disk1/Datasets/IQA/6_Screen/CGIQA6k", 6000], + "uhdiqa": ["/disk1/chenzewen/Competition/ECCV_2024_IQA/uhd_iqa", 4269] +} \ No newline at end of file diff --git a/PromptIQA/utils/dataset/folders.py b/PromptIQA/utils/dataset/folders.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd833bdcf4eb19a34c7857bdbc637cd5d217e26 --- /dev/null +++ b/PromptIQA/utils/dataset/folders.py @@ -0,0 +1,1259 @@ +import torch.utils.data as data +import torch + +import os +import scipy.io +import numpy as np +import csv +from openpyxl import load_workbook +import cv2 +import random +import torch +import torchvision + +from utils.dataset import folders +from utils.dataset.process import ToTensor, Normalize, RandHorizontalFlip + +def get_prompt(samples_p, gt_p, transform, n, length, sample_type='fix'): + # transform = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()]) + combined_data = list(zip(samples_p, gt_p)) + + if sample_type == 'fix': + combined_data.sort(key=lambda x: x[1]) + elif sample_type == 'random': + random.seed() + random.shuffle(combined_data) + else: + raise NotImplementedError('Only Support fix | random') + + length = len(samples_p) + sample, gt = [], [] + if n == 2: + sample.append(combined_data[0][0]) + gt.append(combined_data[0][1]) + return prompt_data(sample, gt, transform) + data_len = (length - 2) // (n - 2) + sample.append(combined_data[0][0]) + gt.append(combined_data[0][1]) + for i in range(data_len, length, data_len): + sample.append(combined_data[i][0]) + gt.append(combined_data[i][1]) + if len(sample) == n - 1: + break + sample.append(combined_data[-1][0]) + gt.append(combined_data[-1][1]) + + assert len(sample) == n + return prompt_data(sample, gt, transform) + + +class prompt_data(data.Dataset): + def __init__(self, sample, gt, transform, div=1): + self.samples, self.gt = [sample], [gt] + self.transform = transform + self.div = div + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, self.div) + + return img_tensor, gt_tensor + + def __len__(self): + length = len(self.samples) + return length + +def reshuffle(sample:list, gt:list): + combine = list(zip(sample.copy(), gt.copy())) + random.shuffle(combine) + + sample_new, gt_new = [], [] + for i, j in combine: + sample_new.append(i) + gt_new.append(j) + + return sample_new, gt_new + +class LIVEC(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=True): + imgpath = scipy.io.loadmat(os.path.join(root, 'Data', 'AllImages_release.mat')) + imgpath = imgpath['AllImages_release'] + imgpath = imgpath[7:1169] + mos = scipy.io.loadmat(os.path.join(root, 'Data', 'AllMOS_release.mat')) + labels = mos['AllMOS_release'].astype(np.float32) + labels = labels[0][7:1169] + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(root, 'Images', imgpath[item][0][0])) + gt.append(labels[item]) + # gt = normalization(gt) + gt = list((np.array(gt) - 1) / 100) + + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + return img_tensor, gt_tensor, self.samples[index], 'livec' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + +class AIGCIQA3W(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=True): + + imgname = [] + mos_all = [] + + xls_file = os.path.join(root, 'info_train.xlsx') + workbook = load_workbook(xls_file) + booksheet = workbook.active + rows = booksheet.rows + count = 1 + prompt = [] + for row in rows: + count += 1 + img_name = booksheet.cell(row=count, column=1).value + imgname.append(img_name) + mos = booksheet.cell(row=count, column=3).value + mos = np.array(mos) + mos = mos.astype(np.float32) + mos_all.append(mos) + prompt.append(str(booksheet.cell(row=count, column=2).value)) + if count == 14002: + break + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(root, 'train', imgname[item])) + gt.append(mos_all[item]) + # gt = normalization(gt) + gt = list(np.array(gt) / 1) + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + return img_tensor, gt_tensor, "", 'AIGCIQA3W' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + +class AIGCIQA2023(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=True): + mos = scipy.io.loadmat(os.path.join(root, 'DATA', 'MOS', 'mosz1.mat')) + labels = mos['MOSz'].astype(np.float32) + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(root, 'Image', 'allimg', f'{item}.png')) + gt.append(labels[item][0]) + # gt = normalization(gt) + gt = list(np.array(gt) / 100) + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'AIGCIQA2023' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + + +class Koniq10k(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=True): + imgname = [] + mos_all = [] + csv_file = os.path.join(root, 'koniq10k_distributions_sets.csv') + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + imgname.append(row['image_name']) + mos = np.array(float(row['MOS'])).astype(np.float32) + mos_all.append(mos) + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(root, '1024x768', imgname[item])) + gt.append(mos_all[item]) + # gt = normalization(gt) + gt = list(np.array(gt) / 100) + + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'koniq10k' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + + +class uhdiqa(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=True): + imgname = [] + mos_all = [] + csv_file = os.path.join(root, 'uhd-iqa-training-metadata.csv') + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + imgname.append(row['image_name']) + mos = np.array(float(row['quality_mos'])).astype(np.float32) + mos_all.append(mos) + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(root, 'challenge/training', imgname[item])) + gt.append(mos_all[item]) + # gt = normalization(gt) + + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'uhdiqa' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + + +class CGIQA6k(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=True): + imgname = [] + mos_all = [] + csv_file = os.path.join(root, 'mos.csv') + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + imgname.append(row['Image']) + mos = np.array(float(row['MOS'])).astype(np.float32) + mos_all.append(mos) + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(root, 'database', imgname[item])) + gt.append(mos_all[item]) + gt = normalization(gt) + + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'CGIQA6k' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + +class CSIQ(data.Dataset): + def __init__(self, root, index, transform, patch_num=1, batch_size=11, istrain=True, dist_type=None): + + refpath = os.path.join(root, 'src_imgs') + refname = getFileName(refpath, '.png') + txtpath = os.path.join(root, 'csiq_label.txt') + fh = open(txtpath, 'r') + imgnames = [] + target = [] + refnames_all = [] + for line in fh: + line = line.split('\n') + words = line[0].split() + if dist_type is None: + imgnames.append((words[0])) + target.append(words[1]) + ref_temp = words[0].split(".") + refnames_all.append(ref_temp[0] + '.' + ref_temp[-1]) + else: + if words[0].split('.')[1] == dist_type: + imgnames.append((words[0])) + target.append(words[1]) + ref_temp = words[0].split(".") + refnames_all.append(ref_temp[0] + '.' + ref_temp[-1]) + + labels = np.array(target).astype(np.float32) + refnames_all = np.array(refnames_all) + + sample = [] + gt = [] + + for i, item in enumerate(index): + train_sel = (refname[index[i]] == refnames_all) + train_sel = np.where(train_sel == True) + train_sel = train_sel[0].tolist() + for j, item in enumerate(train_sel): + for aug in range(patch_num): + sample.append(os.path.join(root, 'dst_imgs_all', imgnames[item])) + gt.append(labels[item]) + # gt = normalization(gt) + gt = list(np.array(gt) / 1) + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + self.dist_type = dist_type + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + if self.dist_type is None: + return img_tensor, gt_tensor, self.samples[index], 'csiq' + else: + return img_tensor, gt_tensor, "", 'csiq_' + self.dist_type + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + +class csiq_other(data.Dataset): + def __init__(self, root, index, transform, patch_num=1, batch_size=11, istrain=True, dist_type=None, types='SSIM'): + + refpath = os.path.join(root, 'src_imgs') + refname = getFileName(refpath, '.png') + txtpath = os.path.join(root, 'csiq_label.txt') + fh = open(txtpath, 'r') + imgnames = [] + target = [] + refnames_all = [] + + idx = { + 'SSIM': 2, + 'FSIM': 3, + 'LPIPS': 4 + } + print('Get type ', types) + + for line in fh: + line = line.split('\n') + words = line[0].split() + if dist_type is None: + imgnames.append((words[0])) + target.append(words[idx[types]]) + ref_temp = words[0].split(".") + refnames_all.append(ref_temp[0] + '.' + ref_temp[-1]) + else: + if words[0].split('.')[1] == dist_type: + imgnames.append((words[0])) + target.append(words[idx[types]]) + ref_temp = words[0].split(".") + refnames_all.append(ref_temp[0] + '.' + ref_temp[-1]) + + labels = np.array(target).astype(np.float32) + refnames_all = np.array(refnames_all) + + sample = [] + gt = [] + + for i, item in enumerate(index): + train_sel = (refname[index[i]] == refnames_all) + train_sel = np.where(train_sel == True) + train_sel = train_sel[0].tolist() + for j, item in enumerate(train_sel): + for aug in range(patch_num): + sample.append(os.path.join(root, 'dst_imgs_all', imgnames[item])) + gt.append(labels[item]) + # gt = normalization(gt) + gt = list(np.array(gt) / 1) + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + self.dist_type = dist_type + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + if self.dist_type is None: + return img_tensor, gt_tensor, "", 'csiq_other' + else: + return img_tensor, gt_tensor, "", 'csiq_' + self.dist_type + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + + +class TID2013Folder_Other_Type(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=True, types='SSIM'): + print('TID Type: ', types) + print('index', index) + imgpath = os.path.join(root, 'distorted_images') + csv_file = os.path.join(root, 'resNEWTest.csv') + + sample, gt = [], [] + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + if (int(row['fileName'].split('_')[0][1:]) - 1) in index: + sample.append(os.path.join(imgpath, row['fileName'])) + mos = np.array(float(row[types])).astype(np.float32) + gt.append(mos) + # gt = normalization(gt) + gt = list(np.array(gt) / 9) + + self.samples_p, self.gt_p = sample, gt + print('gt', len(gt)) + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'tid2013_other' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + +class TID2013Folder(data.Dataset): + def __init__(self, root, index, transform, patch_num=1, batch_size=11, istrain=False): + refpath = os.path.join(root, 'reference_images') + refname = getTIDFileName(refpath, '.bmp.BMP') + txtpath = os.path.join(root, 'mos_with_names.txt') + fh = open(txtpath, 'r') + imgnames = [] + target = [] + refnames_all = [] + for line in fh: + line = line.split('\n') + words = line[0].split() + imgnames.append((words[1])) + target.append(words[0]) + ref_temp = words[1].split("_") + refnames_all.append(ref_temp[0][1:]) + labels = np.array(target).astype(np.float32) + refnames_all = np.array(refnames_all) + + sample = [] + gt = [] + for i, item in enumerate(index): + train_sel = (refname[index[i]] == refnames_all) + train_sel = np.where(train_sel == True) + train_sel = train_sel[0].tolist() + for j, item in enumerate(train_sel): + for aug in range(patch_num): + sample.append(os.path.join(root, 'distorted_images', imgnames[item])) + gt.append(labels[item]) + # gt = normalization(gt) + gt = list(np.array(gt) / 9) + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'tid2013' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + + +class KADID(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=True): + imgpath = os.path.join(root, 'images') + + csv_file = os.path.join(root, 'dmos.csv') + + sample, gt = [], [] + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + if (int(row['dist_img'].split('_')[0].replace('I', '')) - 1) in index: + sample.append(os.path.join(imgpath, row['dist_img'])) + mos = np.array(float(row['dmos'])).astype(np.float32) + gt.append(mos) + # gt = normalization(gt) + gt = list((np.array(gt) - 1) / 5) + + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'kadid' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + + +class KADID_Other(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=True, types='SSIM'): + imgpath = os.path.join(root, 'images') + csv_file = os.path.join(root, 'dmos.csv') + + print("Get Type: ", types) + + sample, gt = [], [] + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + if (int(row['dist_img'].split('_')[0].replace('I', '')) - 1) in index: + sample.append(os.path.join(imgpath, row['dist_img'])) + mos = np.array(float(row[types])).astype(np.float32) + gt.append(mos) + # gt = normalization(gt) + gt = list((np.array(gt) - 1) / 5) + + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'kadid_other' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + +class TID_Multi_Dim(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=True, types='LPIPS'): + print('TID Type: ', types) + imgpath = os.path.join(root, 'distorted_images') + + csv_file = os.path.join(root, 'resNEWTest.csv') + + sample, gt = [], [] + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + if (int(row['fileName'].split('_')[0][1:]) - 1) in index: + sample.append(os.path.join(imgpath, row['fileName'])) + mos = np.array(float(row[types])).astype(np.float32) + gt.append(mos) + # gt = normalization(gt) + gt = list((np.array(gt) - 1) / 5) + + self.samples_p, self.gt_p = sample, gt + print('gt', len(gt)) + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'tid2013_multi_dim' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + + +class LIVEFolder(data.Dataset): + + def __init__(self, root, index, transform, batch_size=11, istrain=False): + + refpath = os.path.join(root, 'refimgs') + refname = getFileName(refpath, '.bmp') + + jp2kroot = os.path.join(root, 'jp2k') + jp2kname = self.getDistortionTypeFileName(jp2kroot, 227) + + jpegroot = os.path.join(root, 'jpeg') + jpegname = self.getDistortionTypeFileName(jpegroot, 233) + + wnroot = os.path.join(root, 'wn') + wnname = self.getDistortionTypeFileName(wnroot, 174) + + gblurroot = os.path.join(root, 'gblur') + gblurname = self.getDistortionTypeFileName(gblurroot, 174) + + fastfadingroot = os.path.join(root, 'fastfading') + fastfadingname = self.getDistortionTypeFileName(fastfadingroot, 174) + + imgpath = jp2kname + jpegname + wnname + gblurname + fastfadingname + + dmos = scipy.io.loadmat(os.path.join(root, 'dmos_realigned.mat')) + labels = dmos['dmos'].astype(np.float32) + + orgs = dmos['orgs'] + refnames_all = scipy.io.loadmat(os.path.join(root, 'refnames_all.mat')) + refnames_all = refnames_all['refnames_all'] + + sample = [] + gt = [] + + for i in range(0, len(index)): + train_sel = (refname[index[i]] == refnames_all) + train_sel = train_sel * ~orgs.astype(np.bool_) + train_sel = np.where(train_sel == True) + train_sel = train_sel[1].tolist() + for j, item in enumerate(train_sel): + sample.append(imgpath[item]) + gt.append(labels[0][item]) + + # print(self.imgpath[item]) + # gt = normalization(gt) + gt =list((np.array(gt) - 1) / 100) + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'live' + + def __len__(self): + length = len(self.samples) + return length + + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + + + def getDistortionTypeFileName(self, path, num): + filename = [] + index = 1 + for i in range(0, num): + name = '%s%s%s' % ('img', str(index), '.bmp') + filename.append(os.path.join(path, name)) + index = index + 1 + return filename + + +class FLIVE(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=False): + imgname = [] + mos_all = [] + csv_file = os.path.join(root, 'labels_image.csv') + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + imgname.append(row['name']) + mos = np.array(float(row['mos'])).astype(np.float32) + mos_all.append(mos) + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(root, imgname[item])) + gt.append(mos_all[item]) + gt = normalization(gt) + + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'flive' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + + +class SPAQ(data.Dataset): + def __init__(self, root, index, transform, column=6, batch_size=11, istrain=False): + sample = [] + gt = [] + + xls_file = os.path.join(root, 'Annotations', 'MOS_and_Image_attribute_scores.xlsx') + workbook = load_workbook(xls_file) + booksheet = workbook.active + rows = booksheet.rows + + for count, row in enumerate(rows, 2): + if count - 2 in index: + sample.append(os.path.join(root, 'img_resize', booksheet.cell(row=count, column=1).value)) + mos = booksheet.cell(row=count, column=column).value + mos = np.array(mos) + mos = mos.astype(np.float32) + gt.append(mos) + if count == 11126: + break + # gt = normalization(gt) + gt = list(np.array(gt) / 100) + self.samples_p, self.gt_p = sample, gt + self.column = column + # print('get column', self.column) + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'spaq_{}'.format(self.column) + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + +class UWIQA(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=False): + sample = [] + gt = [] + + xls_file = os.path.join(root, 'IQA_Value.xlsx') + workbook = load_workbook(xls_file) + booksheet = workbook.active + rows = booksheet.rows + + for count, row in enumerate(rows, 2): + if count - 2 in index: + sample.append(os.path.join(root, 'img', '{}.png'.format(str(booksheet.cell(row=count, column=1).value).zfill(4)))) + mos = booksheet.cell(row=count, column=2).value + mos = np.array(mos) + mos = mos.astype(np.float32) + gt.append(mos) + + # gt = normalization(gt) + gt = list(np.array(gt) / 1) + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, self.samples[index], 'UWIQA' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + + +def split_array(arr, m): + return [arr[i:i + m] for i in range(0, len(arr), m)] + + +class BID(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=False): + + imgname = [] + mos_all = [] + + xls_file = os.path.join(root, 'DatabaseGrades.xlsx') + workbook = load_workbook(xls_file) + booksheet = workbook.active + rows = booksheet.rows + count = 1 + for row in rows: + count += 1 + img_num = booksheet.cell(row=count, column=1).value + img_name = "DatabaseImage%04d.JPG" % (img_num) + imgname.append(img_name) + mos = booksheet.cell(row=count, column=2).value + mos = np.array(mos) + mos = mos.astype(np.float32) + mos_all.append(mos) + if count == 587: + break + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(root, imgname[item])) + gt.append(mos_all[item]) + # gt = normalization(gt) + gt = list(np.array(gt) / 9) + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'bid' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + + +class PIQ2023(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=True): + # index is not used. + imgname = [] + mos_all = [] + img_path = os.path.join(root, 'img') + if istrain: + csv_file = os.path.join(root, 'Test split', 'Device Split', 'DeviceSplit_Train_Scores_Exposure.csv') + else: + csv_file = os.path.join(root, 'Test split', 'Device Split', 'DeviceSplit_Test_Scores_Exposure.csv') + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + imgname.append(row['IMAGE PATH'].replace('\\', '/')) + mos = np.array(float(row['JOD'])).astype(np.float32) + mos_all.append(mos) + + sample, gt = [], [] + for i, item in enumerate(range(len(imgname))): + sample.append(os.path.join(img_path, imgname[item])) + gt.append(mos_all[item]) + gt = normalization(gt) + + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'PIQ2023' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + +class GFIQA_20k(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=True): + imgname = [] + mos_all = [] + img_path = os.path.join(root, 'image') + csv_file = os.path.join(root, 'mos.csv') + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + imgname.append(row['img_name']) + mos = np.array(float(row['mos'])).astype(np.float32) + mos_all.append(mos) + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(img_path, imgname[item])) + gt.append(mos_all[item]) + # gt = normalization(gt) + gt = list(np.array(gt) / 1) + + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "", 'GFIQA_20k' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + +class AGIQA_3k(data.Dataset): + def __init__(self, root, index, transform, batch_size=11, istrain=True): + imgname = [] + mos_all = [] + img_path = os.path.join(root, 'img') + csv_file = os.path.join(root, 'data.csv') + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + imgname.append(row['name']) + mos = np.array(float(row['mos_quality'])).astype(np.float32) + mos_all.append(mos) + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(img_path, imgname[item])) + gt.append(mos_all[item]) + # gt = normalization(gt) + gt = list(np.array(gt) / 5) + + self.samples_p, self.gt_p = sample, gt + + self.samples, self.gt = split_array(sample, batch_size), split_array(gt, batch_size) + if len(self.samples[-1]) != batch_size and istrain: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + self.transform = transform + self.batch_size = batch_size + + def reshuffle(self): + shuffle_sample, shuffle_gt = reshuffle(self.samples_p, self.gt_p) + self.samples, self.gt = split_array(shuffle_sample, self.batch_size), split_array(shuffle_gt, self.batch_size) + if len(self.samples[-1]) != self.batch_size: + self.samples = self.samples[:-1] + self.gt = self.gt[:-1] + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, self.samples[index], 'AGIQA_3k' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10, sample_type='fix'): + return get_prompt(self.samples_p, self.gt_p, self.transform, n, self.__len__(), sample_type=sample_type) + + +def get_item(samples, gt, index, transform, div=1): + div = 1 + path_list, target_list = samples[index], gt[index] + img_tensor, gt_tensor = None, None + for path, target in zip(path_list, target_list): + target = [target / div] + + values_to_insert = np.array([0.0, 1.0]) + position_to_insert = 0 + target = np.insert(target, position_to_insert, values_to_insert) + + sample = load_image(path) + samples = {'img': sample, 'gt': target} + samples = transform(samples) + + if img_tensor is None: + img_tensor = samples['img'].unsqueeze(0) + gt_tensor = samples['gt'].type(torch.FloatTensor).unsqueeze(0) + else: + img_tensor = torch.cat((img_tensor, samples['img'].unsqueeze(0)), dim=0) + gt_tensor = torch.cat((gt_tensor, samples['gt'].type(torch.FloatTensor).unsqueeze(0)), dim=0) + + return img_tensor, gt_tensor + + +def getFileName(path, suffix): + filename = [] + f_list = os.listdir(path) + for i in f_list: + if os.path.splitext(i)[1] == suffix: + filename.append(i) + return filename + + +def load_image(img_path, size=224): + try: + d_img = cv2.imread(img_path, cv2.IMREAD_COLOR) + d_img = cv2.resize(d_img, (size, size), interpolation=cv2.INTER_CUBIC) + d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB) + d_img = np.array(d_img).astype('float32') / 255 + d_img = np.transpose(d_img, (2, 0, 1)) + except: + print(img_path) + + return d_img + + +def normalization(data): + data = np.array(data) + range = np.max(data) - np.min(data) + data = (data - np.min(data)) / range + data = list(data.astype('float').reshape(-1, 1)) + + return data + + +def getTIDFileName(path, suffix): + filename = [] + f_list = os.listdir(path) + for i in f_list: + if suffix.find(os.path.splitext(i)[1]) != -1: + filename.append(i[1:3]) + return filename diff --git a/PromptIQA/utils/dataset/process.py b/PromptIQA/utils/dataset/process.py new file mode 100644 index 0000000000000000000000000000000000000000..de6bc4a066b578783c9d4da818a7c8e30a2c7ba8 --- /dev/null +++ b/PromptIQA/utils/dataset/process.py @@ -0,0 +1,104 @@ +import torch +import numpy as np + + +class Normalize(object): + def __init__(self, mean, var): + self.mean = mean + self.var = var + + def __call__(self, sample): + if isinstance(sample, dict): + img = sample['img'] + gt = sample['gt'] + img = (img - self.mean) / self.var + sample = {'img': img, 'gt': gt} + else: + sample = (sample - self.mean) / self.var + + return sample + +import numpy as np + +class TwentyCrop(object): + def __init__(self, crop_height=224, crop_width=224): + self.crop_height = crop_height + self.crop_width = crop_width + + def __call__(self, sample): + n, m = sample.shape[1], sample.shape[2] # 获取图像的高度和宽度 + h_stride = (n - self.crop_height) // 4 # 计算垂直方向上的步长 + w_stride = (m - self.crop_width) // 4 # 计算水平方向上的步长 + + crops = [] + for h_step in range(5): + for w_step in range(4): + # 计算每个裁剪的起始点 + h_start = h_step * h_stride + w_start = w_step * w_stride + # 裁剪图像 + crop = sample[:, h_start:h_start+self.crop_height, w_start:w_start+self.crop_width] + crops.append(crop) + + # 将裁剪的列表转换为numpy数组 + crops = np.stack(crops) # 这将创建一个形状为[20, 3, crop_height, crop_width]的数组 + return crops + +class FiveCrop(object): + def __init__(self, size=224): + self.size = size # 裁剪图片的尺寸 + + def __call__(self, sample): + # 确保输入的sample是期望的格式 + if isinstance(sample, np.ndarray) and sample.shape[0] == 3: + c, h, w = sample.shape + crop_h, crop_w = self.size, self.size + + # 计算裁剪的起始点 + tl = sample[:, 0:crop_h, 0:crop_w] # 左上角 + tr = sample[:, 0:crop_h, w - crop_w:] # 右上角 + bl = sample[:, h - crop_h:, 0:crop_w] # 左下角 + br = sample[:, h - crop_h:, w - crop_w:] # 右下角 + center = sample[:, h//2 - crop_h//2:h//2 + crop_h//2, w//2 - crop_w//2:w//2 + crop_w//2] # 中心 + + # 将五个裁剪合并到一个numpy数组中 + crops = np.stack([tl, tr, bl, br, center]) + return crops + else: + raise ValueError("输入的sample不是期望的格式或尺寸。") + +class RandHorizontalFlip(object): + def __init__(self, prob_aug): + self.prob_aug = prob_aug + + def __call__(self, sample): + p_aug = np.array([self.prob_aug, 1 - self.prob_aug]) + prob_lr = np.random.choice([1, 0], p=p_aug.ravel()) + + if isinstance(sample, dict): + img = sample['img'] + gt = sample['gt'] + + if prob_lr > 0.5: + img = np.fliplr(img).copy() + sample = {'img': img, 'gt': gt} + else: + if prob_lr > 0.5: + sample = np.fliplr(sample).copy() + return sample + + +class ToTensor(object): + def __init__(self): + pass + + def __call__(self, sample): + if isinstance(sample, dict): + img = sample['img'] + gt = np.array(sample['gt']) + img = torch.from_numpy(img).type(torch.FloatTensor) + gt = torch.from_numpy(gt).type(torch.FloatTensor) + sample = {'img': img, 'gt': gt} + else: + sample = torch.from_numpy(sample).type(torch.FloatTensor) + return sample diff --git a/PromptIQA/utils/dataset_test/__pycache__/data_loader.cpython-38.pyc b/PromptIQA/utils/dataset_test/__pycache__/data_loader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f08d1d5f59e88b72c10cc9164c9d2c47f30401e8 Binary files /dev/null and b/PromptIQA/utils/dataset_test/__pycache__/data_loader.cpython-38.pyc differ diff --git a/PromptIQA/utils/dataset_test/__pycache__/folders.cpython-38.pyc b/PromptIQA/utils/dataset_test/__pycache__/folders.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..688a632fd7e12b59beebe1020c4ea8f88ea9e6b2 Binary files /dev/null and b/PromptIQA/utils/dataset_test/__pycache__/folders.cpython-38.pyc differ diff --git a/PromptIQA/utils/dataset_test/__pycache__/process.cpython-38.pyc b/PromptIQA/utils/dataset_test/__pycache__/process.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c09bac50d0c6bac1979257e77fbc1e5a78f65502 Binary files /dev/null and b/PromptIQA/utils/dataset_test/__pycache__/process.cpython-38.pyc differ diff --git a/PromptIQA/utils/dataset_test/data_loader.py b/PromptIQA/utils/dataset_test/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..ef17085b2a32effaa001bc20c094fbdca8227f72 --- /dev/null +++ b/PromptIQA/utils/dataset_test/data_loader.py @@ -0,0 +1,52 @@ +import torch +import torchvision + +from utils.dataset_test import folders +from utils.dataset_test.process import ToTensor, Normalize, RandHorizontalFlip + +class Data_Loader(object): + """Dataset class for IQA databases""" + + def __init__(self, config, path, img_indx, istrain=True, column=2): + + self.batch_size = config.batch_size + self.istrain = istrain + dataset = config.dataset + + if istrain: + transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), RandHorizontalFlip(prob_aug=0.5), ToTensor()]) + else: + transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()]) + + if dataset == 'livec': + self.data = folders.LIVEC(root=path, index=img_indx, transform=transforms) + elif dataset == 'koniq10k': + self.data = folders.Koniq10k(root=path, index=img_indx, transform=transforms) + elif dataset == 'bid': + self.data = folders.BID(root=path, index=img_indx, transform=transforms) + elif dataset == 'spaq': + self.data = folders.SPAQ(root=path, index=img_indx, transform=transforms, column=column) + elif dataset == 'flive': + self.data = folders.FLIVE(root=path, index=img_indx, transform=transforms) + elif dataset == 'csiq': + self.data = folders.CSIQ(root=path, index=img_indx, transform=transforms) + elif dataset == 'live': + self.data = folders.LIVEFolder(root=path, index=img_indx, transform=transforms) + elif dataset == 'tid2013': + self.data = folders.TID2013Folder(root=path, index=img_indx, transform=transforms) + elif dataset == 'kadid': + self.data = folders.KADID(root=path, index=img_indx, transform=transforms) + else: + raise Exception("Only support livec, koniq10k, bid, spaq.") + + def get_data(self): + dataloader = torch.utils.data.DataLoader(self.data, batch_size=self.batch_size, shuffle=self.istrain, num_workers=16, drop_last=self.istrain) + return dataloader + + def get_samples(self): + return self.data + + def get_prompt(self, n=5): + print('Get {} images for prompting.'.format(n)) + prompt_data = self.data.get_promt(n=n) + return torch.utils.data.DataLoader(prompt_data, batch_size=prompt_data.__len__(), shuffle=False) \ No newline at end of file diff --git a/PromptIQA/utils/dataset_test/dataset_info.json b/PromptIQA/utils/dataset_test/dataset_info.json new file mode 100644 index 0000000000000000000000000000000000000000..86eb08c88fa3dbd85ab44e16db4c73ede430598f --- /dev/null +++ b/PromptIQA/utils/dataset_test/dataset_info.json @@ -0,0 +1,11 @@ +{ + "livec": ["/disk1/Datasets/IQA/ChallengeDB_release", 1162], + "koniq10k": ["/disk1/Datasets/IQA/koniq-10k", 10073], + "bid": ["/disk1/Datasets/IQA/BID/ImageDatabase", 586], + "spaq": ["/disk1/Datasets/IQA/SPAQ", 11125], + "flive": ["/disk1/Datasets/IQA/FLIVE_Database/database", 39810], + "csiq": ["/disk1/Datasets/IQA/CSIQ", 30], + "live": ["/disk1/Datasets/IQA/LIVE", 29], + "tid2013": ["/disk1/Datasets/IQA/TID2013", 25], + "kadid": ["/disk1/Datasets/IQA/kadid10k", 10125] +} \ No newline at end of file diff --git a/PromptIQA/utils/dataset_test/folders.py b/PromptIQA/utils/dataset_test/folders.py new file mode 100644 index 0000000000000000000000000000000000000000..855d39bd9daff5b99747a578d6e58b7bf4f6516e --- /dev/null +++ b/PromptIQA/utils/dataset_test/folders.py @@ -0,0 +1,621 @@ +import torch.utils.data as data +import torch + +from PIL import Image +import os +import scipy.io +import numpy as np +import csv +from openpyxl import load_workbook +import cv2 +import random + +class LIVEC(data.Dataset): + def __init__(self, root, index, transform): + imgpath = scipy.io.loadmat(os.path.join(root, 'Data', 'AllImages_release.mat')) + imgpath = imgpath['AllImages_release'] + imgpath = imgpath[7:1169] + mos = scipy.io.loadmat(os.path.join(root, 'Data', 'AllMOS_release.mat')) + labels = mos['AllMOS_release'].astype(np.float32) + labels = labels[0][7:1169] + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(root, 'Images', imgpath[item][0][0])) + gt.append(labels[item] / 100.) + # gt = normalization(gt) + + self.samples, self.gt = sample, gt + self.transform = transform + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform) + + return img_tensor, gt_tensor, '' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10): + combined_data = list(zip(self.samples, self.gt)) + combined_data.sort(key=lambda x: x[1]) + + values_to_insert = np.array([0.0, 1.0]) + position_to_insert = 0 + + data_len = (self.__len__() - 2) // (n - 2) + sample, gt = [], [] + sample.append(combined_data[0][0]) + gt.append(combined_data[0][1]) + + for i in range(data_len, self.__len__(), data_len): + sample.append(combined_data[i][0]) + gt.append(combined_data[i][1]) + if len(sample) == n - 1: + break + sample.append(combined_data[-1][0]) + gt.append(combined_data[-1][1]) + return prompt_data(sample, gt, self.transform) + +class prompt_data(data.Dataset): + def __init__(self, sample, gt, transform, div=1): + self.samples, self.gt = sample, gt + self.transform = transform + self.div = div + print('prompt GT is', gt) + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, self.div) + + return img_tensor, gt_tensor + + def __len__(self): + length = len(self.samples) + return length + + +class Koniq10k(data.Dataset): + def __init__(self, root, index, transform): + imgname = [] + mos_all = [] + csv_file = os.path.join(root, 'koniq10k_distributions_sets.csv') + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + imgname.append(row['image_name']) + mos = np.array(float(row['MOS'])).astype(np.float32) + mos_all.append(mos) + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(root, '1024x768', imgname[item])) + gt.append(mos_all[item] / 100.) + # gt = normalization(gt) + + self.samples, self.gt = sample, gt + self.transform = transform + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform) + + return img_tensor, gt_tensor, self.samples[index] + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10): + combined_data = list(zip(self.samples, self.gt)) + combined_data.sort(key=lambda x: x[1]) + + values_to_insert = np.array([0.0, 1.0]) + position_to_insert = 0 + + data_len = (self.__len__() - 2) // (n - 2) + sample, gt = [], [] + sample.append(combined_data[0][0]) + gt.append(combined_data[0][1]) + + for i in range(data_len, self.__len__(), data_len): + sample.append(combined_data[i][0]) + gt.append(combined_data[i][1]) + if len(sample) == n - 1: + break + sample.append(combined_data[-1][0]) + gt.append(combined_data[-1][1]) + return prompt_data(sample, gt, self.transform) + +class CSIQ(data.Dataset): + def __init__(self, root, index, transform, patch_num=1): + + refpath = os.path.join(root, 'src_imgs') + refname = getFileName(refpath,'.png') + txtpath = os.path.join(root, 'csiq_label.txt') + fh = open(txtpath, 'r') + imgnames = [] + target = [] + refnames_all = [] + for line in fh: + line = line.split('\n') + words = line[0].split() + imgnames.append((words[0])) + target.append(words[1]) + ref_temp = words[0].split(".") + refnames_all.append(ref_temp[0] + '.' + ref_temp[-1]) + + labels = np.array(target).astype(np.float32) + refnames_all = np.array(refnames_all) + + sample = [] + gt = [] + + for i, item in enumerate(index): + train_sel = (refname[index[i]] == refnames_all) + train_sel = np.where(train_sel == True) + train_sel = train_sel[0].tolist() + for j, item in enumerate(train_sel): + for aug in range(patch_num): + sample.append(os.path.join(root, 'dst_imgs_all', imgnames[item])) + gt.append(labels[item]) + self.samples, self.gt = sample, gt + self.transform = transform + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform) + + return img_tensor, gt_tensor, '' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10): + combined_data = list(zip(self.samples, self.gt)) + combined_data.sort(key=lambda x: x[1]) + # random.shuffle(combined_data) + + values_to_insert = np.array([0.0, 1.0]) + position_to_insert = 0 + + data_len = (self.__len__() - 2) // (n - 2) + sample, gt = [], [] + sample.append(combined_data[0][0]) + gt.append(combined_data[0][1]) + + for i in range(data_len, self.__len__(), data_len): + sample.append(combined_data[i][0]) + gt.append(combined_data[i][1]) + if len(sample) == n - 1: + break + sample.append(combined_data[-1][0]) + gt.append(combined_data[-1][1]) + return prompt_data(sample, gt, self.transform) + +class TID2013Folder(data.Dataset): + + def __init__(self, root, index, transform, patch_num=1): + refpath = os.path.join(root, 'reference_images') + refname = getTIDFileName(refpath,'.bmp.BMP') + txtpath = os.path.join(root, 'mos_with_names.txt') + fh = open(txtpath, 'r') + imgnames = [] + target = [] + refnames_all = [] + for line in fh: + line = line.split('\n') + words = line[0].split() + imgnames.append((words[1])) + target.append(words[0]) + ref_temp = words[1].split("_") + refnames_all.append(ref_temp[0][1:]) + labels = np.array(target).astype(np.float32) + refnames_all = np.array(refnames_all) + + sample = [] + gt = [] + for i, item in enumerate(index): + train_sel = (refname[index[i]] == refnames_all) + train_sel = np.where(train_sel == True) + train_sel = train_sel[0].tolist() + for j, item in enumerate(train_sel): + for aug in range(patch_num): + sample.append(os.path.join(root, 'distorted_images', imgnames[item])) + gt.append(labels[item] / 8) + self.samples, self.gt = sample, gt + self.transform = transform + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform) + + return img_tensor, gt_tensor, '' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10): + combined_data = list(zip(self.samples, self.gt)) + combined_data.sort(key=lambda x: x[1]) + + values_to_insert = np.array([0.0, 1.0]) + position_to_insert = 0 + + data_len = (self.__len__() - 2) // (n - 2) + sample, gt = [], [] + sample.append(combined_data[0][0]) + gt.append(combined_data[0][1]) + + for i in range(data_len, self.__len__(), data_len): + sample.append(combined_data[i][0]) + gt.append(combined_data[i][1]) + if len(sample) == n - 1: + break + sample.append(combined_data[-1][0]) + gt.append(combined_data[-1][1]) + return prompt_data(sample, gt, self.transform) + + +class KADID(data.Dataset): + def __init__(self, root, index, transform): + imgpath = os.path.join(root, 'images') + + csv_file = os.path.join(root, 'dmos.csv') + + imgname = [] + mos_all = [] + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + imgname.append(row['dist_img']) + mos = np.array(float(row['dmos'])).astype(np.float32) + mos_all.append(mos) + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(imgpath, imgname[item])) + gt.append(mos_all[item] / 5.) + # gt = normalization(gt) + self.samples, self.gt = sample, gt + self.transform = transform + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform) + + return img_tensor, gt_tensor, '' + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10): + combined_data = list(zip(self.samples, self.gt)) + combined_data.sort(key=lambda x: x[1]) + + values_to_insert = np.array([0.0, 1.0]) + position_to_insert = 0 + + data_len = (self.__len__() - 2) // (n - 2) + sample, gt = [], [] + sample.append(combined_data[0][0]) + gt.append(combined_data[0][1]) + + for i in range(data_len, self.__len__(), data_len): + sample.append(combined_data[i][0]) + gt.append(combined_data[i][1]) + if len(sample) == n - 1: + break + sample.append(combined_data[-1][0]) + gt.append(combined_data[-1][1]) + return prompt_data(sample, gt, self.transform) + +class LIVEFolder(data.Dataset): + + def __init__(self, root, index, transform): + + refpath = os.path.join(root, 'refimgs') + refname = getFileName(refpath, '.bmp') + + jp2kroot = os.path.join(root, 'jp2k') + jp2kname = self.getDistortionTypeFileName(jp2kroot, 227) + + jpegroot = os.path.join(root, 'jpeg') + jpegname = self.getDistortionTypeFileName(jpegroot, 233) + + wnroot = os.path.join(root, 'wn') + wnname = self.getDistortionTypeFileName(wnroot, 174) + + gblurroot = os.path.join(root, 'gblur') + gblurname = self.getDistortionTypeFileName(gblurroot, 174) + + fastfadingroot = os.path.join(root, 'fastfading') + fastfadingname = self.getDistortionTypeFileName(fastfadingroot, 174) + + imgpath = jp2kname + jpegname + wnname + gblurname + fastfadingname + + dmos = scipy.io.loadmat(os.path.join(root, 'dmos_realigned.mat')) + labels = dmos['dmos'].astype(np.float32) + + orgs = dmos['orgs'] + refnames_all = scipy.io.loadmat(os.path.join(root, 'refnames_all.mat')) + refnames_all = refnames_all['refnames_all'] + + sample = [] + gt = [] + + for i in range(0, len(index)): + train_sel = (refname[index[i]] == refnames_all) + train_sel = train_sel * ~orgs.astype(np.bool_) + train_sel = np.where(train_sel == True) + train_sel = train_sel[1].tolist() + for j, item in enumerate(train_sel): + sample.append(imgpath[item]) + gt.append(labels[0][item] / 100.) + # print(self.imgpath[item]) + self.samples, self.gt = sample, gt + self.transform = transform + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform) + + return img_tensor, gt_tensor, '' + + def __len__(self): + length = len(self.samples) + return length + + def getDistortionTypeFileName(self, path, num): + filename = [] + index = 1 + for i in range(0, num): + name = '%s%s%s' % ('img', str(index), '.bmp') + filename.append(os.path.join(path, name)) + index = index + 1 + return filename + def get_promt(self, n=10): + combined_data = list(zip(self.samples, self.gt)) + combined_data.sort(key=lambda x: x[1]) + + values_to_insert = np.array([0.0, 1.0]) + position_to_insert = 0 + + data_len = (self.__len__() - 2) // (n - 2) + sample, gt = [], [] + sample.append(combined_data[0][0]) + gt.append(combined_data[0][1]) + + for i in range(data_len, self.__len__(), data_len): + sample.append(combined_data[i][0]) + gt.append(combined_data[i][1]) + if len(sample) == n - 1: + break + sample.append(combined_data[-1][0]) + gt.append(combined_data[-1][1]) + return prompt_data(sample, gt, self.transform) + + +class FLIVE(data.Dataset): + def __init__(self, root, index, transform): + imgname = [] + mos_all = [] + csv_file = os.path.join(root, 'labels_image.csv') + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + imgname.append(row['name']) + mos = np.array(float(row['mos'])).astype(np.float32) + mos_all.append(mos) + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(root, imgname[item])) + gt.append(mos_all[item] / 100.) + # gt = normalization(gt) + + self.samples, self.gt = sample, gt + self.transform = transform + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform) + + return img_tensor, gt_tensor, self.samples[index] + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10): + combined_data = list(zip(self.samples, self.gt)) + combined_data.sort(key=lambda x: x[1]) + + values_to_insert = np.array([0.0, 1.0]) + position_to_insert = 0 + + data_len = (self.__len__() - 2) // (n - 2) + sample, gt = [], [] + sample.append(combined_data[0][0]) + gt.append(combined_data[0][1]) + + for i in range(data_len, self.__len__(), data_len): + sample.append(combined_data[i][0]) + gt.append(combined_data[i][1]) + if len(sample) == n - 1: + break + sample.append(combined_data[-1][0]) + gt.append(combined_data[-1][1]) + return prompt_data(sample, gt, self.transform) + +class SPAQ(data.Dataset): + def __init__(self, root, index, transform, column=6): + sample = [] + gt = [] + + xls_file = os.path.join(root, 'Annotations', 'MOS_and_Image_attribute_scores.xlsx') + workbook = load_workbook(xls_file) + booksheet = workbook.active + rows = booksheet.rows + print('column', column) + for count, row in enumerate(rows, 2): + if count - 2 in index: + sample.append(os.path.join(root, 'img_resize', booksheet.cell(row=count, column=1).value)) + mos = booksheet.cell(row=count, column=column).value + mos = np.array(mos) + mos = mos.astype(np.float32) + gt.append(mos / 100.) + if count == 11126: + break + # gt = normalization(gt) + + self.samples, self.gt = sample, gt + self.transform = transform + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform) + + return img_tensor, gt_tensor, self.samples[index] + + def __len__(self): + length = len(self.samples) + return length + + def get_promt(self, n=10): + combined_data = list(zip(self.samples, self.gt)) + combined_data.sort(key=lambda x: x[1]) + + values_to_insert = np.array([0.0, 1.0]) + position_to_insert = 0 + + data_len = (self.__len__() - 2) // (n - 2) + sample, gt = [], [] + sample.append(combined_data[0][0]) + gt.append(combined_data[0][1]) + + for i in range(data_len, self.__len__(), data_len): + sample.append(combined_data[i][0]) + gt.append(combined_data[i][1]) + if len(sample) == n - 1: + break + sample.append(combined_data[-1][0]) + gt.append(combined_data[-1][1]) + return prompt_data(sample, gt, self.transform) + + +class BID(data.Dataset): + def __init__(self, root, index, transform): + + imgname = [] + mos_all = [] + + xls_file = os.path.join(root, 'DatabaseGrades.xlsx') + workbook = load_workbook(xls_file) + booksheet = workbook.active + rows = booksheet.rows + count = 1 + for row in rows: + count += 1 + img_num = booksheet.cell(row=count, column=1).value + img_name = "DatabaseImage%04d.JPG" % (img_num) + imgname.append(img_name) + mos = booksheet.cell(row=count, column=2).value + mos = np.array(mos) + mos = mos.astype(np.float32) + mos_all.append(mos / 5.) + if count == 587: + break + + sample, gt = [], [] + for i, item in enumerate(index): + sample.append(os.path.join(root, imgname[item])) + gt.append(mos_all[item]) + + self.samples, self.gt = sample, gt + self.transform = transform + + def __getitem__(self, index): + img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform, div=5) + + return img_tensor, gt_tensor, "" + + def __len__(self): + length = len(self.samples) + return length + def get_promt(self, n=10): + combined_data = list(zip(self.samples, self.gt)) + combined_data.sort(key=lambda x: x[1]) + + data_len = (self.__len__() - 2) // (n - 2) + sample, gt = [], [] + sample.append(combined_data[0][0]) + gt.append(combined_data[0][1]) + + for i in range(data_len, self.__len__(), data_len): + sample.append(combined_data[i][0]) + gt.append(combined_data[i][1]) + if len(sample) == n - 1: + break + sample.append(combined_data[-1][0]) + gt.append(combined_data[-1][1]) + return prompt_data(sample, gt, self.transform) + +def get_item(samples, gt, index, transform, div=1): + div = 1 + try: + path, target = samples[index], gt[index] + target = [target / div] + + values_to_insert = np.array([0.0, 1.0]) + position_to_insert = 0 + target = np.insert(target, position_to_insert, values_to_insert) + + sample = load_image(path) + samples = {'img': sample, 'gt': target} + samples = transform(samples) + except: + path, target = samples[0], gt[0] + target = target / div + + values_to_insert = np.array([0.0, 1.0]) + position_to_insert = 0 + target = np.insert(target, position_to_insert, values_to_insert) + + sample = load_image(path) + samples = {'img': sample, 'gt': target} + samples = transform(samples) + print('ERROR.') + + return samples['img'], samples['gt'].type(torch.FloatTensor) + + +def getFileName(path, suffix): + filename = [] + f_list = os.listdir(path) + for i in f_list: + if os.path.splitext(i)[1] == suffix: + filename.append(i) + return filename + + +def load_image(img_path): + d_img = cv2.imread(img_path, cv2.IMREAD_COLOR) + d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC) + d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB) + d_img = np.array(d_img).astype('float32') / 255 + d_img = np.transpose(d_img, (2, 0, 1)) + + return d_img + +def normalization(data): + data = np.array(data) + range = np.max(data) - np.min(data) + data = (data - np.min(data)) / range + data = list(data.astype('float').reshape(-1, 1)) + + return data + +def getTIDFileName(path, suffix): + filename = [] + f_list = os.listdir(path) + for i in f_list: + if suffix.find(os.path.splitext(i)[1]) != -1: + filename.append(i[1:3]) + return filename \ No newline at end of file diff --git a/PromptIQA/utils/dataset_test/process.py b/PromptIQA/utils/dataset_test/process.py new file mode 100644 index 0000000000000000000000000000000000000000..e54af122fc9d112138c0b659a9fc4db286dbff0c --- /dev/null +++ b/PromptIQA/utils/dataset_test/process.py @@ -0,0 +1,57 @@ +import torch +import numpy as np + + +class Normalize(object): + def __init__(self, mean, var): + self.mean = mean + self.var = var + + def __call__(self, sample): + if isinstance(sample, dict): + img = sample['img'] + gt = sample['gt'] + img = (img - self.mean) / self.var + sample = {'img': img, 'gt': gt} + else: + sample = (sample - self.mean) / self.var + + return sample + + + +class RandHorizontalFlip(object): + def __init__(self, prob_aug): + self.prob_aug = prob_aug + + def __call__(self, sample): + p_aug = np.array([self.prob_aug, 1 - self.prob_aug]) + prob_lr = np.random.choice([1, 0], p=p_aug.ravel()) + + if isinstance(sample, dict): + img = sample['img'] + gt = sample['gt'] + + if prob_lr > 0.5: + img = np.fliplr(img).copy() + sample = {'img': img, 'gt': gt} + else: + if prob_lr > 0.5: + sample = np.fliplr(sample).copy() + return sample + + +class ToTensor(object): + def __init__(self): + pass + + def __call__(self, sample): + if isinstance(sample, dict): + img = sample['img'] + gt = np.array(sample['gt']) + img = torch.from_numpy(img).type(torch.FloatTensor) + gt = torch.from_numpy(gt).type(torch.FloatTensor) + sample = {'img': img, 'gt': gt} + else: + sample = torch.from_numpy(sample).type(torch.FloatTensor) + return sample diff --git a/PromptIQA/utils/get_other_metric.py b/PromptIQA/utils/get_other_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..0aa0d7b3069016bac1b881a9329bd1e231250369 --- /dev/null +++ b/PromptIQA/utils/get_other_metric.py @@ -0,0 +1,18 @@ +import os +import numpy as np + +path = '/disk1/chenzewen/OurIdeas/GIQA/GIQA_2024/Formal/Full_model' +file_path = 'Other_Metric_Inference/tid2013_other/inference_log_fix_0_SSIM.log' + +for seed in os.listdir(path): + file = os.path.join(path, seed, file_path) + + srocc, plcc = [], [] + with open(file, 'r') as f: + info = f.readlines()[-1].split() + print(info) + srocc.append(float(info[-3][:-1])) + plcc.append(float(info[-1])) + +print(np.median(srocc)) +print(np.median(plcc)) \ No newline at end of file diff --git a/PromptIQA/utils/iqa_solver.py b/PromptIQA/utils/iqa_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..8593dcb04870f2b2bf06ef609b9f3626e844ab3e --- /dev/null +++ b/PromptIQA/utils/iqa_solver.py @@ -0,0 +1,145 @@ +import torch +from scipy import stats +import numpy as np +from models import monet as MoNet +from models import gc_loss as GC_Loss +from utils.dataset import data_loader +import json +import random +import os +from tqdm import tqdm + + +def get_data(dataset, data_path='./utils/dataset/dataset_info.json'): + ''' + Load dataset information from the json file. + ''' + with open(data_path, 'r') as data_info: + data_info = json.load(data_info) + path, img_num = data_info[dataset] + img_num = list(range(img_num)) + + # Random choose 80% for traning and 20% for testing. + random.shuffle(img_num) + train_index = img_num[0:int(round(0.8 * len(img_num)))] + test_index = img_num[int(round(0.8 * len(img_num))):len(img_num)] + + return path, train_index, test_index + + +def cal_srocc_plcc(pred_score, gt_score): + srocc, _ = stats.spearmanr(pred_score, gt_score) + plcc, _ = stats.pearsonr(pred_score, gt_score) + + return srocc, plcc + + +class Solver: + def __init__(self, config): + + path, train_index, test_index = get_data(dataset=config.dataset) + + train_loader = data_loader.Data_Loader(config, path, train_index, istrain=True) + test_loader = data_loader.Data_Loader(config, path, test_index, istrain=False) + self.train_data = train_loader.get_data() + self.test_data = test_loader.get_data() + + self.promt_data_loader = train_loader.get_prompt() + # for i, j in self.promt_data_loader: + # print(j) + print('Traning data number: ', len(train_index)) + print('Testing data number: ', len(test_index)) + + if config.loss == 'MAE': + self.loss = torch.nn.L1Loss().cuda() + elif config.loss == 'MSE': + self.loss = torch.nn.MSELoss().cuda() + elif config.loss == 'GC': + self.loss = GC_Loss.GC_Loss(queue_len=int(len(train_index) * config.queue_ratio)) + else: + raise 'Only Support MAE, MSE and GC loss.' + + print('Loading MoNet...') + self.MoNet = MoNet.MoNet().cuda() + self.MoNet.train(True) + + self.epochs = config.epochs + self.optimizer = torch.optim.Adam(self.MoNet.parameters(), lr=config.lr, weight_decay=config.weight_decay) + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=config.T_max, eta_min=config.eta_min) + + self.model_save_path = os.path.join(config.save_path, 'best_model.pkl') + + def train(self): + """Training""" + best_srocc = 0.0 + best_plcc = 0.0 + print('----------------------------------') + print('Epoch\tTrain_Loss\tTrain_SROCC\tTrain_PLCC\tTest_SROCC\tTest_PLCC') + for t in range(self.epochs): + epoch_loss = [] + pred_scores = [] + gt_scores = [] + + for index, (img, label) in enumerate(tqdm(self.train_data)): + img = img.cuda() + label = label.cuda() + + # last_item = label[:, -1] + # sorted_last_item, indices = torch.sort(last_item, descending=False) + # img = img[indices].cuda() + # label = label[indices].cuda() + + self.optimizer.zero_grad() + pred, label = self.MoNet(img, label) + + pred_scores = pred_scores + pred.cpu().tolist() + gt_scores = gt_scores + label.cpu().tolist() + + loss = self.loss(pred.squeeze(), label.float().detach()) + epoch_loss.append(loss.item()) + + loss.backward() + self.optimizer.step() + + self.scheduler.step() + + train_srocc, train_plcc = cal_srocc_plcc(pred_scores, gt_scores) + test_srocc, test_plcc = self.test() + if test_srocc + test_plcc > best_srocc + best_plcc: + best_srocc = test_srocc + best_plcc = test_plcc + torch.save(self.MoNet.state_dict(), self.model_save_path) + print('Model saved in: ', self.model_save_path) + + print('{}\t{}\t{}\t{}\t{}\t{}'.format(t + 1, round(np.mean(epoch_loss), 4), round(train_srocc, 4), + round(train_plcc, 4), round(test_srocc, 4), round(test_plcc, 4))) + + print('Best test SROCC {}, PLCC {}'.format(round(best_srocc, 6), round(best_plcc, 6))) + + return best_srocc, best_plcc + + def test(self): + """Testing""" + self.MoNet.train(False) + pred_scores, gt_scores = [], [] + + with torch.no_grad(): + for img, label in self.promt_data_loader: + img = img.cuda() + label = label.cuda() + + # print(img.shape) + self.MoNet.forward_prompt(img, label) + + for img, label in tqdm(self.test_data): + img = img.cuda() + label = label.cuda()[:, 2] + + pred = self.MoNet.inference(img) + + pred_scores = pred_scores + pred.cpu().tolist() + gt_scores = gt_scores + label.cpu().tolist() + + test_srocc, test_plcc = cal_srocc_plcc(pred_scores, gt_scores) + self.MoNet.train(True) + return test_srocc, test_plcc diff --git a/PromptIQA/utils/log_writer.py b/PromptIQA/utils/log_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..fa26e0044692db82dc96b023182495cf63fa2906 --- /dev/null +++ b/PromptIQA/utils/log_writer.py @@ -0,0 +1,14 @@ +import sys + +class Logger(object): + def __init__(self, filename="Default.log"): + self.terminal = sys.stdout + self.log = open(filename, "w") + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + self.flush() + + def flush(self): + self.log.flush() \ No newline at end of file diff --git a/PromptIQA/utils/toolkit.py b/PromptIQA/utils/toolkit.py new file mode 100644 index 0000000000000000000000000000000000000000..8acac7d691ccb8e9995248eaee892aabfde555eb --- /dev/null +++ b/PromptIQA/utils/toolkit.py @@ -0,0 +1,140 @@ +import math +import os +import random +import warnings + +import torch +import torch.backends.cudnn as cudnn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +from scipy import stats +import json + +def printArgs(args, savePath): + with open(os.path.join(savePath, "args_info.log"), "w") as f: + print("--------------args----------------") + f.write("--------------args----------------\n") + for arg in vars(args): + print( + format(arg, "<20"), format(str(getattr(args, arg)), "<") + ) # str, arg_type + f.write( + "{}\t{}\n".format( + format(arg, "<20"), format(str(getattr(args, arg)), "<") + ) + ) # str, arg_type + + print("----------------------------------") + f.write("----------------------------------") + + +def setup_seed(seed): + random.seed(seed) + torch.manual_seed(seed) + cudnn.deterministic = True + warnings.warn( + "You have chosen to seed training. " + "This will turn on the CUDNN deterministic setting, " + "which can slow down your training considerably! " + "You may see unexpected behavior when restarting " + "from checkpoints." + ) + + +def get_data(dataset, data_path='./utils/dataset/dataset_info.json', split_seed=2023): + """ + Load dataset information from the json file. + """ + with open(data_path, "r") as data_info: + data_info = json.load(data_info) + path, img_num = data_info[dataset] + img_num = list(range(img_num)) + + random.seed(split_seed) + random.shuffle(img_num) + + if True or dataset != 'flive': + train_index = img_num[0: int(round(0.8 * len(img_num)))] + test_index = img_num[int(round(0.8 * len(img_num))): len(img_num)] + else: + print('Load FLIVE.') + with open('train_data.json') as f: + res = f.readlines() + + train_index = eval(res[0].strip()) + test_index = eval(res[1].strip()) + + print('Split_seed', split_seed) + print('train_index', train_index[:10], len(train_index)) + print('test_index', test_index[:10], len(test_index)) + + return path, train_index, test_index + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate based on schedule""" + lr = args.lr + lr *= 0.5 * (1.0 + math.cos(math.pi * epoch / args.epochs)) + + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): + torch.save(state, filename) + # if is_best: + # shutil.copyfile(filename, "model_best.pth.tar") + + +def cal_srocc_plcc(pred_score, gt_score): + try: + srocc, _ = stats.spearmanr(pred_score, gt_score) + plcc, _ = stats.pearsonr(pred_score, gt_score) + except: + srocc, plcc = 0, 0 + + return srocc, plcc + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +class ProgressMeter: + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print("\t".join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" diff --git a/README.md b/README.md index ca491cb422b812bf622a3efa1d3b5618ef1d7dc5..fb5d46d6117b3d38b30a0c2631ba118aacd60056 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ colorTo: indigo sdk: gradio sdk_version: 4.44.0 app_file: app.py -pinned: false +pinned: true license: mit --- diff --git a/UI/PromptIQA_Show.py b/UI/PromptIQA_Show.py new file mode 100644 index 0000000000000000000000000000000000000000..c74a942e69b38a1ad75d0ed7023c2edc1ac25995 --- /dev/null +++ b/UI/PromptIQA_Show.py @@ -0,0 +1,79 @@ +import gradio as gr +import json +from PromptIQA import run_promptIQA + +class Main_ui(): + def __init__(self) -> None: # stage_Koniq + self.json_path = 'example.json' + self.promptiqa = run_promptIQA.PromptIQA() + + def load_example(self): + with open(self.json_path, 'r') as f: + info = json.load(f) + + examples = [] + remarks = [] + + for exp in info: + ISPP = exp['ISPP'] + Image = exp['Image'] + Remark = exp['Remark'] + + image, score = [], [] + for ISP_Image, ISP_Score in ISPP: + image.append(ISP_Image) + score.append(float(ISP_Score)) + example = [item for pair in zip(image, score) for item in pair] + + example.append(Image[0]) + example.append(float(Image[1])) + + examples.append(example) + remarks.append(Remark) + + return examples, remarks + + def load_demo(self): + def get_iq_score(*args): + ISPP_I, ISPP_S, image = args[:10], args[10:20], args[-1] + res = self.promptiqa.run(ISPP_I, ISPP_S, image) + return res + + image_components = [] + score_components = [] + + with gr.Blocks() as demo: + gr.Markdown("# PromptIQA: Boosting the Performance and Generalization for No-Reference Image Quality Assessment via Prompts") + gr.Markdown("## 1. Upload the Image-Score Pairs Prompts") + + ISP_idx = 1 + for row_num in [10]: + with gr.Row(): + for i in range(row_num): + with gr.Column(scale=1): + ISP_Image = gr.Image(label=f'Image {ISP_idx}', width=448, height=448) + ISP_Score = gr.Slider(0, 1, label=f"Score {ISP_idx}") + ISP_idx += 1 + + image_components.append(ISP_Image) + score_components.append(ISP_Score) + gr.Markdown("---------------------------------------") + + gr.Markdown("## 2. Upload the image to be evaluated.") + with gr.Row(): + Image_To_Be_Evaluated = gr.Image(label=f'Image To Be Evaluated.', width=512, height=512) + with gr.Column(): + quality_score = gr.Textbox(label='Predicted Quality Score') + pre_button = gr.Button("Get Quality Score") + + pre_button.click(get_iq_score, inputs=image_components + score_components + [Image_To_Be_Evaluated], outputs=[quality_score]) + + examples, remarks = self.load_example() + + gr.Markdown("Examples") + for idx, (remark, example) in enumerate(zip(remarks, examples)): + gr.Markdown(f"### Example{idx + 1}: {remark}") + gr.Examples(examples=[example], inputs=[item for pair in zip(image_components, score_components) for item in pair] + [Image_To_Be_Evaluated, quality_score]) + + + return demo \ No newline at end of file diff --git a/UI/__pycache__/PromptIQA_Show.cpython-38.pyc b/UI/__pycache__/PromptIQA_Show.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd6f40e50ea687ebcb3b4e0dadd5265d5251613c Binary files /dev/null and b/UI/__pycache__/PromptIQA_Show.cpython-38.pyc differ diff --git a/app.py b/app.py index 04cc31aa8d0e06aeaac3b59bb361ed71d831e43f..88758e2f78a5d9c58be2985914856787701b8bdb 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,4 @@ -import gradio as gr - -def greet(name): - return "Hello " + name + "!!" - -demo = gr.Interface(fn=greet, inputs="text", outputs="text") -demo.launch() +if __name__ == '__main__': + from UI.PromptIQA_Show import * + demo = Main_ui().load_demo() + demo.launch() \ No newline at end of file diff --git a/example.json b/example.json new file mode 100644 index 0000000000000000000000000000000000000000..65fe40fd6adb5eef3c160212e8c7a6faa90ee8c6 --- /dev/null +++ b/example.json @@ -0,0 +1,352 @@ +[ + { + "Example_id": "Example1", + "ISPP": [ + [ + "./Examples/Example1/ISPP/1600.AWGN.1.png", + 0.062 + ], + [ + "./Examples/Example1/ISPP/1600.AWGN.2.png", + 0.206 + ], + [ + "./Examples/Example1/ISPP/1600.AWGN.3.png", + 0.262 + ], + [ + "./Examples/Example1/ISPP/1600.AWGN.4.png", + 0.375 + ], + [ + "./Examples/Example1/ISPP/1600.AWGN.5.png", + 0.467 + ], + [ + "./Examples/Example1/ISPP/1600.BLUR.1.png", + 0.043 + ], + [ + "./Examples/Example1/ISPP/1600.BLUR.2.png", + 0.142 + ], + [ + "./Examples/Example1/ISPP/1600.BLUR.3.png", + 0.341 + ], + [ + "./Examples/Example1/ISPP/1600.BLUR.4.png", + 0.471 + ], + [ + "./Examples/Example1/ISPP/1600.BLUR.5.png", + 0.750 + ] + ], + "Image": [ + "./Examples/Example1/cactus.png", + 0.0538 + ], + "Remark": "Scores are from CSIQ." + }, + { + "Example_id": "Example2", + "ISPP": [ + [ + "./Examples/Example1/ISPP/1600.AWGN.1.png", + 0.938 + ], + [ + "./Examples/Example1/ISPP/1600.AWGN.2.png", + 0.794 + ], + [ + "./Examples/Example1/ISPP/1600.AWGN.3.png", + 0.738 + ], + [ + "./Examples/Example1/ISPP/1600.AWGN.4.png", + 0.625 + ], + [ + "./Examples/Example1/ISPP/1600.AWGN.5.png", + 0.533 + ], + [ + "./Examples/Example1/ISPP/1600.BLUR.1.png", + 0.957 + ], + [ + "./Examples/Example1/ISPP/1600.BLUR.2.png", + 0.858 + ], + [ + "./Examples/Example1/ISPP/1600.BLUR.3.png", + 0.659 + ], + [ + "./Examples/Example1/ISPP/1600.BLUR.4.png", + 0.529 + ], + [ + "./Examples/Example1/ISPP/1600.BLUR.5.png", + 0.250 + ] + ], + "Image": [ + "./Examples/Example2/cactus.png", + 0.9508 + ], + "Remark": "Scores are reversed compared with Example1." + }, + { + "Example_id": "Example3", + "ISPP": [ + [ + "./Examples/Example3/ISPP/10986.png", + 0.1557 + ], + [ + "./Examples/Example3/ISPP/10989.png", + 0.5486 + ], + [ + "./Examples/Example3/ISPP/10990.png", + 0.7629 + ], + [ + "./Examples/Example3/ISPP/10992.png", + 0.4214 + ], + [ + "./Examples/Example3/ISPP/10993.png", + 0.8186 + ], + [ + "./Examples/Example3/ISPP/10994.png", + 0.7514 + ], + [ + "./Examples/Example3/ISPP/10995.png", + 0.8286 + ], + [ + "./Examples/Example3/ISPP/10996.png", + 0.3014 + ], + [ + "./Examples/Example3/ISPP/10997.png", + 0.3657 + ], + [ + "./Examples/Example3/ISPP/10998.png", + 0.6371 + ] + ], + "Image": [ + "./Examples/Example3/11074.png", + 0.8501 + ], + "Remark": "Scores are from SPAQ - MOS." + }, + { + "Example_id": "Example4", + "ISPP": [ + [ + "./Examples/Example4/ISPP/10986.png", + 0.8443 + ], + [ + "./Examples/Example4/ISPP/10989.png", + 0.4514 + ], + [ + "./Examples/Example4/ISPP/10990.png", + 0.2371 + ], + [ + "./Examples/Example4/ISPP/10992.png", + 0.5786 + ], + [ + "./Examples/Example4/ISPP/10993.png", + 0.1814 + ], + [ + "./Examples/Example4/ISPP/10994.png", + 0.2486 + ], + [ + "./Examples/Example4/ISPP/10995.png", + 0.1714 + ], + [ + "./Examples/Example4/ISPP/10996.png", + 0.6986 + ], + [ + "./Examples/Example4/ISPP/10997.png", + 0.6343 + ], + [ + "./Examples/Example4/ISPP/10998.png", + 0.3629 + ] + ], + "Image": [ + "./Examples/Example4/11074.png", + 0.0928 + ], + "Remark": "The scores are reversed compared with Example3." + }, + { + "Example_id": "Example5", + "ISPP": [ + [ + "./Examples/Example5/ISPP/10986.png", + 0.2733 + ], + [ + "./Examples/Example5/ISPP/10989.png", + 0.7017 + ], + [ + "./Examples/Example5/ISPP/10990.png", + 0.7883 + ], + [ + "./Examples/Example5/ISPP/10992.png", + 0.7717 + ], + [ + "./Examples/Example5/ISPP/10993.png", + 0.81 + ], + [ + "./Examples/Example5/ISPP/10994.png", + 0.835 + ], + [ + "./Examples/Example5/ISPP/10995.png", + 0.89 + ], + [ + "./Examples/Example5/ISPP/10996.png", + 0.7617 + ], + [ + "./Examples/Example5/ISPP/10997.png", + 0.8333 + ], + [ + "./Examples/Example5/ISPP/10998.png", + 0.7967 + ] + ], + "Image": [ + "./Examples/Example5/11074.png", + 0.9649 + ], + "Remark": "Scores are from SPAQ - Colorfulness." + }, + { + "Example_id": "Example6", + "ISPP": [ + [ + "./Examples/Example5/ISPP/10986.png", + 0.0786 + ], + [ + "./Examples/Example5/ISPP/10989.png", + 0.5929 + ], + [ + "./Examples/Example5/ISPP/10990.png", + 0.6543 + ], + [ + "./Examples/Example5/ISPP/10992.png", + 0.6957 + ], + [ + "./Examples/Example5/ISPP/10993.png", + 0.7929 + ], + [ + "./Examples/Example5/ISPP/10994.png", + 0.7343 + ], + [ + "./Examples/Example5/ISPP/10995.png", + 0.85 + ], + [ + "./Examples/Example5/ISPP/10996.png", + 0.6786 + ], + [ + "./Examples/Example5/ISPP/10997.png", + 0.7883 + ], + [ + "./Examples/Example5/ISPP/10998.png", + 0.5929 + ] + ], + "Image": [ + "./Examples/Example5/11074.png", + 0.9232 + ], + "Remark": "Scores are from SPAQ - Brightness." + }, + { + "Example_id": "Example7", + "ISPP": [ + [ + "./Examples/Example6/ISPP/188.bmp", + 0.6108 + ], + [ + "./Examples/Example6/ISPP/189.bmp", + 0.7761 + ], + [ + "./Examples/Example6/ISPP/190.bmp", + 0.2057 + ], + [ + "./Examples/Example6/ISPP/191.bmp", + 0.8206 + ], + [ + "./Examples/Example6/ISPP/192.bmp", + 0.434 + ], + [ + "./Examples/Example6/ISPP/193.bmp", + 0.6433 + ], + [ + "./Examples/Example6/ISPP/194.bmp", + 0.4919 + ], + [ + "./Examples/Example6/ISPP/195.bmp", + 0.4603 + ], + [ + "./Examples/Example6/ISPP/196.bmp", + 0.7064 + ], + [ + "./Examples/Example6/ISPP/197.bmp", + 0.6741 + ] + ], + "Image": [ + "./Examples/Example6/198.bmp", + 7 + ], + "Remark": "Scores are from LIVEC." + } +] \ No newline at end of file diff --git a/get_examplt.py b/get_examplt.py new file mode 100644 index 0000000000000000000000000000000000000000..1c04d4a8a1f3badebd451b0583d5d50e616698a1 --- /dev/null +++ b/get_examplt.py @@ -0,0 +1,27 @@ +import os +from copy import deepcopy + + +isp_json = [] +path = './Examples' +for img_dir in sorted(os.listdir(path)): + if os.path.isdir(os.path.join(path, img_dir)): + ISPP = os.path.join(path, img_dir, 'ISPP') + + ispp = {} + ispp['Example_id'] = img_dir + ispp['ISPP'] = [] + img_list = [] + for idx, img in enumerate(sorted(os.listdir(ISPP))): + ispp['ISPP'].append([os.path.join(ISPP, img), idx / 10 if '1' in img_dir else 1 - idx / 10]) + + for file in os.listdir(os.path.join(path, img_dir)): + if os.path.isfile(os.path.join(path, img_dir, file)): + img_list.append(file) + ispp['Image'] = [os.path.join(path, img_dir, file), 7] + ispp['Remark'] = [] + isp_json.append(deepcopy(ispp)) + +with open('example2.json', 'w') as f: + import json + json.dump(isp_json, f, indent=4) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..cc44f8f5afded3e79e79a3f2bbf9e60127d60c90 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +--find-links https://download.pytorch.org/whl/torch_stable.html +einops==0.6.1 +numpy==1.24.3 +opencv_python==4.8.0.76 +openpyxl==3.0.9 +Pillow==10.0.0 +scipy +timm==0.5.4 +tqdm==4.61.2 +torch==1.11.0+cu113 \ No newline at end of file