{
  "id": "pytorch/mps-float64-cast-fallback",
  "signature": "RuntimeError: MPS backend does not support float64. Falling back to float32. Please cast your tensors to float32 explicitly.",
  "signature_zh": "RuntimeError: MPS 后端不支持 float64，回退到 float32。请将张量显式转换为 float32。",
  "regex": "MPS backend does not support float64",
  "domain": "pytorch",
  "category": "type_error",
  "subcategory": null,
  "root_cause": "The Metal Performance Shaders (MPS) backend on Apple Silicon does not support float64 (double precision) tensors; any operation involving float64 tensors triggers a fallback to CPU or an error, causing performance degradation or crashes.",
  "root_cause_type": "generic",
  "root_cause_zh": "Apple Silicon 上的 Metal Performance Shaders (MPS) 后端不支持 float64（双精度）张量；任何涉及 float64 张量的操作会触发回退到 CPU 或导致错误，引起性能下降或崩溃。",
  "versions": [
    {
      "version": "PyTorch 2.0.0",
      "introduced": null,
      "deprecated": null,
      "removed": null,
      "behavior_change": null,
      "status": "active"
    },
    {
      "version": "PyTorch 2.1.0",
      "introduced": null,
      "deprecated": null,
      "removed": null,
      "behavior_change": null,
      "status": "active"
    },
    {
      "version": "macOS 13 Ventura",
      "introduced": null,
      "deprecated": null,
      "removed": null,
      "behavior_change": null,
      "status": "active"
    },
    {
      "version": "macOS 14 Sonoma",
      "introduced": null,
      "deprecated": null,
      "removed": null,
      "behavior_change": null,
      "status": "active"
    },
    {
      "version": "Apple M1",
      "introduced": null,
      "deprecated": null,
      "removed": null,
      "behavior_change": null,
      "status": "active"
    },
    {
      "version": "Apple M2",
      "introduced": null,
      "deprecated": null,
      "removed": null,
      "behavior_change": null,
      "status": "active"
    },
    {
      "version": "Apple M3",
      "introduced": null,
      "deprecated": null,
      "removed": null,
      "behavior_change": null,
      "status": "active"
    }
  ],
  "os_specific": {},
  "dead_ends": [
    {
      "action": "Set environment variable PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable MPS fallback",
      "why_fails": "This flag controls memory watermark, not dtype support. The float64 issue remains and will cause errors.",
      "fail_rate": 0.95,
      "condition": "",
      "sources": []
    },
    {
      "action": "Use torch.set_default_dtype(torch.float64) to force double precision",
      "why_fails": "This makes the problem worse by creating more float64 tensors, increasing fallback frequency and memory usage.",
      "fail_rate": 0.9,
      "condition": "",
      "sources": []
    },
    {
      "action": "Install PyTorch nightly build for better MPS support",
      "why_fails": "Nightly builds may have improvements but still do not support float64 on MPS as of current versions.",
      "fail_rate": 0.85,
      "condition": "",
      "sources": []
    }
  ],
  "workarounds": [
    {
      "action": "Cast all tensors to float32 explicitly before moving to MPS device: tensor = tensor.float().to('mps'). This ensures compatibility with MPS backend.",
      "success_rate": 0.95,
      "how": "Cast all tensors to float32 explicitly before moving to MPS device: tensor = tensor.float().to('mps'). This ensures compatibility with MPS backend.",
      "condition": "",
      "sources": []
    },
    {
      "action": "Set the default dtype to float32 at the start of the script: torch.set_default_dtype(torch.float32). This prevents accidental creation of float64 tensors from Python floats.",
      "success_rate": 0.9,
      "how": "Set the default dtype to float32 at the start of the script: torch.set_default_dtype(torch.float32). This prevents accidental creation of float64 tensors from Python floats.",
      "condition": "",
      "sources": []
    },
    {
      "action": "Use a custom wrapper function that checks dtype and casts if needed: def to_mps(tensor): return tensor.float().to('mps') if tensor.is_floating_point() else tensor.to('mps')",
      "success_rate": 0.85,
      "how": "Use a custom wrapper function that checks dtype and casts if needed: def to_mps(tensor): return tensor.float().to('mps') if tensor.is_floating_point() else tensor.to('mps')",
      "condition": "",
      "sources": []
    }
  ],
  "workarounds_zh": [
    "Cast all tensors to float32 explicitly before moving to MPS device: tensor = tensor.float().to('mps'). This ensures compatibility with MPS backend.",
    "Set the default dtype to float32 at the start of the script: torch.set_default_dtype(torch.float32). This prevents accidental creation of float64 tensors from Python floats.",
    "Use a custom wrapper function that checks dtype and casts if needed: def to_mps(tensor): return tensor.float().to('mps') if tensor.is_floating_point() else tensor.to('mps')"
  ],
  "transition_graph": {
    "leads_to": [],
    "preceded_by": [],
    "frequently_confused_with": []
  },
  "official_doc_url": "https://pytorch.org/docs/stable/notes/mps.html#un-supported-features",
  "official_doc_section": null,
  "error_code": "MPS_ERROR_UNSUPPORTED_DTYPE",
  "verification_tier": "ai_generated",
  "confidence": 0.88,
  "fix_success_rate": 0.9,
  "resolvable": "true",
  "first_seen": "2023-05-01",
  "last_confirmed": "2024-06-01",
  "last_updated": "2024-06-01",
  "evidence_count": 1,
  "tags": [],
  "locale": "en",
  "aliases": []
}