diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..5dc76a3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,63 @@ +Copyright (c) 2024, NVIDIA Corporation. All rights reserved. + +Nvidia Source Code License-NC + +1. Definitions + +“Licensor” means any person or entity that distributes its Work. + +“Work” means (a) the original work of authorship made available under this license, which may include software, documentation, +or other files, and (b) any additions to or derivative works thereof that are made available under this license. + +The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. +copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that +remain separable from, or merely link (or bind by name) to the interfaces of, the Work. + +Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing +the applicability of this license to the Work, or (b) a copy of this license. + +2. License Grant + +2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, +worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly +display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. + +3. Limitations + +3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a +complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, +trademark, or attribution notices that are present in the Work. + +3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution +of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 +applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. +Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply +to the Work itself. + +3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. +Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. +As used herein, “non-commercially” means for research or evaluation purposes only. + +3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim +or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under +this license from such Licensor (including the grant in Section 2.1) will terminate immediately. + +3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, +except as necessary to reproduce the notices described in this license. + +3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) +will terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES +OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING +ANY ACTIVITIES UNDER THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, +OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL +DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, +BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR +HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. diff --git a/config.json b/config.json new file mode 100644 index 0000000..7738929 --- /dev/null +++ b/config.json @@ -0,0 +1,1051 @@ +{ + "architectures": [ + "MambaVisionModelForImageClassification" + ], + "auto_map": { + "AutoConfig": "configuration_mambavision.MambaVisionConfig", + "AutoModel": "modeling_mambavision.MambaVisionModel", + "AutoModelForImageClassification": "modeling_mambavision.MambaVisionModelForImageClassification" + }, + "crop_mode": "center", + "crop_pct": 1.0, + "mean": [ + 0.485, + 0.456, + 0.406 + ], + "std": [ + 0.229, + 0.224, + 0.225 + ], + "depths": [ + 3, + 3, + 12, + 5 + ], + "id2label": { + "0": "tench, Tinca tinca", + "1": "goldfish, Carassius auratus", + "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", + "3": "tiger shark, Galeocerdo cuvieri", + "4": "hammerhead, hammerhead shark", + "5": "electric ray, crampfish, numbfish, torpedo", + "6": "stingray", + "7": "cock", + "8": "hen", + "9": "ostrich, Struthio camelus", + "10": "brambling, Fringilla montifringilla", + "11": "goldfinch, Carduelis carduelis", + "12": "house finch, linnet, Carpodacus mexicanus", + "13": "junco, snowbird", + "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea", + "15": "robin, American robin, Turdus migratorius", + "16": "bulbul", + "17": "jay", + "18": "magpie", + "19": "chickadee", + "20": "water ouzel, dipper", + "21": "kite", + "22": "bald eagle, American eagle, Haliaeetus leucocephalus", + "23": "vulture", + "24": "great grey owl, great gray owl, Strix nebulosa", + "25": "European fire salamander, Salamandra salamandra", + "26": "common newt, Triturus vulgaris", + "27": "eft", + "28": "spotted salamander, Ambystoma maculatum", + "29": "axolotl, mud puppy, Ambystoma mexicanum", + "30": "bullfrog, Rana catesbeiana", + "31": "tree frog, tree-frog", + "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", + "33": "loggerhead, loggerhead turtle, Caretta caretta", + "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", + "35": "mud turtle", + "36": "terrapin", + "37": "box turtle, box tortoise", + "38": "banded gecko", + "39": "common iguana, iguana, Iguana iguana", + "40": "American chameleon, anole, Anolis carolinensis", + "41": "whiptail, whiptail lizard", + "42": "agama", + "43": "frilled lizard, Chlamydosaurus kingi", + "44": "alligator lizard", + "45": "Gila monster, Heloderma suspectum", + "46": "green lizard, Lacerta viridis", + "47": "African chameleon, Chamaeleo chamaeleon", + "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", + "49": "African crocodile, Nile crocodile, Crocodylus niloticus", + "50": "American alligator, Alligator mississipiensis", + "51": "triceratops", + "52": "thunder snake, worm snake, Carphophis amoenus", + "53": "ringneck snake, ring-necked snake, ring snake", + "54": "hognose snake, puff adder, sand viper", + "55": "green snake, grass snake", + "56": "king snake, kingsnake", + "57": "garter snake, grass snake", + "58": "water snake", + "59": "vine snake", + "60": "night snake, Hypsiglena torquata", + "61": "boa constrictor, Constrictor constrictor", + "62": "rock python, rock snake, Python sebae", + "63": "Indian cobra, Naja naja", + "64": "green mamba", + "65": "sea snake", + "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", + "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus", + "68": "sidewinder, horned rattlesnake, Crotalus cerastes", + "69": "trilobite", + "70": "harvestman, daddy longlegs, Phalangium opilio", + "71": "scorpion", + "72": "black and gold garden spider, Argiope aurantia", + "73": "barn spider, Araneus cavaticus", + "74": "garden spider, Aranea diademata", + "75": "black widow, Latrodectus mactans", + "76": "tarantula", + "77": "wolf spider, hunting spider", + "78": "tick", + "79": "centipede", + "80": "black grouse", + "81": "ptarmigan", + "82": "ruffed grouse, partridge, Bonasa umbellus", + "83": "prairie chicken, prairie grouse, prairie fowl", + "84": "peacock", + "85": "quail", + "86": "partridge", + "87": "African grey, African gray, Psittacus erithacus", + "88": "macaw", + "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", + "90": "lorikeet", + "91": "coucal", + "92": "bee eater", + "93": "hornbill", + "94": "hummingbird", + "95": "jacamar", + "96": "toucan", + "97": "drake", + "98": "red-breasted merganser, Mergus serrator", + "99": "goose", + "100": "black swan, Cygnus atratus", + "101": "tusker", + "102": "echidna, spiny anteater, anteater", + "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", + "104": "wallaby, brush kangaroo", + "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", + "106": "wombat", + "107": "jellyfish", + "108": "sea anemone, anemone", + "109": "brain coral", + "110": "flatworm, platyhelminth", + "111": "nematode, nematode worm, roundworm", + "112": "conch", + "113": "snail", + "114": "slug", + "115": "sea slug, nudibranch", + "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore", + "117": "chambered nautilus, pearly nautilus, nautilus", + "118": "Dungeness crab, Cancer magister", + "119": "rock crab, Cancer irroratus", + "120": "fiddler crab", + "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", + "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", + "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", + "124": "crayfish, crawfish, crawdad, crawdaddy", + "125": "hermit crab", + "126": "isopod", + "127": "white stork, Ciconia ciconia", + "128": "black stork, Ciconia nigra", + "129": "spoonbill", + "130": "flamingo", + "131": "little blue heron, Egretta caerulea", + "132": "American egret, great white heron, Egretta albus", + "133": "bittern", + "134": "crane", + "135": "limpkin, Aramus pictus", + "136": "European gallinule, Porphyrio porphyrio", + "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", + "138": "bustard", + "139": "ruddy turnstone, Arenaria interpres", + "140": "red-backed sandpiper, dunlin, Erolia alpina", + "141": "redshank, Tringa totanus", + "142": "dowitcher", + "143": "oystercatcher, oyster catcher", + "144": "pelican", + "145": "king penguin, Aptenodytes patagonica", + "146": "albatross, mollymawk", + "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", + "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", + "149": "dugong, Dugong dugon", + "150": "sea lion", + "151": "Chihuahua", + "152": "Japanese spaniel", + "153": "Maltese dog, Maltese terrier, Maltese", + "154": "Pekinese, Pekingese, Peke", + "155": "Shih-Tzu", + "156": "Blenheim spaniel", + "157": "papillon", + "158": "toy terrier", + "159": "Rhodesian ridgeback", + "160": "Afghan hound, Afghan", + "161": "basset, basset hound", + "162": "beagle", + "163": "bloodhound, sleuthhound", + "164": "bluetick", + "165": "black-and-tan coonhound", + "166": "Walker hound, Walker foxhound", + "167": "English foxhound", + "168": "redbone", + "169": "borzoi, Russian wolfhound", + "170": "Irish wolfhound", + "171": "Italian greyhound", + "172": "whippet", + "173": "Ibizan hound, Ibizan Podenco", + "174": "Norwegian elkhound, elkhound", + "175": "otterhound, otter hound", + "176": "Saluki, gazelle hound", + "177": "Scottish deerhound, deerhound", + "178": "Weimaraner", + "179": "Staffordshire bullterrier, Staffordshire bull terrier", + "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", + "181": "Bedlington terrier", + "182": "Border terrier", + "183": "Kerry blue terrier", + "184": "Irish terrier", + "185": "Norfolk terrier", + "186": "Norwich terrier", + "187": "Yorkshire terrier", + "188": "wire-haired fox terrier", + "189": "Lakeland terrier", + "190": "Sealyham terrier, Sealyham", + "191": "Airedale, Airedale terrier", + "192": "cairn, cairn terrier", + "193": "Australian terrier", + "194": "Dandie Dinmont, Dandie Dinmont terrier", + "195": "Boston bull, Boston terrier", + "196": "miniature schnauzer", + "197": "giant schnauzer", + "198": "standard schnauzer", + "199": "Scotch terrier, Scottish terrier, Scottie", + "200": "Tibetan terrier, chrysanthemum dog", + "201": "silky terrier, Sydney silky", + "202": "soft-coated wheaten terrier", + "203": "West Highland white terrier", + "204": "Lhasa, Lhasa apso", + "205": "flat-coated retriever", + "206": "curly-coated retriever", + "207": "golden retriever", + "208": "Labrador retriever", + "209": "Chesapeake Bay retriever", + "210": "German short-haired pointer", + "211": "vizsla, Hungarian pointer", + "212": "English setter", + "213": "Irish setter, red setter", + "214": "Gordon setter", + "215": "Brittany spaniel", + "216": "clumber, clumber spaniel", + "217": "English springer, English springer spaniel", + "218": "Welsh springer spaniel", + "219": "cocker spaniel, English cocker spaniel, cocker", + "220": "Sussex spaniel", + "221": "Irish water spaniel", + "222": "kuvasz", + "223": "schipperke", + "224": "groenendael", + "225": "malinois", + "226": "briard", + "227": "kelpie", + "228": "komondor", + "229": "Old English sheepdog, bobtail", + "230": "Shetland sheepdog, Shetland sheep dog, Shetland", + "231": "collie", + "232": "Border collie", + "233": "Bouvier des Flandres, Bouviers des Flandres", + "234": "Rottweiler", + "235": "German shepherd, German shepherd dog, German police dog, alsatian", + "236": "Doberman, Doberman pinscher", + "237": "miniature pinscher", + "238": "Greater Swiss Mountain dog", + "239": "Bernese mountain dog", + "240": "Appenzeller", + "241": "EntleBucher", + "242": "boxer", + "243": "bull mastiff", + "244": "Tibetan mastiff", + "245": "French bulldog", + "246": "Great Dane", + "247": "Saint Bernard, St Bernard", + "248": "Eskimo dog, husky", + "249": "malamute, malemute, Alaskan malamute", + "250": "Siberian husky", + "251": "dalmatian, coach dog, carriage dog", + "252": "affenpinscher, monkey pinscher, monkey dog", + "253": "basenji", + "254": "pug, pug-dog", + "255": "Leonberg", + "256": "Newfoundland, Newfoundland dog", + "257": "Great Pyrenees", + "258": "Samoyed, Samoyede", + "259": "Pomeranian", + "260": "chow, chow chow", + "261": "keeshond", + "262": "Brabancon griffon", + "263": "Pembroke, Pembroke Welsh corgi", + "264": "Cardigan, Cardigan Welsh corgi", + "265": "toy poodle", + "266": "miniature poodle", + "267": "standard poodle", + "268": "Mexican hairless", + "269": "timber wolf, grey wolf, gray wolf, Canis lupus", + "270": "white wolf, Arctic wolf, Canis lupus tundrarum", + "271": "red wolf, maned wolf, Canis rufus, Canis niger", + "272": "coyote, prairie wolf, brush wolf, Canis latrans", + "273": "dingo, warrigal, warragal, Canis dingo", + "274": "dhole, Cuon alpinus", + "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", + "276": "hyena, hyaena", + "277": "red fox, Vulpes vulpes", + "278": "kit fox, Vulpes macrotis", + "279": "Arctic fox, white fox, Alopex lagopus", + "280": "grey fox, gray fox, Urocyon cinereoargenteus", + "281": "tabby, tabby cat", + "282": "tiger cat", + "283": "Persian cat", + "284": "Siamese cat, Siamese", + "285": "Egyptian cat", + "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", + "287": "lynx, catamount", + "288": "leopard, Panthera pardus", + "289": "snow leopard, ounce, Panthera uncia", + "290": "jaguar, panther, Panthera onca, Felis onca", + "291": "lion, king of beasts, Panthera leo", + "292": "tiger, Panthera tigris", + "293": "cheetah, chetah, Acinonyx jubatus", + "294": "brown bear, bruin, Ursus arctos", + "295": "American black bear, black bear, Ursus americanus, Euarctos americanus", + "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", + "297": "sloth bear, Melursus ursinus, Ursus ursinus", + "298": "mongoose", + "299": "meerkat, mierkat", + "300": "tiger beetle", + "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", + "302": "ground beetle, carabid beetle", + "303": "long-horned beetle, longicorn, longicorn beetle", + "304": "leaf beetle, chrysomelid", + "305": "dung beetle", + "306": "rhinoceros beetle", + "307": "weevil", + "308": "fly", + "309": "bee", + "310": "ant, emmet, pismire", + "311": "grasshopper, hopper", + "312": "cricket", + "313": "walking stick, walkingstick, stick insect", + "314": "cockroach, roach", + "315": "mantis, mantid", + "316": "cicada, cicala", + "317": "leafhopper", + "318": "lacewing, lacewing fly", + "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + "320": "damselfly", + "321": "admiral", + "322": "ringlet, ringlet butterfly", + "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", + "324": "cabbage butterfly", + "325": "sulphur butterfly, sulfur butterfly", + "326": "lycaenid, lycaenid butterfly", + "327": "starfish, sea star", + "328": "sea urchin", + "329": "sea cucumber, holothurian", + "330": "wood rabbit, cottontail, cottontail rabbit", + "331": "hare", + "332": "Angora, Angora rabbit", + "333": "hamster", + "334": "porcupine, hedgehog", + "335": "fox squirrel, eastern fox squirrel, Sciurus niger", + "336": "marmot", + "337": "beaver", + "338": "guinea pig, Cavia cobaya", + "339": "sorrel", + "340": "zebra", + "341": "hog, pig, grunter, squealer, Sus scrofa", + "342": "wild boar, boar, Sus scrofa", + "343": "warthog", + "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius", + "345": "ox", + "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", + "347": "bison", + "348": "ram, tup", + "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", + "350": "ibex, Capra ibex", + "351": "hartebeest", + "352": "impala, Aepyceros melampus", + "353": "gazelle", + "354": "Arabian camel, dromedary, Camelus dromedarius", + "355": "llama", + "356": "weasel", + "357": "mink", + "358": "polecat, fitch, foulmart, foumart, Mustela putorius", + "359": "black-footed ferret, ferret, Mustela nigripes", + "360": "otter", + "361": "skunk, polecat, wood pussy", + "362": "badger", + "363": "armadillo", + "364": "three-toed sloth, ai, Bradypus tridactylus", + "365": "orangutan, orang, orangutang, Pongo pygmaeus", + "366": "gorilla, Gorilla gorilla", + "367": "chimpanzee, chimp, Pan troglodytes", + "368": "gibbon, Hylobates lar", + "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus", + "370": "guenon, guenon monkey", + "371": "patas, hussar monkey, Erythrocebus patas", + "372": "baboon", + "373": "macaque", + "374": "langur", + "375": "colobus, colobus monkey", + "376": "proboscis monkey, Nasalis larvatus", + "377": "marmoset", + "378": "capuchin, ringtail, Cebus capucinus", + "379": "howler monkey, howler", + "380": "titi, titi monkey", + "381": "spider monkey, Ateles geoffroyi", + "382": "squirrel monkey, Saimiri sciureus", + "383": "Madagascar cat, ring-tailed lemur, Lemur catta", + "384": "indri, indris, Indri indri, Indri brevicaudatus", + "385": "Indian elephant, Elephas maximus", + "386": "African elephant, Loxodonta africana", + "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", + "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", + "389": "barracouta, snoek", + "390": "eel", + "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", + "392": "rock beauty, Holocanthus tricolor", + "393": "anemone fish", + "394": "sturgeon", + "395": "gar, garfish, garpike, billfish, Lepisosteus osseus", + "396": "lionfish", + "397": "puffer, pufferfish, blowfish, globefish", + "398": "abacus", + "399": "abaya", + "400": "academic gown, academic robe, judge's robe", + "401": "accordion, piano accordion, squeeze box", + "402": "acoustic guitar", + "403": "aircraft carrier, carrier, flattop, attack aircraft carrier", + "404": "airliner", + "405": "airship, dirigible", + "406": "altar", + "407": "ambulance", + "408": "amphibian, amphibious vehicle", + "409": "analog clock", + "410": "apiary, bee house", + "411": "apron", + "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", + "413": "assault rifle, assault gun", + "414": "backpack, back pack, knapsack, packsack, rucksack, haversack", + "415": "bakery, bakeshop, bakehouse", + "416": "balance beam, beam", + "417": "balloon", + "418": "ballpoint, ballpoint pen, ballpen, Biro", + "419": "Band Aid", + "420": "banjo", + "421": "bannister, banister, balustrade, balusters, handrail", + "422": "barbell", + "423": "barber chair", + "424": "barbershop", + "425": "barn", + "426": "barometer", + "427": "barrel, cask", + "428": "barrow, garden cart, lawn cart, wheelbarrow", + "429": "baseball", + "430": "basketball", + "431": "bassinet", + "432": "bassoon", + "433": "bathing cap, swimming cap", + "434": "bath towel", + "435": "bathtub, bathing tub, bath, tub", + "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", + "437": "beacon, lighthouse, beacon light, pharos", + "438": "beaker", + "439": "bearskin, busby, shako", + "440": "beer bottle", + "441": "beer glass", + "442": "bell cote, bell cot", + "443": "bib", + "444": "bicycle-built-for-two, tandem bicycle, tandem", + "445": "bikini, two-piece", + "446": "binder, ring-binder", + "447": "binoculars, field glasses, opera glasses", + "448": "birdhouse", + "449": "boathouse", + "450": "bobsled, bobsleigh, bob", + "451": "bolo tie, bolo, bola tie, bola", + "452": "bonnet, poke bonnet", + "453": "bookcase", + "454": "bookshop, bookstore, bookstall", + "455": "bottlecap", + "456": "bow", + "457": "bow tie, bow-tie, bowtie", + "458": "brass, memorial tablet, plaque", + "459": "brassiere, bra, bandeau", + "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty", + "461": "breastplate, aegis, egis", + "462": "broom", + "463": "bucket, pail", + "464": "buckle", + "465": "bulletproof vest", + "466": "bullet train, bullet", + "467": "butcher shop, meat market", + "468": "cab, hack, taxi, taxicab", + "469": "caldron, cauldron", + "470": "candle, taper, wax light", + "471": "cannon", + "472": "canoe", + "473": "can opener, tin opener", + "474": "cardigan", + "475": "car mirror", + "476": "carousel, carrousel, merry-go-round, roundabout, whirligig", + "477": "carpenter's kit, tool kit", + "478": "carton", + "479": "car wheel", + "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", + "481": "cassette", + "482": "cassette player", + "483": "castle", + "484": "catamaran", + "485": "CD player", + "486": "cello, violoncello", + "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone", + "488": "chain", + "489": "chainlink fence", + "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", + "491": "chain saw, chainsaw", + "492": "chest", + "493": "chiffonier, commode", + "494": "chime, bell, gong", + "495": "china cabinet, china closet", + "496": "Christmas stocking", + "497": "church, church building", + "498": "cinema, movie theater, movie theatre, movie house, picture palace", + "499": "cleaver, meat cleaver, chopper", + "500": "cliff dwelling", + "501": "cloak", + "502": "clog, geta, patten, sabot", + "503": "cocktail shaker", + "504": "coffee mug", + "505": "coffeepot", + "506": "coil, spiral, volute, whorl, helix", + "507": "combination lock", + "508": "computer keyboard, keypad", + "509": "confectionery, confectionary, candy store", + "510": "container ship, containership, container vessel", + "511": "convertible", + "512": "corkscrew, bottle screw", + "513": "cornet, horn, trumpet, trump", + "514": "cowboy boot", + "515": "cowboy hat, ten-gallon hat", + "516": "cradle", + "517": "crane", + "518": "crash helmet", + "519": "crate", + "520": "crib, cot", + "521": "Crock Pot", + "522": "croquet ball", + "523": "crutch", + "524": "cuirass", + "525": "dam, dike, dyke", + "526": "desk", + "527": "desktop computer", + "528": "dial telephone, dial phone", + "529": "diaper, nappy, napkin", + "530": "digital clock", + "531": "digital watch", + "532": "dining table, board", + "533": "dishrag, dishcloth", + "534": "dishwasher, dish washer, dishwashing machine", + "535": "disk brake, disc brake", + "536": "dock, dockage, docking facility", + "537": "dogsled, dog sled, dog sleigh", + "538": "dome", + "539": "doormat, welcome mat", + "540": "drilling platform, offshore rig", + "541": "drum, membranophone, tympan", + "542": "drumstick", + "543": "dumbbell", + "544": "Dutch oven", + "545": "electric fan, blower", + "546": "electric guitar", + "547": "electric locomotive", + "548": "entertainment center", + "549": "envelope", + "550": "espresso maker", + "551": "face powder", + "552": "feather boa, boa", + "553": "file, file cabinet, filing cabinet", + "554": "fireboat", + "555": "fire engine, fire truck", + "556": "fire screen, fireguard", + "557": "flagpole, flagstaff", + "558": "flute, transverse flute", + "559": "folding chair", + "560": "football helmet", + "561": "forklift", + "562": "fountain", + "563": "fountain pen", + "564": "four-poster", + "565": "freight car", + "566": "French horn, horn", + "567": "frying pan, frypan, skillet", + "568": "fur coat", + "569": "garbage truck, dustcart", + "570": "gasmask, respirator, gas helmet", + "571": "gas pump, gasoline pump, petrol pump, island dispenser", + "572": "goblet", + "573": "go-kart", + "574": "golf ball", + "575": "golfcart, golf cart", + "576": "gondola", + "577": "gong, tam-tam", + "578": "gown", + "579": "grand piano, grand", + "580": "greenhouse, nursery, glasshouse", + "581": "grille, radiator grille", + "582": "grocery store, grocery, food market, market", + "583": "guillotine", + "584": "hair slide", + "585": "hair spray", + "586": "half track", + "587": "hammer", + "588": "hamper", + "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier", + "590": "hand-held computer, hand-held microcomputer", + "591": "handkerchief, hankie, hanky, hankey", + "592": "hard disc, hard disk, fixed disk", + "593": "harmonica, mouth organ, harp, mouth harp", + "594": "harp", + "595": "harvester, reaper", + "596": "hatchet", + "597": "holster", + "598": "home theater, home theatre", + "599": "honeycomb", + "600": "hook, claw", + "601": "hoopskirt, crinoline", + "602": "horizontal bar, high bar", + "603": "horse cart, horse-cart", + "604": "hourglass", + "605": "iPod", + "606": "iron, smoothing iron", + "607": "jack-o'-lantern", + "608": "jean, blue jean, denim", + "609": "jeep, landrover", + "610": "jersey, T-shirt, tee shirt", + "611": "jigsaw puzzle", + "612": "jinrikisha, ricksha, rickshaw", + "613": "joystick", + "614": "kimono", + "615": "knee pad", + "616": "knot", + "617": "lab coat, laboratory coat", + "618": "ladle", + "619": "lampshade, lamp shade", + "620": "laptop, laptop computer", + "621": "lawn mower, mower", + "622": "lens cap, lens cover", + "623": "letter opener, paper knife, paperknife", + "624": "library", + "625": "lifeboat", + "626": "lighter, light, igniter, ignitor", + "627": "limousine, limo", + "628": "liner, ocean liner", + "629": "lipstick, lip rouge", + "630": "Loafer", + "631": "lotion", + "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", + "633": "loupe, jeweler's loupe", + "634": "lumbermill, sawmill", + "635": "magnetic compass", + "636": "mailbag, postbag", + "637": "mailbox, letter box", + "638": "maillot", + "639": "maillot, tank suit", + "640": "manhole cover", + "641": "maraca", + "642": "marimba, xylophone", + "643": "mask", + "644": "matchstick", + "645": "maypole", + "646": "maze, labyrinth", + "647": "measuring cup", + "648": "medicine chest, medicine cabinet", + "649": "megalith, megalithic structure", + "650": "microphone, mike", + "651": "microwave, microwave oven", + "652": "military uniform", + "653": "milk can", + "654": "minibus", + "655": "miniskirt, mini", + "656": "minivan", + "657": "missile", + "658": "mitten", + "659": "mixing bowl", + "660": "mobile home, manufactured home", + "661": "Model T", + "662": "modem", + "663": "monastery", + "664": "monitor", + "665": "moped", + "666": "mortar", + "667": "mortarboard", + "668": "mosque", + "669": "mosquito net", + "670": "motor scooter, scooter", + "671": "mountain bike, all-terrain bike, off-roader", + "672": "mountain tent", + "673": "mouse, computer mouse", + "674": "mousetrap", + "675": "moving van", + "676": "muzzle", + "677": "nail", + "678": "neck brace", + "679": "necklace", + "680": "nipple", + "681": "notebook, notebook computer", + "682": "obelisk", + "683": "oboe, hautboy, hautbois", + "684": "ocarina, sweet potato", + "685": "odometer, hodometer, mileometer, milometer", + "686": "oil filter", + "687": "organ, pipe organ", + "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO", + "689": "overskirt", + "690": "oxcart", + "691": "oxygen mask", + "692": "packet", + "693": "paddle, boat paddle", + "694": "paddlewheel, paddle wheel", + "695": "padlock", + "696": "paintbrush", + "697": "pajama, pyjama, pj's, jammies", + "698": "palace", + "699": "panpipe, pandean pipe, syrinx", + "700": "paper towel", + "701": "parachute, chute", + "702": "parallel bars, bars", + "703": "park bench", + "704": "parking meter", + "705": "passenger car, coach, carriage", + "706": "patio, terrace", + "707": "pay-phone, pay-station", + "708": "pedestal, plinth, footstall", + "709": "pencil box, pencil case", + "710": "pencil sharpener", + "711": "perfume, essence", + "712": "Petri dish", + "713": "photocopier", + "714": "pick, plectrum, plectron", + "715": "pickelhaube", + "716": "picket fence, paling", + "717": "pickup, pickup truck", + "718": "pier", + "719": "piggy bank, penny bank", + "720": "pill bottle", + "721": "pillow", + "722": "ping-pong ball", + "723": "pinwheel", + "724": "pirate, pirate ship", + "725": "pitcher, ewer", + "726": "plane, carpenter's plane, woodworking plane", + "727": "planetarium", + "728": "plastic bag", + "729": "plate rack", + "730": "plow, plough", + "731": "plunger, plumber's helper", + "732": "Polaroid camera, Polaroid Land camera", + "733": "pole", + "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", + "735": "poncho", + "736": "pool table, billiard table, snooker table", + "737": "pop bottle, soda bottle", + "738": "pot, flowerpot", + "739": "potter's wheel", + "740": "power drill", + "741": "prayer rug, prayer mat", + "742": "printer", + "743": "prison, prison house", + "744": "projectile, missile", + "745": "projector", + "746": "puck, hockey puck", + "747": "punching bag, punch bag, punching ball, punchball", + "748": "purse", + "749": "quill, quill pen", + "750": "quilt, comforter, comfort, puff", + "751": "racer, race car, racing car", + "752": "racket, racquet", + "753": "radiator", + "754": "radio, wireless", + "755": "radio telescope, radio reflector", + "756": "rain barrel", + "757": "recreational vehicle, RV, R.V.", + "758": "reel", + "759": "reflex camera", + "760": "refrigerator, icebox", + "761": "remote control, remote", + "762": "restaurant, eating house, eating place, eatery", + "763": "revolver, six-gun, six-shooter", + "764": "rifle", + "765": "rocking chair, rocker", + "766": "rotisserie", + "767": "rubber eraser, rubber, pencil eraser", + "768": "rugby ball", + "769": "rule, ruler", + "770": "running shoe", + "771": "safe", + "772": "safety pin", + "773": "saltshaker, salt shaker", + "774": "sandal", + "775": "sarong", + "776": "sax, saxophone", + "777": "scabbard", + "778": "scale, weighing machine", + "779": "school bus", + "780": "schooner", + "781": "scoreboard", + "782": "screen, CRT screen", + "783": "screw", + "784": "screwdriver", + "785": "seat belt, seatbelt", + "786": "sewing machine", + "787": "shield, buckler", + "788": "shoe shop, shoe-shop, shoe store", + "789": "shoji", + "790": "shopping basket", + "791": "shopping cart", + "792": "shovel", + "793": "shower cap", + "794": "shower curtain", + "795": "ski", + "796": "ski mask", + "797": "sleeping bag", + "798": "slide rule, slipstick", + "799": "sliding door", + "800": "slot, one-armed bandit", + "801": "snorkel", + "802": "snowmobile", + "803": "snowplow, snowplough", + "804": "soap dispenser", + "805": "soccer ball", + "806": "sock", + "807": "solar dish, solar collector, solar furnace", + "808": "sombrero", + "809": "soup bowl", + "810": "space bar", + "811": "space heater", + "812": "space shuttle", + "813": "spatula", + "814": "speedboat", + "815": "spider web, spider's web", + "816": "spindle", + "817": "sports car, sport car", + "818": "spotlight, spot", + "819": "stage", + "820": "steam locomotive", + "821": "steel arch bridge", + "822": "steel drum", + "823": "stethoscope", + "824": "stole", + "825": "stone wall", + "826": "stopwatch, stop watch", + "827": "stove", + "828": "strainer", + "829": "streetcar, tram, tramcar, trolley, trolley car", + "830": "stretcher", + "831": "studio couch, day bed", + "832": "stupa, tope", + "833": "submarine, pigboat, sub, U-boat", + "834": "suit, suit of clothes", + "835": "sundial", + "836": "sunglass", + "837": "sunglasses, dark glasses, shades", + "838": "sunscreen, sunblock, sun blocker", + "839": "suspension bridge", + "840": "swab, swob, mop", + "841": "sweatshirt", + "842": "swimming trunks, bathing trunks", + "843": "swing", + "844": "switch, electric switch, electrical switch", + "845": "syringe", + "846": "table lamp", + "847": "tank, army tank, armored combat vehicle, armoured combat vehicle", + "848": "tape player", + "849": "teapot", + "850": "teddy, teddy bear", + "851": "television, television system", + "852": "tennis ball", + "853": "thatch, thatched roof", + "854": "theater curtain, theatre curtain", + "855": "thimble", + "856": "thresher, thrasher, threshing machine", + "857": "throne", + "858": "tile roof", + "859": "toaster", + "860": "tobacco shop, tobacconist shop, tobacconist", + "861": "toilet seat", + "862": "torch", + "863": "totem pole", + "864": "tow truck, tow car, wrecker", + "865": "toyshop", + "866": "tractor", + "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", + "868": "tray", + "869": "trench coat", + "870": "tricycle, trike, velocipede", + "871": "trimaran", + "872": "tripod", + "873": "triumphal arch", + "874": "trolleybus, trolley coach, trackless trolley", + "875": "trombone", + "876": "tub, vat", + "877": "turnstile", + "878": "typewriter keyboard", + "879": "umbrella", + "880": "unicycle, monocycle", + "881": "upright, upright piano", + "882": "vacuum, vacuum cleaner", + "883": "vase", + "884": "vault", + "885": "velvet", + "886": "vending machine", + "887": "vestment", + "888": "viaduct", + "889": "violin, fiddle", + "890": "volleyball", + "891": "waffle iron", + "892": "wall clock", + "893": "wallet, billfold, notecase, pocketbook", + "894": "wardrobe, closet, press", + "895": "warplane, military plane", + "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin", + "897": "washer, automatic washer, washing machine", + "898": "water bottle", + "899": "water jug", + "900": "water tower", + "901": "whiskey jug", + "902": "whistle", + "903": "wig", + "904": "window screen", + "905": "window shade", + "906": "Windsor tie", + "907": "wine bottle", + "908": "wing", + "909": "wok", + "910": "wooden spoon", + "911": "wool, woolen, woollen", + "912": "worm fence, snake fence, snake-rail fence, Virginia fence", + "913": "wreck", + "914": "yawl", + "915": "yurt", + "916": "web site, website, internet site, site", + "917": "comic book", + "918": "crossword puzzle, crossword", + "919": "street sign", + "920": "traffic light, traffic signal, stoplight", + "921": "book jacket, dust cover, dust jacket, dust wrapper", + "922": "menu", + "923": "plate", + "924": "guacamole", + "925": "consomme", + "926": "hot pot, hotpot", + "927": "trifle", + "928": "ice cream, icecream", + "929": "ice lolly, lolly, lollipop, popsicle", + "930": "French loaf", + "931": "bagel, beigel", + "932": "pretzel", + "933": "cheeseburger", + "934": "hotdog, hot dog, red hot", + "935": "mashed potato", + "936": "head cabbage", + "937": "broccoli", + "938": "cauliflower", + "939": "zucchini, courgette", + "940": "spaghetti squash", + "941": "acorn squash", + "942": "butternut squash", + "943": "cucumber, cuke", + "944": "artichoke, globe artichoke", + "945": "bell pepper", + "946": "cardoon", + "947": "mushroom", + "948": "Granny Smith", + "949": "strawberry", + "950": "orange", + "951": "lemon", + "952": "fig", + "953": "pineapple, ananas", + "954": "banana", + "955": "jackfruit, jak, jack", + "956": "custard apple", + "957": "pomegranate", + "958": "hay", + "959": "carbonara", + "960": "chocolate sauce, chocolate syrup", + "961": "dough", + "962": "meat loaf, meatloaf", + "963": "pizza, pizza pie", + "964": "potpie", + "965": "burrito", + "966": "red wine", + "967": "espresso", + "968": "cup", + "969": "eggnog", + "970": "alp", + "971": "bubble", + "972": "cliff, drop, drop-off", + "973": "coral reef", + "974": "geyser", + "975": "lakeside, lakeshore", + "976": "promontory, headland, head, foreland", + "977": "sandbar, sand bar", + "978": "seashore, coast, seacoast, sea-coast", + "979": "valley, vale", + "980": "volcano", + "981": "ballplayer, baseball player", + "982": "groom, bridegroom", + "983": "scuba diver", + "984": "rapeseed", + "985": "daisy", + "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + "987": "corn", + "988": "acorn", + "989": "hip, rose hip, rosehip", + "990": "buckeye, horse chestnut, conker", + "991": "coral fungus", + "992": "agaric", + "993": "gyromitra", + "994": "stinkhorn, carrion fungus", + "995": "earthstar", + "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", + "997": "bolete", + "998": "ear, spike, capitulum", + "999": "toilet tissue, toilet paper, bathroom tissue" + }, + "dim": 196, + "drop_path_rate": 0.3, + "in_dim": 64, + "layer_scale": 1e-05, + "layer_scale_conv": null, + "mlp_ratio": 4, + "model_type": "mambavision", + "num_heads": [ + 4, + 8, + 16, + 32 + ], + "torch_dtype": "float32", + "transformers_version": "4.36.2", + "window_size": [ + 8, + 8, + 14, + 7 + ] +} diff --git a/configuration_mambavision.py b/configuration_mambavision.py new file mode 100644 index 0000000..fafbfc9 --- /dev/null +++ b/configuration_mambavision.py @@ -0,0 +1,28 @@ +from transformers import PretrainedConfig + +class MambaVisionConfig(PretrainedConfig): + model_type = "mambavision" + + def __init__( + self, + depths=[3, 3, 12, 5], + num_heads=[4, 8, 16, 32], + window_size=[8, 8, 14, 7], + dim=196, + in_dim=64, + mlp_ratio=4, + drop_path_rate=0.3, + layer_scale=1e-5, + layer_scale_conv=None, + **kwargs, + ): + self.depths = depths + self.num_heads = num_heads + self.window_size = window_size + self.dim = dim + self.in_dim = in_dim + self.mlp_ratio = mlp_ratio + self.drop_path_rate = drop_path_rate + self.layer_scale=layer_scale + self.layer_scale_conv=layer_scale_conv + super().__init__(**kwargs) \ No newline at end of file diff --git a/mamba_vision.py b/mamba_vision.py new file mode 100644 index 0000000..9dcd83c --- /dev/null +++ b/mamba_vision.py @@ -0,0 +1,865 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + + +import torch +import torch.nn as nn +from timm.models.registry import register_model +import math +from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d +from timm.models._builder import resolve_pretrained_cfg +try: + from timm.models._builder import _update_default_kwargs as update_args +except: + from timm.models._builder import _update_default_model_kwargs as update_args +from timm.models.vision_transformer import Mlp, PatchEmbed +from timm.models.layers import DropPath, trunc_normal_ +from timm.models.registry import register_model +import torch.nn.functional as F +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from einops import rearrange, repeat +from pathlib import Path +from huggingface_hub import PyTorchModelHubMixin + + +def _cfg(url='', **kwargs): + return {'url': url, + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': None, + 'crop_pct': 0.875, + 'interpolation': 'bicubic', + 'fixed_input_size': True, + 'mean': (0.485, 0.456, 0.406), + 'std': (0.229, 0.224, 0.225), + **kwargs + } + + +default_cfgs = { + 'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar', + crop_pct=1.0, + input_size=(3, 224, 224), + crop_mode='center'), + 'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar', + crop_pct=0.98, + input_size=(3, 224, 224), + crop_mode='center'), + 'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar', + crop_pct=0.93, + input_size=(3, 224, 224), + crop_mode='center'), + 'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar', + crop_pct=1.0, + input_size=(3, 224, 224), + crop_mode='center'), + 'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar', + crop_pct=1.0, + input_size=(3, 224, 224), + crop_mode='center'), + 'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar', + crop_pct=1.0, + input_size=(3, 224, 224), + crop_mode='center') +} + + +def window_partition(x, window_size): + """ + Args: + x: (B, C, H, W) + window_size: window size + h_w: Height of window + w_w: Width of window + Returns: + local window features (num_windows*B, window_size*window_size, C) + """ + B, C, H, W = x.shape + x = x.view(B, C, H // window_size, window_size, W // window_size, window_size) + windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: local window features (num_windows*B, window_size, window_size, C) + window_size: Window size + H: Height of image + W: Width of image + Returns: + x: (B, C, H, W) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W) + return x + + +def _load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, True, + all_missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(module) + load = None + missing_keys = [ + key for key in all_missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + + if len(err_msg) > 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + + +def _load_checkpoint(model, + filename, + map_location='cpu', + strict=False, + logger=None): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = torch.load(filename, map_location=map_location) + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + if sorted(list(state_dict.keys()))[0].startswith('encoder'): + state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} + + _load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +class Downsample(nn.Module): + """ + Down-sampling block" + """ + + def __init__(self, + dim, + keep_dim=False, + ): + """ + Args: + dim: feature size dimension. + norm_layer: normalization layer. + keep_dim: bool argument for maintaining the resolution. + """ + + super().__init__() + if keep_dim: + dim_out = dim + else: + dim_out = 2 * dim + self.reduction = nn.Sequential( + nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False), + ) + + def forward(self, x): + x = self.reduction(x) + return x + + +class PatchEmbed(nn.Module): + """ + Patch embedding block" + """ + + def __init__(self, in_chans=3, in_dim=64, dim=96): + """ + Args: + in_chans: number of input channels. + dim: feature size dimension. + """ + # in_dim = 1 + super().__init__() + self.proj = nn.Identity() + self.conv_down = nn.Sequential( + nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False), + nn.BatchNorm2d(in_dim, eps=1e-4), + nn.ReLU(), + nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False), + nn.BatchNorm2d(dim, eps=1e-4), + nn.ReLU() + ) + + def forward(self, x): + x = self.proj(x) + x = self.conv_down(x) + return x + + +class ConvBlock(nn.Module): + + def __init__(self, dim, + drop_path=0., + layer_scale=None, + kernel_size=3): + super().__init__() + + self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1) + self.norm1 = nn.BatchNorm2d(dim, eps=1e-5) + self.act1 = nn.GELU(approximate= 'tanh') + self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1) + self.norm2 = nn.BatchNorm2d(dim, eps=1e-5) + self.layer_scale = layer_scale + if layer_scale is not None and type(layer_scale) in [int, float]: + self.gamma = nn.Parameter(layer_scale * torch.ones(dim)) + self.layer_scale = True + else: + self.layer_scale = False + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.conv1(x) + x = self.norm1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.norm2(x) + if self.layer_scale: + x = x * self.gamma.view(1, -1, 1, 1) + x = input + self.drop_path(x) + return x + + +class MambaVisionMixer(nn.Module): + def __init__( + self, + d_model, + d_state=16, + d_conv=4, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + use_fast_path=True, + layer_idx=None, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + self.use_fast_path = use_fast_path + self.layer_idx = layer_idx + self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs) + self.x_proj = nn.Linear( + self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs + ) + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs) + dt_init_std = self.dt_rank**-0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(self.dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + dt = torch.exp( + torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + self.dt_proj.bias.copy_(inv_dt) + self.dt_proj.bias._no_reinit = True + A = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=self.d_inner//2, + ).contiguous() + A_log = torch.log(A) + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device)) + self.D._no_weight_decay = True + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + self.conv1d_x = nn.Conv1d( + in_channels=self.d_inner//2, + out_channels=self.d_inner//2, + bias=conv_bias//2, + kernel_size=d_conv, + groups=self.d_inner//2, + **factory_kwargs, + ) + self.conv1d_z = nn.Conv1d( + in_channels=self.d_inner//2, + out_channels=self.d_inner//2, + bias=conv_bias//2, + kernel_size=d_conv, + groups=self.d_inner//2, + **factory_kwargs, + ) + + def forward(self, hidden_states): + """ + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + _, seqlen, _ = hidden_states.shape + xz = self.in_proj(hidden_states) + xz = rearrange(xz, "b l d -> b d l") + x, z = xz.chunk(2, dim=1) + A = -torch.exp(self.A_log.float()) + x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2)) + z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2)) + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + y = selective_scan_fn(x, + dt, + A, + B, + C, + self.D.float(), + z=None, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=None) + + y = torch.cat([y, z], dim=1) + y = rearrange(y, "b d l -> b l d") + out = self.out_proj(y) + return out + + +class Attention(nn.Module): + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop=0., + proj_drop=0., + norm_layer=nn.LayerNorm, + ): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = True + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__(self, + dim, + num_heads, + counter, + transformer_blocks, + mlp_ratio=4., + qkv_bias=False, + qk_scale=False, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + Mlp_block=Mlp, + layer_scale=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + if counter in transformer_blocks: + self.mixer = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + norm_layer=norm_layer, + ) + else: + self.mixer = MambaVisionMixer(d_model=dim, + d_state=8, + d_conv=3, + expand=1 + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float] + self.gamma_1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1 + self.gamma_2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1 + + def forward(self, x): + x = x + self.drop_path(self.gamma_1 * self.mixer(self.norm1(x))) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class MambaVisionLayer(nn.Module): + """ + MambaVision layer" + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size, + conv=False, + downsample=True, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + layer_scale=None, + layer_scale_conv=None, + transformer_blocks = [], + ): + """ + Args: + dim: feature size dimension. + depth: number of layers in each stage. + window_size: window size in each stage. + conv: bool argument for conv stage flag. + downsample: bool argument for down-sampling. + mlp_ratio: MLP ratio. + num_heads: number of heads in each stage. + qkv_bias: bool argument for query, key, value learnable bias. + qk_scale: bool argument to scaling query, key. + drop: dropout rate. + attn_drop: attention dropout rate. + drop_path: drop path rate. + norm_layer: normalization layer. + layer_scale: layer scaling coefficient. + layer_scale_conv: conv layer scaling coefficient. + transformer_blocks: list of transformer blocks. + """ + + super().__init__() + self.conv = conv + self.transformer_block = False + if conv: + self.blocks = nn.ModuleList([ConvBlock(dim=dim, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + layer_scale=layer_scale_conv) + for i in range(depth)]) + self.transformer_block = False + else: + self.transformer_block = True + self.blocks = nn.ModuleList([Block(dim=dim, + counter=i, + transformer_blocks=transformer_blocks, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + layer_scale=layer_scale) + for i in range(depth)]) + self.transformer_block = True + + self.downsample = None if not downsample else Downsample(dim=dim) + self.do_gt = False + self.window_size = window_size + + def forward(self, x): + _, _, H, W = x.shape + + if self.transformer_block: + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + if pad_r > 0 or pad_b > 0: + x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b)) + _, _, Hp, Wp = x.shape + else: + Hp, Wp = H, W + x = window_partition(x, self.window_size) + + for _, blk in enumerate(self.blocks): + x = blk(x) + if self.transformer_block: + x = window_reverse(x, self.window_size, Hp, Wp) + if pad_r > 0 or pad_b > 0: + x = x[:, :, :H, :W].contiguous() + if self.downsample is None: + return x + return self.downsample(x) + + +class MambaVision(nn.Module, PyTorchModelHubMixin): + """ + MambaVision, + """ + + def __init__(self, + dim, + in_dim, + depths, + window_size, + mlp_ratio, + num_heads, + drop_path_rate=0.2, + in_chans=3, + num_classes=1000, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + layer_scale=None, + layer_scale_conv=None, + **kwargs): + """ + Args: + dim: feature size dimension. + depths: number of layers in each stage. + window_size: window size in each stage. + mlp_ratio: MLP ratio. + num_heads: number of heads in each stage. + drop_path_rate: drop path rate. + in_chans: number of input channels. + num_classes: number of classes. + qkv_bias: bool argument for query, key, value learnable bias. + qk_scale: bool argument to scaling query, key. + drop_rate: dropout rate. + attn_drop_rate: attention dropout rate. + norm_layer: normalization layer. + layer_scale: layer scaling coefficient. + layer_scale_conv: conv layer scaling coefficient. + """ + super().__init__() + num_features = int(dim * 2 ** (len(depths) - 1)) + self.num_classes = num_classes + self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + self.levels = nn.ModuleList() + for i in range(len(depths)): + conv = True if (i == 0 or i == 1) else False + level = MambaVisionLayer(dim=int(dim * 2 ** i), + depth=depths[i], + num_heads=num_heads[i], + window_size=window_size[i], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + conv=conv, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], + downsample=(i < 3), + layer_scale=layer_scale, + layer_scale_conv=layer_scale_conv, + transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])), + ) + self.levels.append(level) + self.norm = nn.BatchNorm2d(num_features) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, LayerNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'rpb'} + + def forward_features(self, x): + x = self.patch_embed(x) + for level in self.levels: + x = level(x) + x = self.norm(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + def _load_state_dict(self, + pretrained, + strict: bool = False): + _load_checkpoint(self, + pretrained, + strict=strict) + + +@register_model +def mamba_vision_T(pretrained=False, **kwargs): + model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T.pth.tar") + pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T').to_dict() + update_args(pretrained_cfg, kwargs, kwargs_filter=None) + model = MambaVision(depths=[1, 3, 8, 4], + num_heads=[2, 4, 8, 16], + window_size=[8, 8, 14, 7], + dim=80, + in_dim=32, + mlp_ratio=4, + resolution=224, + drop_path_rate=0.2, + **kwargs) + model.pretrained_cfg = pretrained_cfg + model.default_cfg = model.pretrained_cfg + if pretrained: + if not Path(model_path).is_file(): + url = model.default_cfg['url'] + torch.hub.download_url_to_file(url=url, dst=model_path) + model._load_state_dict(model_path) + return model + + +@register_model +def mamba_vision_T2(pretrained=False, **kwargs): + model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T2.pth.tar") + pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T2').to_dict() + update_args(pretrained_cfg, kwargs, kwargs_filter=None) + model = MambaVision(depths=[1, 3, 11, 4], + num_heads=[2, 4, 8, 16], + window_size=[8, 8, 14, 7], + dim=80, + in_dim=32, + mlp_ratio=4, + resolution=224, + drop_path_rate=0.2, + **kwargs) + model.pretrained_cfg = pretrained_cfg + model.default_cfg = model.pretrained_cfg + if pretrained: + if not Path(model_path).is_file(): + url = model.default_cfg['url'] + torch.hub.download_url_to_file(url=url, dst=model_path) + model._load_state_dict(model_path) + return model + + +@register_model +def mamba_vision_S(pretrained=False, **kwargs): + model_path = kwargs.pop("model_path", "/tmp/mamba_vision_S.pth.tar") + pretrained_cfg = resolve_pretrained_cfg('mamba_vision_S').to_dict() + update_args(pretrained_cfg, kwargs, kwargs_filter=None) + model = MambaVision(depths=[3, 3, 7, 5], + num_heads=[2, 4, 8, 16], + window_size=[8, 8, 14, 7], + dim=96, + in_dim=64, + mlp_ratio=4, + resolution=224, + drop_path_rate=0.2, + **kwargs) + model.pretrained_cfg = pretrained_cfg + model.default_cfg = model.pretrained_cfg + if pretrained: + if not Path(model_path).is_file(): + url = model.default_cfg['url'] + torch.hub.download_url_to_file(url=url, dst=model_path) + model._load_state_dict(model_path) + return model + + +@register_model +def mamba_vision_B(pretrained=False, **kwargs): + model_path = kwargs.pop("model_path", "/tmp/mamba_vision_B.pth.tar") + pretrained_cfg = resolve_pretrained_cfg('mamba_vision_B').to_dict() + update_args(pretrained_cfg, kwargs, kwargs_filter=None) + model = MambaVision(depths=[3, 3, 10, 5], + num_heads=[2, 4, 8, 16], + window_size=[8, 8, 14, 7], + dim=128, + in_dim=64, + mlp_ratio=4, + resolution=224, + drop_path_rate=0.3, + layer_scale=1e-5, + layer_scale_conv=None, + **kwargs) + model.pretrained_cfg = pretrained_cfg + model.default_cfg = model.pretrained_cfg + if pretrained: + if not Path(model_path).is_file(): + url = model.default_cfg['url'] + torch.hub.download_url_to_file(url=url, dst=model_path) + model._load_state_dict(model_path) + return model + + +@register_model +def mamba_vision_L(pretrained=False, **kwargs): + model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L.pth.tar") + pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L').to_dict() + update_args(pretrained_cfg, kwargs, kwargs_filter=None) + model = MambaVision(depths=[3, 3, 10, 5], + num_heads=[4, 8, 16, 32], + window_size=[8, 8, 14, 7], + dim=196, + in_dim=64, + mlp_ratio=4, + resolution=224, + drop_path_rate=0.3, + layer_scale=1e-5, + layer_scale_conv=None, + **kwargs) + model.pretrained_cfg = pretrained_cfg + model.default_cfg = model.pretrained_cfg + if pretrained: + if not Path(model_path).is_file(): + url = model.default_cfg['url'] + torch.hub.download_url_to_file(url=url, dst=model_path) + model._load_state_dict(model_path) + return model + + +@register_model +def mamba_vision_L2(pretrained=False, **kwargs): + model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L2.pth.tar") + pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L2').to_dict() + update_args(pretrained_cfg, kwargs, kwargs_filter=None) + model = MambaVision(depths=[3, 3, 12, 5], + num_heads=[4, 8, 16, 32], + window_size=[8, 8, 14, 7], + dim=196, + in_dim=64, + mlp_ratio=4, + resolution=224, + drop_path_rate=0.3, + layer_scale=1e-5, + layer_scale_conv=None, + **kwargs) + model.pretrained_cfg = pretrained_cfg + model.default_cfg = model.pretrained_cfg + if pretrained: + if not Path(model_path).is_file(): + url = model.default_cfg['url'] + torch.hub.download_url_to_file(url=url, dst=model_path) + model._load_state_dict(model_path) + return model \ No newline at end of file diff --git a/mambavision_large2_1k.pth.tar b/mambavision_large2_1k.pth.tar new file mode 100644 index 0000000..e6ff99d --- /dev/null +++ b/mambavision_large2_1k.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:498582771558c79263cc1b61a86acd8c0e7767738d6a7a2d38fc1b3af149b52e +size 2899063127 diff --git a/model.safetensors b/model.safetensors new file mode 100644 index 0000000..2b674c4 --- /dev/null +++ b/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56a23220d4f17a3977fe4032cd564b6d54e29f5fe263530cdb3c94af2ee3f220 +size 966304200 diff --git a/modeling_mambavision.py b/modeling_mambavision.py new file mode 100644 index 0000000..8d1a501 --- /dev/null +++ b/modeling_mambavision.py @@ -0,0 +1,764 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + + +import torch +import torch.nn as nn +from timm.models.registry import register_model +import math +from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d +from timm.models._builder import resolve_pretrained_cfg +try: + from timm.models._builder import _update_default_kwargs as update_args +except: + from timm.models._builder import _update_default_model_kwargs as update_args +from timm.models.vision_transformer import Mlp, PatchEmbed +from timm.models.layers import DropPath, trunc_normal_ +from timm.models.registry import register_model +import torch.nn.functional as F +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from einops import rearrange, repeat + +from transformers import PreTrainedModel + +from configuration_mambavision import MambaVisionConfig + + +def _cfg(url='', **kwargs): + return {'url': url, + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': None, + 'crop_pct': 0.875, + 'interpolation': 'bicubic', + 'fixed_input_size': True, + 'mean': (0.485, 0.456, 0.406), + 'std': (0.229, 0.224, 0.225), + **kwargs + } + + +default_cfgs = { + 'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar', + crop_pct=1.0, + input_size=(3, 224, 224), + crop_mode='center'), + 'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar', + crop_pct=0.98, + input_size=(3, 224, 224), + crop_mode='center'), + 'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar', + crop_pct=0.93, + input_size=(3, 224, 224), + crop_mode='center'), + 'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar', + crop_pct=1.0, + input_size=(3, 224, 224), + crop_mode='center'), + 'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar', + crop_pct=1.0, + input_size=(3, 224, 224), + crop_mode='center'), + 'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar', + crop_pct=1.0, + input_size=(3, 224, 224), + crop_mode='center') +} + + +def window_partition(x, window_size): + """ + Args: + x: (B, C, H, W) + window_size: window size + h_w: Height of window + w_w: Width of window + Returns: + local window features (num_windows*B, window_size*window_size, C) + """ + B, C, H, W = x.shape + x = x.view(B, C, H // window_size, window_size, W // window_size, window_size) + windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: local window features (num_windows*B, window_size, window_size, C) + window_size: Window size + H: Height of image + W: Width of image + Returns: + x: (B, C, H, W) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W) + return x + + +def _load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, True, + all_missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(module) + load = None + missing_keys = [ + key for key in all_missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + + if len(err_msg) > 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + + +def _load_checkpoint(model, + filename, + map_location='cpu', + strict=False, + logger=None): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = torch.load(filename, map_location=map_location) + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + if sorted(list(state_dict.keys()))[0].startswith('encoder'): + state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} + + _load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +class Downsample(nn.Module): + """ + Down-sampling block" + """ + + def __init__(self, + dim, + keep_dim=False, + ): + """ + Args: + dim: feature size dimension. + norm_layer: normalization layer. + keep_dim: bool argument for maintaining the resolution. + """ + + super().__init__() + if keep_dim: + dim_out = dim + else: + dim_out = 2 * dim + self.reduction = nn.Sequential( + nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False), + ) + + def forward(self, x): + x = self.reduction(x) + return x + + +class PatchEmbed(nn.Module): + """ + Patch embedding block" + """ + + def __init__(self, in_chans=3, in_dim=64, dim=96): + """ + Args: + in_chans: number of input channels. + dim: feature size dimension. + """ + # in_dim = 1 + super().__init__() + self.proj = nn.Identity() + self.conv_down = nn.Sequential( + nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False), + nn.BatchNorm2d(in_dim, eps=1e-4), + nn.ReLU(), + nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False), + nn.BatchNorm2d(dim, eps=1e-4), + nn.ReLU() + ) + + def forward(self, x): + x = self.proj(x) + x = self.conv_down(x) + return x + + +class ConvBlock(nn.Module): + + def __init__(self, dim, + drop_path=0., + layer_scale=None, + kernel_size=3): + super().__init__() + + self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1) + self.norm1 = nn.BatchNorm2d(dim, eps=1e-5) + self.act1 = nn.GELU(approximate= 'tanh') + self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1) + self.norm2 = nn.BatchNorm2d(dim, eps=1e-5) + self.layer_scale = layer_scale + if layer_scale is not None and type(layer_scale) in [int, float]: + self.g = nn.Parameter(layer_scale * torch.ones(dim)) + self.layer_scale = True + else: + self.layer_scale = False + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.conv1(x) + x = self.norm1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.norm2(x) + if self.layer_scale: + x = x * self.g.view(1, -1, 1, 1) + x = input + self.drop_path(x) + return x + + +class MambaVisionMixer(nn.Module): + def __init__( + self, + d_model, + d_state=16, + d_conv=4, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + use_fast_path=True, + layer_idx=None, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + self.use_fast_path = use_fast_path + self.layer_idx = layer_idx + self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs) + self.x_proj = nn.Linear( + self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs + ) + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs) + dt_init_std = self.dt_rank**-0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(self.dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + dt = torch.exp( + torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + self.dt_proj.bias.copy_(inv_dt) + self.dt_proj.bias._no_reinit = True + A = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=self.d_inner//2, + ).contiguous() + A_log = torch.log(A) + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device)) + self.D._no_weight_decay = True + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + self.conv1d_x = nn.Conv1d( + in_channels=self.d_inner//2, + out_channels=self.d_inner//2, + bias=conv_bias//2, + kernel_size=d_conv, + groups=self.d_inner//2, + **factory_kwargs, + ) + self.conv1d_z = nn.Conv1d( + in_channels=self.d_inner//2, + out_channels=self.d_inner//2, + bias=conv_bias//2, + kernel_size=d_conv, + groups=self.d_inner//2, + **factory_kwargs, + ) + + def forward(self, hidden_states): + """ + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + _, seqlen, _ = hidden_states.shape + xz = self.in_proj(hidden_states) + xz = rearrange(xz, "b l d -> b d l") + x, z = xz.chunk(2, dim=1) + A = -torch.exp(self.A_log.float()) + x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2)) + z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2)) + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + y = selective_scan_fn(x, + dt, + A, + B, + C, + self.D.float(), + z=None, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=None) + + y = torch.cat([y, z], dim=1) + y = rearrange(y, "b d l -> b l d") + out = self.out_proj(y) + return out + + +class Attention(nn.Module): + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop=0., + proj_drop=0., + norm_layer=nn.LayerNorm, + ): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = True + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__(self, + dim, + num_heads, + counter, + transformer_blocks, + mlp_ratio=4., + qkv_bias=False, + qk_scale=False, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + Mlp_block=Mlp, + layer_scale=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + if counter in transformer_blocks: + self.mixer = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + norm_layer=norm_layer, + ) + else: + self.mixer = MambaVisionMixer(d_model=dim, + d_state=8, + d_conv=3, + expand=1 + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float] + self.g_1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1 + self.g_2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1 + + def forward(self, x): + x = x + self.drop_path(self.g_1 * self.mixer(self.norm1(x))) + x = x + self.drop_path(self.g_2 * self.mlp(self.norm2(x))) + return x + + +class MambaVisionLayer(nn.Module): + """ + MambaVision layer" + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size, + conv=False, + downsample=True, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + layer_scale=None, + layer_scale_conv=None, + transformer_blocks = [], + ): + """ + Args: + dim: feature size dimension. + depth: number of layers in each stage. + window_size: window size in each stage. + conv: bool argument for conv stage flag. + downsample: bool argument for down-sampling. + mlp_ratio: MLP ratio. + num_heads: number of heads in each stage. + qkv_bias: bool argument for query, key, value learnable bias. + qk_scale: bool argument to scaling query, key. + drop: dropout rate. + attn_drop: attention dropout rate. + drop_path: drop path rate. + norm_layer: normalization layer. + layer_scale: layer scaling coefficient. + layer_scale_conv: conv layer scaling coefficient. + transformer_blocks: list of transformer blocks. + """ + + super().__init__() + self.conv = conv + self.transformer_block = False + if conv: + self.blocks = nn.ModuleList([ConvBlock(dim=dim, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + layer_scale=layer_scale_conv) + for i in range(depth)]) + self.transformer_block = False + else: + self.transformer_block = True + self.blocks = nn.ModuleList([Block(dim=dim, + counter=i, + transformer_blocks=transformer_blocks, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + layer_scale=layer_scale) + for i in range(depth)]) + self.transformer_block = True + + self.downsample = None if not downsample else Downsample(dim=dim) + self.do_gt = False + self.window_size = window_size + + def forward(self, x): + _, _, H, W = x.shape + + if self.transformer_block: + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + if pad_r > 0 or pad_b > 0: + x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b)) + _, _, Hp, Wp = x.shape + else: + Hp, Wp = H, W + x = window_partition(x, self.window_size) + + for _, blk in enumerate(self.blocks): + x = blk(x) + if self.transformer_block: + x = window_reverse(x, self.window_size, Hp, Wp) + if pad_r > 0 or pad_b > 0: + x = x[:, :, :H, :W].contiguous() + if self.downsample is None: + return x, x + return self.downsample(x), x + + +class MambaVision(nn.Module): + """ + MambaVision, + """ + + def __init__(self, + dim, + in_dim, + depths, + window_size, + mlp_ratio, + num_heads, + drop_path_rate=0.2, + in_chans=3, + num_classes=1000, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + layer_scale=None, + layer_scale_conv=None, + **kwargs): + """ + Args: + dim: feature size dimension. + depths: number of layers in each stage. + window_size: window size in each stage. + mlp_ratio: MLP ratio. + num_heads: number of heads in each stage. + drop_path_rate: drop path rate. + in_chans: number of input channels. + num_classes: number of classes. + qkv_bias: bool argument for query, key, value learnable bias. + qk_scale: bool argument to scaling query, key. + drop_rate: dropout rate. + attn_drop_rate: attention dropout rate. + norm_layer: normalization layer. + layer_scale: layer scaling coefficient. + layer_scale_conv: conv layer scaling coefficient. + """ + super().__init__() + num_features = int(dim * 2 ** (len(depths) - 1)) + self.num_classes = num_classes + self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + self.levels = nn.ModuleList() + for i in range(len(depths)): + conv = True if (i == 0 or i == 1) else False + level = MambaVisionLayer(dim=int(dim * 2 ** i), + depth=depths[i], + num_heads=num_heads[i], + window_size=window_size[i], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + conv=conv, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], + downsample=(i < 3), + layer_scale=layer_scale, + layer_scale_conv=layer_scale_conv, + transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])), + ) + self.levels.append(level) + self.norm = nn.BatchNorm2d(num_features) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, LayerNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'rpb'} + + def forward_features(self, x): + x = self.patch_embed(x) + outs = [] + for level in self.levels: + x, xo = level(x) + outs.append(xo) + x = self.norm(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + return x, outs + + def forward(self, x): + x, outs = self.forward_features(x) + x = self.head(x) + return x + + def _load_state_dict(self, + pretrained, + strict: bool = False): + _load_checkpoint(self, + pretrained, + strict=strict) + + +class MambaVisionModel(PreTrainedModel): + config_class = MambaVisionConfig + + def __init__(self, config): + super().__init__(config) + self.model = MambaVision( + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + dim=config.dim, + in_dim=config.in_dim, + mlp_ratio=config.mlp_ratio, + layer_scale=config.layer_scale, + layer_scale_conv=config.layer_scale_conv + ) + + def forward(self, tensor): + return self.model.forward_features(tensor) + + +class MambaVisionModelForImageClassification(PreTrainedModel): + config_class = MambaVisionConfig + + + def __init__(self, config): + super().__init__(config) + self.model = MambaVision( + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + dim=config.dim, + in_dim=config.in_dim, + mlp_ratio=config.mlp_ratio, + layer_scale=config.layer_scale, + layer_scale_conv=config.layer_scale_conv + ) + + def forward(self, tensor, labels=None): + logits = self.model(tensor) + if labels is not None: + loss = torch.nn.cross_entropy(logits, labels) + return {"loss": loss, "logits": logits} + return {"logits": logits} \ No newline at end of file